ConjugateGradient.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CONJUGATE_GRADIENT_H
00011 #define EIGEN_CONJUGATE_GRADIENT_H
00012 
00013 namespace Eigen { 
00014 
00015 namespace internal {
00016 
00026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
00027 EIGEN_DONT_INLINE
00028 void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
00029                         const Preconditioner& precond, int& iters,
00030                         typename Dest::RealScalar& tol_error)
00031 {
00032   using std::sqrt;
00033   using std::abs;
00034   typedef typename Dest::RealScalar RealScalar;
00035   typedef typename Dest::Scalar Scalar;
00036   typedef Matrix<Scalar,Dynamic,1> VectorType;
00037   
00038   RealScalar tol = tol_error;
00039   int maxIters = iters;
00040   
00041   int n = mat.cols();
00042 
00043   VectorType residual = rhs - mat * x; //initial residual
00044   VectorType p(n);
00045 
00046   p = precond.solve(residual);      //initial search direction
00047 
00048   VectorType z(n), tmp(n);
00049   RealScalar absNew = internal::real(residual.dot(p));  // the square of the absolute value of r scaled by invM
00050   RealScalar rhsNorm2 = rhs.squaredNorm();
00051   RealScalar residualNorm2 = 0;
00052   RealScalar threshold = tol*tol*rhsNorm2;
00053   int i = 0;
00054   while(i < maxIters)
00055   {
00056     tmp.noalias() = mat * p;              // the bottleneck of the algorithm
00057 
00058     Scalar alpha = absNew / p.dot(tmp);   // the amount we travel on dir
00059     x += alpha * p;                       // update solution
00060     residual -= alpha * tmp;              // update residue
00061     
00062     residualNorm2 = residual.squaredNorm();
00063     if(residualNorm2 < threshold)
00064       break;
00065     
00066     z = precond.solve(residual);          // approximately solve for "A z = residual"
00067 
00068     RealScalar absOld = absNew;
00069     absNew = internal::real(residual.dot(z));     // update the absolute value of r
00070     RealScalar beta = absNew / absOld;            // calculate the Gram-Schmidt value used to create the new search direction
00071     p = z + beta * p;                             // update search direction
00072     i++;
00073   }
00074   tol_error = sqrt(residualNorm2 / rhsNorm2);
00075   iters = i;
00076 }
00077 
00078 }
00079 
00080 template< typename _MatrixType, int _UpLo=Lower,
00081           typename _Preconditioner = DiagonalPreconditioner<typename _MatrixType::Scalar> >
00082 class ConjugateGradient;
00083 
00084 namespace internal {
00085 
00086 template< typename _MatrixType, int _UpLo, typename _Preconditioner>
00087 struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
00088 {
00089   typedef _MatrixType MatrixType;
00090   typedef _Preconditioner Preconditioner;
00091 };
00092 
00093 }
00094 
00143 template< typename _MatrixType, int _UpLo, typename _Preconditioner>
00144 class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
00145 {
00146   typedef IterativeSolverBase<ConjugateGradient> Base;
00147   using Base::mp_matrix;
00148   using Base::m_error;
00149   using Base::m_iterations;
00150   using Base::m_info;
00151   using Base::m_isInitialized;
00152 public:
00153   typedef _MatrixType MatrixType;
00154   typedef typename MatrixType::Scalar Scalar;
00155   typedef typename MatrixType::Index Index;
00156   typedef typename MatrixType::RealScalar RealScalar;
00157   typedef _Preconditioner Preconditioner;
00158 
00159   enum {
00160     UpLo = _UpLo
00161   };
00162 
00163 public:
00164 
00166   ConjugateGradient() : Base() {}
00167 
00178   ConjugateGradient(const MatrixType& A) : Base(A) {}
00179 
00180   ~ConjugateGradient() {}
00181   
00187   template<typename Rhs,typename Guess>
00188   inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess>
00189   solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
00190   {
00191     eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
00192     eigen_assert(Base::rows()==b.rows()
00193               && "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b");
00194     return internal::solve_retval_with_guess
00195             <ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0);
00196   }
00197 
00199   template<typename Rhs,typename Dest>
00200   void _solveWithGuess(const Rhs& b, Dest& x) const
00201   {
00202     m_iterations = Base::maxIterations();
00203     m_error = Base::m_tolerance;
00204 
00205     for(int j=0; j<b.cols(); ++j)
00206     {
00207       m_iterations = Base::maxIterations();
00208       m_error = Base::m_tolerance;
00209 
00210       typename Dest::ColXpr xj(x,j);
00211       internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
00212                                    Base::m_preconditioner, m_iterations, m_error);
00213     }
00214 
00215     m_isInitialized = true;
00216     m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
00217   }
00218   
00220   template<typename Rhs,typename Dest>
00221   void _solve(const Rhs& b, Dest& x) const
00222   {
00223     x.setOnes();
00224     _solveWithGuess(b,x);
00225   }
00226 
00227 protected:
00228 
00229 };
00230 
00231 
00232 namespace internal {
00233 
00234 template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs>
00235 struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
00236   : solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
00237 {
00238   typedef ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> Dec;
00239   EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
00240 
00241   template<typename Dest> void evalTo(Dest& dst) const
00242   {
00243     dec()._solve(rhs(),dst);
00244   }
00245 };
00246 
00247 } // end namespace internal
00248 
00249 } // end namespace Eigen
00250 
00251 #endif // EIGEN_CONJUGATE_GRADIENT_H


win_eigen
Author(s): Daniel Stonier
autogenerated on Mon Oct 6 2014 12:24:21