PaStiXSupport.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_PASTIXSUPPORT_H
00011 #define EIGEN_PASTIXSUPPORT_H
00012 
00013 namespace Eigen { 
00014 
00023 template<typename _MatrixType, bool IsStrSym = false> class PastixLU;
00024 template<typename _MatrixType, int Options> class PastixLLT;
00025 template<typename _MatrixType, int Options> class PastixLDLT;
00026 
00027 namespace internal
00028 {
00029     
00030   template<class Pastix> struct pastix_traits;
00031 
00032   template<typename _MatrixType>
00033   struct pastix_traits< PastixLU<_MatrixType> >
00034   {
00035     typedef _MatrixType MatrixType;
00036     typedef typename _MatrixType::Scalar Scalar;
00037     typedef typename _MatrixType::RealScalar RealScalar;
00038     typedef typename _MatrixType::Index Index;
00039   };
00040 
00041   template<typename _MatrixType, int Options>
00042   struct pastix_traits< PastixLLT<_MatrixType,Options> >
00043   {
00044     typedef _MatrixType MatrixType;
00045     typedef typename _MatrixType::Scalar Scalar;
00046     typedef typename _MatrixType::RealScalar RealScalar;
00047     typedef typename _MatrixType::Index Index;
00048   };
00049 
00050   template<typename _MatrixType, int Options>
00051   struct pastix_traits< PastixLDLT<_MatrixType,Options> >
00052   {
00053     typedef _MatrixType MatrixType;
00054     typedef typename _MatrixType::Scalar Scalar;
00055     typedef typename _MatrixType::RealScalar RealScalar;
00056     typedef typename _MatrixType::Index Index;
00057   };
00058   
00059   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, float *vals, int *perm, int * invp, float *x, int nbrhs, int *iparm, double *dparm)
00060   {
00061     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00062     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00063     s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm); 
00064   }
00065   
00066   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, double *vals, int *perm, int * invp, double *x, int nbrhs, int *iparm, double *dparm)
00067   {
00068     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00069     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00070     d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm); 
00071   }
00072   
00073   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<float> *vals, int *perm, int * invp, std::complex<float> *x, int nbrhs, int *iparm, double *dparm)
00074   {
00075     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00076     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00077     c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm); 
00078   }
00079   
00080   void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<double> *vals, int *perm, int * invp, std::complex<double> *x, int nbrhs, int *iparm, double *dparm)
00081   {
00082     if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
00083     if (nbrhs == 0) {x = NULL; nbrhs=1;}
00084     z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm); 
00085   }
00086 
00087   // Convert the matrix  to Fortran-style Numbering
00088   template <typename MatrixType>
00089   void c_to_fortran_numbering (MatrixType& mat)
00090   {
00091     if ( !(mat.outerIndexPtr()[0]) ) 
00092     { 
00093       int i;
00094       for(i = 0; i <= mat.rows(); ++i)
00095         ++mat.outerIndexPtr()[i];
00096       for(i = 0; i < mat.nonZeros(); ++i)
00097         ++mat.innerIndexPtr()[i];
00098     }
00099   }
00100   
00101   // Convert to C-style Numbering
00102   template <typename MatrixType>
00103   void fortran_to_c_numbering (MatrixType& mat)
00104   {
00105     // Check the Numbering
00106     if ( mat.outerIndexPtr()[0] == 1 ) 
00107     { // Convert to C-style numbering
00108       int i;
00109       for(i = 0; i <= mat.rows(); ++i)
00110         --mat.outerIndexPtr()[i];
00111       for(i = 0; i < mat.nonZeros(); ++i)
00112         --mat.innerIndexPtr()[i];
00113     }
00114   }
00115 }
00116 
00117 // This is the base class to interface with PaStiX functions. 
00118 // Users should not used this class directly. 
00119 template <class Derived>
00120 class PastixBase : internal::noncopyable
00121 {
00122   public:
00123     typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
00124     typedef _MatrixType MatrixType;
00125     typedef typename MatrixType::Scalar Scalar;
00126     typedef typename MatrixType::RealScalar RealScalar;
00127     typedef typename MatrixType::Index Index;
00128     typedef Matrix<Scalar,Dynamic,1> Vector;
00129     typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
00130     
00131   public:
00132     
00133     PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
00134     {
00135       init();
00136     }
00137     
00138     ~PastixBase() 
00139     {
00140       clean();
00141     }
00142 
00147     template<typename Rhs>
00148     inline const internal::solve_retval<PastixBase, Rhs>
00149     solve(const MatrixBase<Rhs>& b) const
00150     {
00151       eigen_assert(m_isInitialized && "Pastix solver is not initialized.");
00152       eigen_assert(rows()==b.rows()
00153                 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
00154       return internal::solve_retval<PastixBase, Rhs>(*this, b.derived());
00155     }
00156     
00157     template<typename Rhs,typename Dest>
00158     bool _solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const;
00159     
00160     Derived& derived()
00161     {
00162       return *static_cast<Derived*>(this);
00163     }
00164     const Derived& derived() const
00165     {
00166       return *static_cast<const Derived*>(this);
00167     }
00168 
00174     Array<Index,IPARM_SIZE,1>& iparm()
00175     {
00176       return m_iparm; 
00177     }
00178     
00183     int& iparm(int idxparam)
00184     {
00185       return m_iparm(idxparam);
00186     }
00187     
00192     Array<RealScalar,IPARM_SIZE,1>& dparm()
00193     {
00194       return m_dparm; 
00195     }
00196     
00197     
00201     double& dparm(int idxparam)
00202     {
00203       return m_dparm(idxparam);
00204     }
00205     
00206     inline Index cols() const { return m_size; }
00207     inline Index rows() const { return m_size; }
00208     
00217     ComputationInfo info() const
00218     {
00219       eigen_assert(m_isInitialized && "Decomposition is not initialized.");
00220       return m_info;
00221     }
00222     
00227     template<typename Rhs>
00228     inline const internal::sparse_solve_retval<PastixBase, Rhs>
00229     solve(const SparseMatrixBase<Rhs>& b) const
00230     {
00231       eigen_assert(m_isInitialized && "Pastix LU, LLT or LDLT is not initialized.");
00232       eigen_assert(rows()==b.rows()
00233                 && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
00234       return internal::sparse_solve_retval<PastixBase, Rhs>(*this, b.derived());
00235     }
00236     
00237   protected:
00238 
00239     // Initialize the Pastix data structure, check the matrix
00240     void init(); 
00241     
00242     // Compute the ordering and the symbolic factorization
00243     void analyzePattern(ColSpMatrix& mat);
00244     
00245     // Compute the numerical factorization
00246     void factorize(ColSpMatrix& mat);
00247     
00248     // Free all the data allocated by Pastix
00249     void clean()
00250     {
00251       eigen_assert(m_initisOk && "The Pastix structure should be allocated first"); 
00252       m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
00253       m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
00254       internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
00255                              m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00256     }
00257     
00258     void compute(ColSpMatrix& mat);
00259     
00260     int m_initisOk; 
00261     int m_analysisIsOk;
00262     int m_factorizationIsOk;
00263     bool m_isInitialized;
00264     mutable ComputationInfo m_info; 
00265     mutable pastix_data_t *m_pastixdata; // Data structure for pastix
00266     mutable int m_comm; // The MPI communicator identifier
00267     mutable Matrix<int,IPARM_SIZE,1> m_iparm; // integer vector for the input parameters
00268     mutable Matrix<double,DPARM_SIZE,1> m_dparm; // Scalar vector for the input parameters
00269     mutable Matrix<Index,Dynamic,1> m_perm;  // Permutation vector
00270     mutable Matrix<Index,Dynamic,1> m_invp;  // Inverse permutation vector
00271     mutable int m_size; // Size of the matrix 
00272 }; 
00273 
00278 template <class Derived>
00279 void PastixBase<Derived>::init()
00280 {
00281   m_size = 0; 
00282   m_iparm.setZero(IPARM_SIZE);
00283   m_dparm.setZero(DPARM_SIZE);
00284   
00285   m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
00286   pastix(&m_pastixdata, MPI_COMM_WORLD,
00287          0, 0, 0, 0,
00288          0, 0, 0, 1, m_iparm.data(), m_dparm.data());
00289   
00290   m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
00291   m_iparm[IPARM_VERBOSE]             = 2;
00292   m_iparm[IPARM_ORDERING]            = API_ORDER_SCOTCH;
00293   m_iparm[IPARM_INCOMPLETE]          = API_NO;
00294   m_iparm[IPARM_OOC_LIMIT]           = 2000;
00295   m_iparm[IPARM_RHS_MAKING]          = API_RHS_B;
00296   m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
00297   
00298   m_iparm(IPARM_START_TASK) = API_TASK_INIT;
00299   m_iparm(IPARM_END_TASK) = API_TASK_INIT;
00300   internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
00301                          0, 0, 0, 0, m_iparm.data(), m_dparm.data());
00302   
00303   // Check the returned error
00304   if(m_iparm(IPARM_ERROR_NUMBER)) {
00305     m_info = InvalidInput;
00306     m_initisOk = false;
00307   }
00308   else { 
00309     m_info = Success;
00310     m_initisOk = true;
00311   }
00312 }
00313 
00314 template <class Derived>
00315 void PastixBase<Derived>::compute(ColSpMatrix& mat)
00316 {
00317   eigen_assert(mat.rows() == mat.cols() && "The input matrix should be squared");
00318   
00319   analyzePattern(mat);  
00320   factorize(mat);
00321   
00322   m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
00323   m_isInitialized = m_factorizationIsOk;
00324 }
00325 
00326 
00327 template <class Derived>
00328 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
00329 {                         
00330   eigen_assert(m_initisOk && "The initialization of PaSTiX failed");
00331   
00332   // clean previous calls
00333   if(m_size>0)
00334     clean();
00335   
00336   m_size = mat.rows();
00337   m_perm.resize(m_size);
00338   m_invp.resize(m_size);
00339   
00340   m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
00341   m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
00342   internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
00343                mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00344   
00345   // Check the returned error
00346   if(m_iparm(IPARM_ERROR_NUMBER))
00347   {
00348     m_info = NumericalIssue;
00349     m_analysisIsOk = false;
00350   }
00351   else
00352   { 
00353     m_info = Success;
00354     m_analysisIsOk = true;
00355   }
00356 }
00357 
00358 template <class Derived>
00359 void PastixBase<Derived>::factorize(ColSpMatrix& mat)
00360 {
00361 //   if(&m_cpyMat != &mat) m_cpyMat = mat;
00362   eigen_assert(m_analysisIsOk && "The analysis phase should be called before the factorization phase");
00363   m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
00364   m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
00365   m_size = mat.rows();
00366   
00367   internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
00368                mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
00369   
00370   // Check the returned error
00371   if(m_iparm(IPARM_ERROR_NUMBER))
00372   {
00373     m_info = NumericalIssue;
00374     m_factorizationIsOk = false;
00375     m_isInitialized = false;
00376   }
00377   else
00378   {
00379     m_info = Success;
00380     m_factorizationIsOk = true;
00381     m_isInitialized = true;
00382   }
00383 }
00384 
00385 /* Solve the system */
00386 template<typename Base>
00387 template<typename Rhs,typename Dest>
00388 bool PastixBase<Base>::_solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const
00389 {
00390   eigen_assert(m_isInitialized && "The matrix should be factorized first");
00391   EIGEN_STATIC_ASSERT((Dest::Flags&RowMajorBit)==0,
00392                      THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
00393   int rhs = 1;
00394   
00395   x = b; /* on return, x is overwritten by the computed solution */
00396   
00397   for (int i = 0; i < b.cols(); i++){
00398     m_iparm[IPARM_START_TASK]          = API_TASK_SOLVE;
00399     m_iparm[IPARM_END_TASK]            = API_TASK_REFINE;
00400   
00401     internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
00402                            m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
00403   }
00404   
00405   // Check the returned error
00406   m_info = m_iparm(IPARM_ERROR_NUMBER)==0 ? Success : NumericalIssue;
00407   
00408   return m_iparm(IPARM_ERROR_NUMBER)==0;
00409 }
00410 
00430 template<typename _MatrixType, bool IsStrSym>
00431 class PastixLU : public PastixBase< PastixLU<_MatrixType> >
00432 {
00433   public:
00434     typedef _MatrixType MatrixType;
00435     typedef PastixBase<PastixLU<MatrixType> > Base;
00436     typedef typename Base::ColSpMatrix ColSpMatrix;
00437     typedef typename MatrixType::Index Index;
00438     
00439   public:
00440     PastixLU() : Base()
00441     {
00442       init();
00443     }
00444     
00445     PastixLU(const MatrixType& matrix):Base()
00446     {
00447       init();
00448       compute(matrix);
00449     }
00455     void compute (const MatrixType& matrix)
00456     {
00457       m_structureIsUptodate = false;
00458       ColSpMatrix temp;
00459       grabMatrix(matrix, temp);
00460       Base::compute(temp);
00461     }
00467     void analyzePattern(const MatrixType& matrix)
00468     {
00469       m_structureIsUptodate = false;
00470       ColSpMatrix temp;
00471       grabMatrix(matrix, temp);
00472       Base::analyzePattern(temp);
00473     }
00474 
00480     void factorize(const MatrixType& matrix)
00481     {
00482       ColSpMatrix temp;
00483       grabMatrix(matrix, temp);
00484       Base::factorize(temp);
00485     }
00486   protected:
00487     
00488     void init()
00489     {
00490       m_structureIsUptodate = false;
00491       m_iparm(IPARM_SYM) = API_SYM_NO;
00492       m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
00493     }
00494     
00495     void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00496     {
00497       if(IsStrSym)
00498         out = matrix;
00499       else
00500       {
00501         if(!m_structureIsUptodate)
00502         {
00503           // update the transposed structure
00504           m_transposedStructure = matrix.transpose();
00505           
00506           // Set the elements of the matrix to zero 
00507           for (Index j=0; j<m_transposedStructure.outerSize(); ++j) 
00508             for(typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
00509               it.valueRef() = 0.0;
00510 
00511           m_structureIsUptodate = true;
00512         }
00513         
00514         out = m_transposedStructure + matrix;
00515       }
00516       internal::c_to_fortran_numbering(out);
00517     }
00518     
00519     using Base::m_iparm;
00520     using Base::m_dparm;
00521     
00522     ColSpMatrix m_transposedStructure;
00523     bool m_structureIsUptodate;
00524 };
00525 
00540 template<typename _MatrixType, int _UpLo>
00541 class PastixLLT : public PastixBase< PastixLLT<_MatrixType, _UpLo> >
00542 {
00543   public:
00544     typedef _MatrixType MatrixType;
00545     typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
00546     typedef typename Base::ColSpMatrix ColSpMatrix;
00547     
00548   public:
00549     enum { UpLo = _UpLo };
00550     PastixLLT() : Base()
00551     {
00552       init();
00553     }
00554     
00555     PastixLLT(const MatrixType& matrix):Base()
00556     {
00557       init();
00558       compute(matrix);
00559     }
00560 
00564     void compute (const MatrixType& matrix)
00565     {
00566       ColSpMatrix temp;
00567       grabMatrix(matrix, temp);
00568       Base::compute(temp);
00569     }
00570 
00575     void analyzePattern(const MatrixType& matrix)
00576     {
00577       ColSpMatrix temp;
00578       grabMatrix(matrix, temp);
00579       Base::analyzePattern(temp);
00580     }
00584     void factorize(const MatrixType& matrix)
00585     {
00586       ColSpMatrix temp;
00587       grabMatrix(matrix, temp);
00588       Base::factorize(temp);
00589     }
00590   protected:
00591     using Base::m_iparm;
00592     
00593     void init()
00594     {
00595       m_iparm(IPARM_SYM) = API_SYM_YES;
00596       m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
00597     }
00598     
00599     void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00600     {
00601       // Pastix supports only lower, column-major matrices 
00602       out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
00603       internal::c_to_fortran_numbering(out);
00604     }
00605 };
00606 
00621 template<typename _MatrixType, int _UpLo>
00622 class PastixLDLT : public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
00623 {
00624   public:
00625     typedef _MatrixType MatrixType;
00626     typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base; 
00627     typedef typename Base::ColSpMatrix ColSpMatrix;
00628     
00629   public:
00630     enum { UpLo = _UpLo };
00631     PastixLDLT():Base()
00632     {
00633       init();
00634     }
00635     
00636     PastixLDLT(const MatrixType& matrix):Base()
00637     {
00638       init();
00639       compute(matrix);
00640     }
00641 
00645     void compute (const MatrixType& matrix)
00646     {
00647       ColSpMatrix temp;
00648       grabMatrix(matrix, temp);
00649       Base::compute(temp);
00650     }
00651 
00656     void analyzePattern(const MatrixType& matrix)
00657     { 
00658       ColSpMatrix temp;
00659       grabMatrix(matrix, temp);
00660       Base::analyzePattern(temp);
00661     }
00665     void factorize(const MatrixType& matrix)
00666     {
00667       ColSpMatrix temp;
00668       grabMatrix(matrix, temp);
00669       Base::factorize(temp);
00670     }
00671 
00672   protected:
00673     using Base::m_iparm;
00674     
00675     void init()
00676     {
00677       m_iparm(IPARM_SYM) = API_SYM_YES;
00678       m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
00679     }
00680     
00681     void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
00682     {
00683       // Pastix supports only lower, column-major matrices 
00684       out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
00685       internal::c_to_fortran_numbering(out);
00686     }
00687 };
00688 
00689 namespace internal {
00690 
00691 template<typename _MatrixType, typename Rhs>
00692 struct solve_retval<PastixBase<_MatrixType>, Rhs>
00693   : solve_retval_base<PastixBase<_MatrixType>, Rhs>
00694 {
00695   typedef PastixBase<_MatrixType> Dec;
00696   EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
00697 
00698   template<typename Dest> void evalTo(Dest& dst) const
00699   {
00700     dec()._solve(rhs(),dst);
00701   }
00702 };
00703 
00704 template<typename _MatrixType, typename Rhs>
00705 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
00706   : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
00707 {
00708   typedef PastixBase<_MatrixType> Dec;
00709   EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
00710 
00711   template<typename Dest> void evalTo(Dest& dst) const
00712   {
00713     this->defaultEvalTo(dst);
00714   }
00715 };
00716 
00717 } // end namespace internal
00718 
00719 } // end namespace Eigen
00720 
00721 #endif


shape_reconstruction
Author(s): Roberto Martín-Martín
autogenerated on Sat Jun 8 2019 18:34:05