00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_PASTIXSUPPORT_H
00011 #define EIGEN_PASTIXSUPPORT_H
00012
00013 namespace Eigen {
00014
00023 template<typename _MatrixType, bool IsStrSym = false> class PastixLU;
00024 template<typename _MatrixType, int Options> class PastixLLT;
00025 template<typename _MatrixType, int Options> class PastixLDLT;
00026
00027 namespace internal
00028 {
00029
00030 template<class Pastix> struct pastix_traits;
00031
00032 template<typename _MatrixType>
00033 struct pastix_traits< PastixLU<_MatrixType> >
00034 {
00035 typedef _MatrixType MatrixType;
00036 typedef typename _MatrixType::Scalar Scalar;
00037 typedef typename _MatrixType::RealScalar RealScalar;
00038 typedef typename _MatrixType::Index Index;
00039 };
00040
00041 template<typename _MatrixType, int Options>
00042 struct pastix_traits< PastixLLT<_MatrixType,Options> >
00043 {
00044 typedef _MatrixType MatrixType;
00045 typedef typename _MatrixType::Scalar Scalar;
00046 typedef typename _MatrixType::RealScalar RealScalar;
00047 typedef typename _MatrixType::Index Index;
00048 };
00049
00050 template<typename _MatrixType, int Options>
00051 struct pastix_traits< PastixLDLT<_MatrixType,Options> >
00052 {
00053 typedef _MatrixType MatrixType;
00054 typedef typename _MatrixType::Scalar Scalar;
00055 typedef typename _MatrixType::RealScalar RealScalar;
00056 typedef typename _MatrixType::Index Index;
00057 };
00058
00059 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, float *vals, int *perm, int * invp, float *x, int nbrhs, int *iparm, double *dparm)
00060 {
00061 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00062 if (nbrhs == 0) {x = NULL; nbrhs=1;}
00063 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
00064 }
00065
00066 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, double *vals, int *perm, int * invp, double *x, int nbrhs, int *iparm, double *dparm)
00067 {
00068 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00069 if (nbrhs == 0) {x = NULL; nbrhs=1;}
00070 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
00071 }
00072
00073 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<float> *vals, int *perm, int * invp, std::complex<float> *x, int nbrhs, int *iparm, double *dparm)
00074 {
00075 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00076 if (nbrhs == 0) {x = NULL; nbrhs=1;}
00077 c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm);
00078 }
00079
00080 void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<double> *vals, int *perm, int * invp, std::complex<double> *x, int nbrhs, int *iparm, double *dparm)
00081 {
00082 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00083 if (nbrhs == 0) {x = NULL; nbrhs=1;}
00084 z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm);
00085 }
00086
00087
00088 template <typename MatrixType>
00089 void c_to_fortran_numbering (MatrixType& mat)
00090 {
00091 if ( !(mat.outerIndexPtr()[0]) )
00092 {
00093 int i;
00094 for(i = 0; i <= mat.rows(); ++i)
00095 ++mat.outerIndexPtr()[i];
00096 for(i = 0; i < mat.nonZeros(); ++i)
00097 ++mat.innerIndexPtr()[i];
00098 }
00099 }
00100
00101
00102 template <typename MatrixType>
00103 void fortran_to_c_numbering (MatrixType& mat)
00104 {
00105
00106 if ( mat.outerIndexPtr()[0] == 1 )
00107 {
00108 int i;
00109 for(i = 0; i <= mat.rows(); ++i)
00110 --mat.outerIndexPtr()[i];
00111 for(i = 0; i < mat.nonZeros(); ++i)
00112 --mat.innerIndexPtr()[i];
00113 }
00114 }
00115 }
00116
00117
00118
00119 template <class Derived>
00120 class PastixBase : internal::noncopyable
00121 {
00122 public:
00123 typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
00124 typedef _MatrixType MatrixType;
00125 typedef typename MatrixType::Scalar Scalar;
00126 typedef typename MatrixType::RealScalar RealScalar;
00127 typedef typename MatrixType::Index Index;
00128 typedef Matrix<Scalar,Dynamic,1> Vector;
00129 typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
00130
00131 public:
00132
00133 PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
00134 {
00135 init();
00136 }
00137
00138 ~PastixBase()
00139 {
00140 clean();
00141 }
00142
00147 template<typename Rhs>
00148 inline const internal::solve_retval<PastixBase, Rhs>
00149 solve(const MatrixBase<Rhs>& b) const
00150 {
00151 eigen_assert(m_isInitialized && "Pastix solver is not initialized.");
00152 eigen_assert(rows()==b.rows()
00153 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
00154 return internal::solve_retval<PastixBase, Rhs>(*this, b.derived());
00155 }
00156
00157 template<typename Rhs,typename Dest>
00158 bool _solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const;
00159
00160 Derived& derived()
00161 {
00162 return *static_cast<Derived*>(this);
00163 }
00164 const Derived& derived() const
00165 {
00166 return *static_cast<const Derived*>(this);
00167 }
00168
00174 Array<Index,IPARM_SIZE,1>& iparm()
00175 {
00176 return m_iparm;
00177 }
00178
00183 int& iparm(int idxparam)
00184 {
00185 return m_iparm(idxparam);
00186 }
00187
00192 Array<RealScalar,IPARM_SIZE,1>& dparm()
00193 {
00194 return m_dparm;
00195 }
00196
00197
00201 double& dparm(int idxparam)
00202 {
00203 return m_dparm(idxparam);
00204 }
00205
00206 inline Index cols() const { return m_size; }
00207 inline Index rows() const { return m_size; }
00208
00217 ComputationInfo info() const
00218 {
00219 eigen_assert(m_isInitialized && "Decomposition is not initialized.");
00220 return m_info;
00221 }
00222
00227 template<typename Rhs>
00228 inline const internal::sparse_solve_retval<PastixBase, Rhs>
00229 solve(const SparseMatrixBase<Rhs>& b) const
00230 {
00231 eigen_assert(m_isInitialized && "Pastix LU, LLT or LDLT is not initialized.");
00232 eigen_assert(rows()==b.rows()
00233 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
00234 return internal::sparse_solve_retval<PastixBase, Rhs>(*this, b.derived());
00235 }
00236
00237 protected:
00238
00239
00240 void init();
00241
00242
00243 void analyzePattern(ColSpMatrix& mat);
00244
00245
00246 void factorize(ColSpMatrix& mat);
00247
00248
00249 void clean()
00250 {
00251 eigen_assert(m_initisOk && "The Pastix structure should be allocated first");
00252 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
00253 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
00254 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
00255 m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00256 }
00257
00258 void compute(ColSpMatrix& mat);
00259
00260 int m_initisOk;
00261 int m_analysisIsOk;
00262 int m_factorizationIsOk;
00263 bool m_isInitialized;
00264 mutable ComputationInfo m_info;
00265 mutable pastix_data_t *m_pastixdata;
00266 mutable int m_comm;
00267 mutable Matrix<int,IPARM_SIZE,1> m_iparm;
00268 mutable Matrix<double,DPARM_SIZE,1> m_dparm;
00269 mutable Matrix<Index,Dynamic,1> m_perm;
00270 mutable Matrix<Index,Dynamic,1> m_invp;
00271 mutable int m_size;
00272 };
00273
00278 template <class Derived>
00279 void PastixBase<Derived>::init()
00280 {
00281 m_size = 0;
00282 m_iparm.setZero(IPARM_SIZE);
00283 m_dparm.setZero(DPARM_SIZE);
00284
00285 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
00286 pastix(&m_pastixdata, MPI_COMM_WORLD,
00287 0, 0, 0, 0,
00288 0, 0, 0, 1, m_iparm.data(), m_dparm.data());
00289
00290 m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
00291 m_iparm[IPARM_VERBOSE] = 2;
00292 m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH;
00293 m_iparm[IPARM_INCOMPLETE] = API_NO;
00294 m_iparm[IPARM_OOC_LIMIT] = 2000;
00295 m_iparm[IPARM_RHS_MAKING] = API_RHS_B;
00296 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
00297
00298 m_iparm(IPARM_START_TASK) = API_TASK_INIT;
00299 m_iparm(IPARM_END_TASK) = API_TASK_INIT;
00300 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
00301 0, 0, 0, 0, m_iparm.data(), m_dparm.data());
00302
00303
00304 if(m_iparm(IPARM_ERROR_NUMBER)) {
00305 m_info = InvalidInput;
00306 m_initisOk = false;
00307 }
00308 else {
00309 m_info = Success;
00310 m_initisOk = true;
00311 }
00312 }
00313
00314 template <class Derived>
00315 void PastixBase<Derived>::compute(ColSpMatrix& mat)
00316 {
00317 eigen_assert(mat.rows() == mat.cols() && "The input matrix should be squared");
00318
00319 analyzePattern(mat);
00320 factorize(mat);
00321
00322 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
00323 m_isInitialized = m_factorizationIsOk;
00324 }
00325
00326
00327 template <class Derived>
00328 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
00329 {
00330 eigen_assert(m_initisOk && "The initialization of PaSTiX failed");
00331
00332
00333 if(m_size>0)
00334 clean();
00335
00336 m_size = mat.rows();
00337 m_perm.resize(m_size);
00338 m_invp.resize(m_size);
00339
00340 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
00341 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
00342 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
00343 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00344
00345
00346 if(m_iparm(IPARM_ERROR_NUMBER))
00347 {
00348 m_info = NumericalIssue;
00349 m_analysisIsOk = false;
00350 }
00351 else
00352 {
00353 m_info = Success;
00354 m_analysisIsOk = true;
00355 }
00356 }
00357
00358 template <class Derived>
00359 void PastixBase<Derived>::factorize(ColSpMatrix& mat)
00360 {
00361
00362 eigen_assert(m_analysisIsOk && "The analysis phase should be called before the factorization phase");
00363 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
00364 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
00365 m_size = mat.rows();
00366
00367 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
00368 mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00369
00370
00371 if(m_iparm(IPARM_ERROR_NUMBER))
00372 {
00373 m_info = NumericalIssue;
00374 m_factorizationIsOk = false;
00375 m_isInitialized = false;
00376 }
00377 else
00378 {
00379 m_info = Success;
00380 m_factorizationIsOk = true;
00381 m_isInitialized = true;
00382 }
00383 }
00384
00385
00386 template<typename Base>
00387 template<typename Rhs,typename Dest>
00388 bool PastixBase<Base>::_solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const
00389 {
00390 eigen_assert(m_isInitialized && "The matrix should be factorized first");
00391 EIGEN_STATIC_ASSERT((Dest::Flags&RowMajorBit)==0,
00392 THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
00393 int rhs = 1;
00394
00395 x = b;
00396
00397 for (int i = 0; i < b.cols(); i++){
00398 m_iparm[IPARM_START_TASK] = API_TASK_SOLVE;
00399 m_iparm[IPARM_END_TASK] = API_TASK_REFINE;
00400
00401 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
00402 m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
00403 }
00404
00405
00406 m_info = m_iparm(IPARM_ERROR_NUMBER)==0 ? Success : NumericalIssue;
00407
00408 return m_iparm(IPARM_ERROR_NUMBER)==0;
00409 }
00410
00430 template<typename _MatrixType, bool IsStrSym>
00431 class PastixLU : public PastixBase< PastixLU<_MatrixType> >
00432 {
00433 public:
00434 typedef _MatrixType MatrixType;
00435 typedef PastixBase<PastixLU<MatrixType> > Base;
00436 typedef typename Base::ColSpMatrix ColSpMatrix;
00437 typedef typename MatrixType::Index Index;
00438
00439 public:
00440 PastixLU() : Base()
00441 {
00442 init();
00443 }
00444
00445 PastixLU(const MatrixType& matrix):Base()
00446 {
00447 init();
00448 compute(matrix);
00449 }
00455 void compute (const MatrixType& matrix)
00456 {
00457 m_structureIsUptodate = false;
00458 ColSpMatrix temp;
00459 grabMatrix(matrix, temp);
00460 Base::compute(temp);
00461 }
00467 void analyzePattern(const MatrixType& matrix)
00468 {
00469 m_structureIsUptodate = false;
00470 ColSpMatrix temp;
00471 grabMatrix(matrix, temp);
00472 Base::analyzePattern(temp);
00473 }
00474
00480 void factorize(const MatrixType& matrix)
00481 {
00482 ColSpMatrix temp;
00483 grabMatrix(matrix, temp);
00484 Base::factorize(temp);
00485 }
00486 protected:
00487
00488 void init()
00489 {
00490 m_structureIsUptodate = false;
00491 m_iparm(IPARM_SYM) = API_SYM_NO;
00492 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
00493 }
00494
00495 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00496 {
00497 if(IsStrSym)
00498 out = matrix;
00499 else
00500 {
00501 if(!m_structureIsUptodate)
00502 {
00503
00504 m_transposedStructure = matrix.transpose();
00505
00506
00507 for (Index j=0; j<m_transposedStructure.outerSize(); ++j)
00508 for(typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
00509 it.valueRef() = 0.0;
00510
00511 m_structureIsUptodate = true;
00512 }
00513
00514 out = m_transposedStructure + matrix;
00515 }
00516 internal::c_to_fortran_numbering(out);
00517 }
00518
00519 using Base::m_iparm;
00520 using Base::m_dparm;
00521
00522 ColSpMatrix m_transposedStructure;
00523 bool m_structureIsUptodate;
00524 };
00525
00540 template<typename _MatrixType, int _UpLo>
00541 class PastixLLT : public PastixBase< PastixLLT<_MatrixType, _UpLo> >
00542 {
00543 public:
00544 typedef _MatrixType MatrixType;
00545 typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
00546 typedef typename Base::ColSpMatrix ColSpMatrix;
00547
00548 public:
00549 enum { UpLo = _UpLo };
00550 PastixLLT() : Base()
00551 {
00552 init();
00553 }
00554
00555 PastixLLT(const MatrixType& matrix):Base()
00556 {
00557 init();
00558 compute(matrix);
00559 }
00560
00564 void compute (const MatrixType& matrix)
00565 {
00566 ColSpMatrix temp;
00567 grabMatrix(matrix, temp);
00568 Base::compute(temp);
00569 }
00570
00575 void analyzePattern(const MatrixType& matrix)
00576 {
00577 ColSpMatrix temp;
00578 grabMatrix(matrix, temp);
00579 Base::analyzePattern(temp);
00580 }
00584 void factorize(const MatrixType& matrix)
00585 {
00586 ColSpMatrix temp;
00587 grabMatrix(matrix, temp);
00588 Base::factorize(temp);
00589 }
00590 protected:
00591 using Base::m_iparm;
00592
00593 void init()
00594 {
00595 m_iparm(IPARM_SYM) = API_SYM_YES;
00596 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
00597 }
00598
00599 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00600 {
00601
00602 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
00603 internal::c_to_fortran_numbering(out);
00604 }
00605 };
00606
00621 template<typename _MatrixType, int _UpLo>
00622 class PastixLDLT : public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
00623 {
00624 public:
00625 typedef _MatrixType MatrixType;
00626 typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base;
00627 typedef typename Base::ColSpMatrix ColSpMatrix;
00628
00629 public:
00630 enum { UpLo = _UpLo };
00631 PastixLDLT():Base()
00632 {
00633 init();
00634 }
00635
00636 PastixLDLT(const MatrixType& matrix):Base()
00637 {
00638 init();
00639 compute(matrix);
00640 }
00641
00645 void compute (const MatrixType& matrix)
00646 {
00647 ColSpMatrix temp;
00648 grabMatrix(matrix, temp);
00649 Base::compute(temp);
00650 }
00651
00656 void analyzePattern(const MatrixType& matrix)
00657 {
00658 ColSpMatrix temp;
00659 grabMatrix(matrix, temp);
00660 Base::analyzePattern(temp);
00661 }
00665 void factorize(const MatrixType& matrix)
00666 {
00667 ColSpMatrix temp;
00668 grabMatrix(matrix, temp);
00669 Base::factorize(temp);
00670 }
00671
00672 protected:
00673 using Base::m_iparm;
00674
00675 void init()
00676 {
00677 m_iparm(IPARM_SYM) = API_SYM_YES;
00678 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
00679 }
00680
00681 void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00682 {
00683
00684 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
00685 internal::c_to_fortran_numbering(out);
00686 }
00687 };
00688
00689 namespace internal {
00690
00691 template<typename _MatrixType, typename Rhs>
00692 struct solve_retval<PastixBase<_MatrixType>, Rhs>
00693 : solve_retval_base<PastixBase<_MatrixType>, Rhs>
00694 {
00695 typedef PastixBase<_MatrixType> Dec;
00696 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
00697
00698 template<typename Dest> void evalTo(Dest& dst) const
00699 {
00700 dec()._solve(rhs(),dst);
00701 }
00702 };
00703
00704 template<typename _MatrixType, typename Rhs>
00705 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
00706 : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
00707 {
00708 typedef PastixBase<_MatrixType> Dec;
00709 EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
00710
00711 template<typename Dest> void evalTo(Dest& dst) const
00712 {
00713 this->defaultEvalTo(dst);
00714 }
00715 };
00716
00717 }
00718
00719 }
00720
00721 #endif