00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 #ifndef EIGEN_PASTIXSUPPORT_H
00011 #define EIGEN_PASTIXSUPPORT_H
00012 
00013 namespace Eigen { 
00014 
00023 template<typename _MatrixType, bool IsStrSym = false> class PastixLU;
00024 template<typename _MatrixType, int Options> class PastixLLT;
00025 template<typename _MatrixType, int Options> class PastixLDLT;
00026 
00027 namespace internal
00028 {
00029     
00030   template<class Pastix> struct pastix_traits;
00031 
00032   template<typename _MatrixType>
00033   struct pastix_traits< PastixLU<_MatrixType> >
00034   {
00035     typedef _MatrixType MatrixType;
00036     typedef typename _MatrixType::Scalar Scalar;
00037     typedef typename _MatrixType::RealScalar RealScalar;
00038     typedef typename _MatrixType::Index Index;
00039   };
00040 
00041   template<typename _MatrixType, int Options>
00042   struct pastix_traits< PastixLLT<_MatrixType,Options> >
00043   {
00044     typedef _MatrixType MatrixType;
00045     typedef typename _MatrixType::Scalar Scalar;
00046     typedef typename _MatrixType::RealScalar RealScalar;
00047     typedef typename _MatrixType::Index Index;
00048   };
00049 
00050   template<typename _MatrixType, int Options>
00051   struct pastix_traits< PastixLDLT<_MatrixType,Options> >
00052   {
00053     typedef _MatrixType MatrixType;
00054     typedef typename _MatrixType::Scalar Scalar;
00055     typedef typename _MatrixType::RealScalar RealScalar;
00056     typedef typename _MatrixType::Index Index;
00057   };
00058   
00059   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, float *vals, int *perm, int * invp, float *x, int nbrhs, int *iparm, double *dparm)
00060   {
00061     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00062     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00063     s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm); 
00064   }
00065   
00066   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, double *vals, int *perm, int * invp, double *x, int nbrhs, int *iparm, double *dparm)
00067   {
00068     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00069     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00070     d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm); 
00071   }
00072   
00073   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<float> *vals, int *perm, int * invp, std::complex<float> *x, int nbrhs, int *iparm, double *dparm)
00074   {
00075     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00076     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00077     c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm); 
00078   }
00079   
00080   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<double> *vals, int *perm, int * invp, std::complex<double> *x, int nbrhs, int *iparm, double *dparm)
00081   {
00082     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00083     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00084     z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm); 
00085   }
00086 
00087   
00088   template <typename MatrixType>
00089   void c_to_fortran_numbering (MatrixType& mat)
00090   {
00091     if ( !(mat.outerIndexPtr()[0]) ) 
00092     { 
00093       int i;
00094       for(i = 0; i <= mat.rows(); ++i)
00095         ++mat.outerIndexPtr()[i];
00096       for(i = 0; i < mat.nonZeros(); ++i)
00097         ++mat.innerIndexPtr()[i];
00098     }
00099   }
00100   
00101   
00102   template <typename MatrixType>
00103   void fortran_to_c_numbering (MatrixType& mat)
00104   {
00105     
00106     if ( mat.outerIndexPtr()[0] == 1 ) 
00107     { 
00108       int i;
00109       for(i = 0; i <= mat.rows(); ++i)
00110         --mat.outerIndexPtr()[i];
00111       for(i = 0; i < mat.nonZeros(); ++i)
00112         --mat.innerIndexPtr()[i];
00113     }
00114   }
00115 }
00116 
00117 
00118 
00119 template <class Derived>
00120 class PastixBase : internal::noncopyable
00121 {
00122   public:
00123     typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
00124     typedef _MatrixType MatrixType;
00125     typedef typename MatrixType::Scalar Scalar;
00126     typedef typename MatrixType::RealScalar RealScalar;
00127     typedef typename MatrixType::Index Index;
00128     typedef Matrix<Scalar,Dynamic,1> Vector;
00129     typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
00130     
00131   public:
00132     
00133     PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
00134     {
00135       init();
00136     }
00137     
00138     ~PastixBase() 
00139     {
00140       clean();
00141     }
00142 
00147     template<typename Rhs>
00148     inline const internal::solve_retval<PastixBase, Rhs>
00149     solve(const MatrixBase<Rhs>& b) const
00150     {
00151       eigen_assert(m_isInitialized && "Pastix solver is not initialized.");
00152       eigen_assert(rows()==b.rows()
00153                 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
00154       return internal::solve_retval<PastixBase, Rhs>(*this, b.derived());
00155     }
00156     
00157     template<typename Rhs,typename Dest>
00158     bool _solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const;
00159     
00161     template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
00162     void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
00163     {
00164       eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()");
00165       eigen_assert(rows()==b.rows());
00166       
00167       
00168       static const int NbColsAtOnce = 1;
00169       int rhsCols = b.cols();
00170       int size = b.rows();
00171       Eigen::Matrix<DestScalar,Dynamic,Dynamic> tmp(size,rhsCols);
00172       for(int k=0; k<rhsCols; k+=NbColsAtOnce)
00173       {
00174         int actualCols = std::min<int>(rhsCols-k, NbColsAtOnce);
00175         tmp.leftCols(actualCols) = b.middleCols(k,actualCols);
00176         tmp.leftCols(actualCols) = derived().solve(tmp.leftCols(actualCols));
00177         dest.middleCols(k,actualCols) = tmp.leftCols(actualCols).sparseView();
00178       }
00179     }
00180     
00181     Derived& derived()
00182     {
00183       return *static_cast<Derived*>(this);
00184     }
00185     const Derived& derived() const
00186     {
00187       return *static_cast<const Derived*>(this);
00188     }
00189 
00195     Array<Index,IPARM_SIZE,1>& iparm()
00196     {
00197       return m_iparm; 
00198     }
00199     
00204     int& iparm(int idxparam)
00205     {
00206       return m_iparm(idxparam);
00207     }
00208     
00213     Array<RealScalar,IPARM_SIZE,1>& dparm()
00214     {
00215       return m_dparm; 
00216     }
00217     
00218     
00222     double& dparm(int idxparam)
00223     {
00224       return m_dparm(idxparam);
00225     }
00226     
00227     inline Index cols() const { return m_size; }
00228     inline Index rows() const { return m_size; }
00229     
00238     ComputationInfo info() const
00239     {
00240       eigen_assert(m_isInitialized && "Decomposition is not initialized.");
00241       return m_info;
00242     }
00243     
00248     template<typename Rhs>
00249     inline const internal::sparse_solve_retval<PastixBase, Rhs>
00250     solve(const SparseMatrixBase<Rhs>& b) const
00251     {
00252       eigen_assert(m_isInitialized && "Pastix LU, LLT or LDLT is not initialized.");
00253       eigen_assert(rows()==b.rows()
00254                 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
00255       return internal::sparse_solve_retval<PastixBase, Rhs>(*this, b.derived());
00256     }
00257     
00258   protected:
00259 
00260     
00261     void init(); 
00262     
00263     
00264     void analyzePattern(ColSpMatrix& mat);
00265     
00266     
00267     void factorize(ColSpMatrix& mat);
00268     
00269     
00270     void clean()
00271     {
00272       eigen_assert(m_initisOk && "The Pastix structure should be allocated first"); 
00273       m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
00274       m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
00275       internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
00276                              m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00277     }
00278     
00279     void compute(ColSpMatrix& mat);
00280     
00281     int m_initisOk; 
00282     int m_analysisIsOk;
00283     int m_factorizationIsOk;
00284     bool m_isInitialized;
00285     mutable ComputationInfo m_info; 
00286     mutable pastix_data_t *m_pastixdata; 
00287     mutable int m_comm; 
00288     mutable Matrix<int,IPARM_SIZE,1> m_iparm; 
00289     mutable Matrix<double,DPARM_SIZE,1> m_dparm; 
00290     mutable Matrix<Index,Dynamic,1> m_perm;  
00291     mutable Matrix<Index,Dynamic,1> m_invp;  
00292     mutable int m_size; 
00293 }; 
00294 
00299 template <class Derived>
00300 void PastixBase<Derived>::init()
00301 {
00302   m_size = 0; 
00303   m_iparm.setZero(IPARM_SIZE);
00304   m_dparm.setZero(DPARM_SIZE);
00305   
00306   m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
00307   pastix(&m_pastixdata, MPI_COMM_WORLD,
00308          0, 0, 0, 0,
00309          0, 0, 0, 1, m_iparm.data(), m_dparm.data());
00310   
00311   m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
00312   m_iparm[IPARM_VERBOSE]             = 2;
00313   m_iparm[IPARM_ORDERING]            = API_ORDER_SCOTCH;
00314   m_iparm[IPARM_INCOMPLETE]          = API_NO;
00315   m_iparm[IPARM_OOC_LIMIT]           = 2000;
00316   m_iparm[IPARM_RHS_MAKING]          = API_RHS_B;
00317   m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
00318   
00319   m_iparm(IPARM_START_TASK) = API_TASK_INIT;
00320   m_iparm(IPARM_END_TASK) = API_TASK_INIT;
00321   internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
00322                          0, 0, 0, 0, m_iparm.data(), m_dparm.data());
00323   
00324   
00325   if(m_iparm(IPARM_ERROR_NUMBER)) {
00326     m_info = InvalidInput;
00327     m_initisOk = false;
00328   }
00329   else { 
00330     m_info = Success;
00331     m_initisOk = true;
00332   }
00333 }
00334 
00335 template <class Derived>
00336 void PastixBase<Derived>::compute(ColSpMatrix& mat)
00337 {
00338   eigen_assert(mat.rows() == mat.cols() && "The input matrix should be squared");
00339   
00340   analyzePattern(mat);  
00341   factorize(mat);
00342   
00343   m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
00344   m_isInitialized = m_factorizationIsOk;
00345 }
00346 
00347 
00348 template <class Derived>
00349 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
00350 {                         
00351   eigen_assert(m_initisOk && "The initialization of PaSTiX failed");
00352   
00353   
00354   if(m_size>0)
00355     clean();
00356   
00357   m_size = mat.rows();
00358   m_perm.resize(m_size);
00359   m_invp.resize(m_size);
00360   
00361   m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
00362   m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
00363   internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
00364                mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00365   
00366   
00367   if(m_iparm(IPARM_ERROR_NUMBER))
00368   {
00369     m_info = NumericalIssue;
00370     m_analysisIsOk = false;
00371   }
00372   else
00373   { 
00374     m_info = Success;
00375     m_analysisIsOk = true;
00376   }
00377 }
00378 
00379 template <class Derived>
00380 void PastixBase<Derived>::factorize(ColSpMatrix& mat)
00381 {
00382 
00383   eigen_assert(m_analysisIsOk && "The analysis phase should be called before the factorization phase");
00384   m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
00385   m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
00386   m_size = mat.rows();
00387   
00388   internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
00389                mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00390   
00391   
00392   if(m_iparm(IPARM_ERROR_NUMBER))
00393   {
00394     m_info = NumericalIssue;
00395     m_factorizationIsOk = false;
00396     m_isInitialized = false;
00397   }
00398   else
00399   {
00400     m_info = Success;
00401     m_factorizationIsOk = true;
00402     m_isInitialized = true;
00403   }
00404 }
00405 
00406 
00407 template<typename Base>
00408 template<typename Rhs,typename Dest>
00409 bool PastixBase<Base>::_solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const
00410 {
00411   eigen_assert(m_isInitialized && "The matrix should be factorized first");
00412   EIGEN_STATIC_ASSERT((Dest::Flags&RowMajorBit)==0,
00413                      THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
00414   int rhs = 1;
00415   
00416   x = b; 
00417   
00418   for (int i = 0; i < b.cols(); i++){
00419     m_iparm[IPARM_START_TASK]          = API_TASK_SOLVE;
00420     m_iparm[IPARM_END_TASK]            = API_TASK_REFINE;
00421   
00422     internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
00423                            m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
00424   }
00425   
00426   
00427   m_info = m_iparm(IPARM_ERROR_NUMBER)==0 ? Success : NumericalIssue;
00428   
00429   return m_iparm(IPARM_ERROR_NUMBER)==0;
00430 }
00431 
00451 template<typename _MatrixType, bool IsStrSym>
00452 class PastixLU : public PastixBase< PastixLU<_MatrixType> >
00453 {
00454   public:
00455     typedef _MatrixType MatrixType;
00456     typedef PastixBase<PastixLU<MatrixType> > Base;
00457     typedef typename Base::ColSpMatrix ColSpMatrix;
00458     typedef typename MatrixType::Index Index;
00459     
00460   public:
00461     PastixLU() : Base()
00462     {
00463       init();
00464     }
00465     
00466     PastixLU(const MatrixType& matrix):Base()
00467     {
00468       init();
00469       compute(matrix);
00470     }
00476     void compute (const MatrixType& matrix)
00477     {
00478       m_structureIsUptodate = false;
00479       ColSpMatrix temp;
00480       grabMatrix(matrix, temp);
00481       Base::compute(temp);
00482     }
00488     void analyzePattern(const MatrixType& matrix)
00489     {
00490       m_structureIsUptodate = false;
00491       ColSpMatrix temp;
00492       grabMatrix(matrix, temp);
00493       Base::analyzePattern(temp);
00494     }
00495 
00501     void factorize(const MatrixType& matrix)
00502     {
00503       ColSpMatrix temp;
00504       grabMatrix(matrix, temp);
00505       Base::factorize(temp);
00506     }
00507   protected:
00508     
00509     void init()
00510     {
00511       m_structureIsUptodate = false;
00512       m_iparm(IPARM_SYM) = API_SYM_NO;
00513       m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
00514     }
00515     
00516     void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00517     {
00518       if(IsStrSym)
00519         out = matrix;
00520       else
00521       {
00522         if(!m_structureIsUptodate)
00523         {
00524           
00525           m_transposedStructure = matrix.transpose();
00526           
00527           
00528           for (Index j=0; j<m_transposedStructure.outerSize(); ++j) 
00529             for(typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
00530               it.valueRef() = 0.0;
00531 
00532           m_structureIsUptodate = true;
00533         }
00534         
00535         out = m_transposedStructure + matrix;
00536       }
00537       internal::c_to_fortran_numbering(out);
00538     }
00539     
00540     using Base::m_iparm;
00541     using Base::m_dparm;
00542     
00543     ColSpMatrix m_transposedStructure;
00544     bool m_structureIsUptodate;
00545 };
00546 
00561 template<typename _MatrixType, int _UpLo>
00562 class PastixLLT : public PastixBase< PastixLLT<_MatrixType, _UpLo> >
00563 {
00564   public:
00565     typedef _MatrixType MatrixType;
00566     typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
00567     typedef typename Base::ColSpMatrix ColSpMatrix;
00568     
00569   public:
00570     enum { UpLo = _UpLo };
00571     PastixLLT() : Base()
00572     {
00573       init();
00574     }
00575     
00576     PastixLLT(const MatrixType& matrix):Base()
00577     {
00578       init();
00579       compute(matrix);
00580     }
00581 
00585     void compute (const MatrixType& matrix)
00586     {
00587       ColSpMatrix temp;
00588       grabMatrix(matrix, temp);
00589       Base::compute(temp);
00590     }
00591 
00596     void analyzePattern(const MatrixType& matrix)
00597     {
00598       ColSpMatrix temp;
00599       grabMatrix(matrix, temp);
00600       Base::analyzePattern(temp);
00601     }
00605     void factorize(const MatrixType& matrix)
00606     {
00607       ColSpMatrix temp;
00608       grabMatrix(matrix, temp);
00609       Base::factorize(temp);
00610     }
00611   protected:
00612     using Base::m_iparm;
00613     
00614     void init()
00615     {
00616       m_iparm(IPARM_SYM) = API_SYM_YES;
00617       m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
00618     }
00619     
00620     void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00621     {
00622       
00623       out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
00624       internal::c_to_fortran_numbering(out);
00625     }
00626 };
00627 
00642 template<typename _MatrixType, int _UpLo>
00643 class PastixLDLT : public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
00644 {
00645   public:
00646     typedef _MatrixType MatrixType;
00647     typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base; 
00648     typedef typename Base::ColSpMatrix ColSpMatrix;
00649     
00650   public:
00651     enum { UpLo = _UpLo };
00652     PastixLDLT():Base()
00653     {
00654       init();
00655     }
00656     
00657     PastixLDLT(const MatrixType& matrix):Base()
00658     {
00659       init();
00660       compute(matrix);
00661     }
00662 
00666     void compute (const MatrixType& matrix)
00667     {
00668       ColSpMatrix temp;
00669       grabMatrix(matrix, temp);
00670       Base::compute(temp);
00671     }
00672 
00677     void analyzePattern(const MatrixType& matrix)
00678     { 
00679       ColSpMatrix temp;
00680       grabMatrix(matrix, temp);
00681       Base::analyzePattern(temp);
00682     }
00686     void factorize(const MatrixType& matrix)
00687     {
00688       ColSpMatrix temp;
00689       grabMatrix(matrix, temp);
00690       Base::factorize(temp);
00691     }
00692 
00693   protected:
00694     using Base::m_iparm;
00695     
00696     void init()
00697     {
00698       m_iparm(IPARM_SYM) = API_SYM_YES;
00699       m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
00700     }
00701     
00702     void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00703     {
00704       
00705       out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
00706       internal::c_to_fortran_numbering(out);
00707     }
00708 };
00709 
00710 namespace internal {
00711 
00712 template<typename _MatrixType, typename Rhs>
00713 struct solve_retval<PastixBase<_MatrixType>, Rhs>
00714   : solve_retval_base<PastixBase<_MatrixType>, Rhs>
00715 {
00716   typedef PastixBase<_MatrixType> Dec;
00717   EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
00718 
00719   template<typename Dest> void evalTo(Dest& dst) const
00720   {
00721     dec()._solve(rhs(),dst);
00722   }
00723 };
00724 
00725 template<typename _MatrixType, typename Rhs>
00726 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
00727   : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
00728 {
00729   typedef PastixBase<_MatrixType> Dec;
00730   EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
00731 
00732   template<typename Dest> void evalTo(Dest& dst) const
00733   {
00734     dec()._solve_sparse(rhs(),dst);
00735   }
00736 };
00737 
00738 } 
00739 
00740 } 
00741 
00742 #endif