LeastSquareConjugateGradient.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_LEAST_SQUARE_CONJUGATE_GRADIENT_H
11 #define EIGEN_LEAST_SQUARE_CONJUGATE_GRADIENT_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
26 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
28 void least_square_conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
29  const Preconditioner& precond, Index& iters,
30  typename Dest::RealScalar& tol_error)
31 {
32  using std::sqrt;
33  using std::abs;
34  typedef typename Dest::RealScalar RealScalar;
35  typedef typename Dest::Scalar Scalar;
37 
38  RealScalar tol = tol_error;
39  Index maxIters = iters;
40 
41  Index m = mat.rows(), n = mat.cols();
42 
43  VectorType residual = rhs - mat * x;
44  VectorType normal_residual = mat.adjoint() * residual;
45 
46  RealScalar rhsNorm2 = (mat.adjoint()*rhs).squaredNorm();
47  if(rhsNorm2 == 0)
48  {
49  x.setZero();
50  iters = 0;
51  tol_error = 0;
52  return;
53  }
54  RealScalar threshold = tol*tol*rhsNorm2;
55  RealScalar residualNorm2 = normal_residual.squaredNorm();
56  if (residualNorm2 < threshold)
57  {
58  iters = 0;
59  tol_error = sqrt(residualNorm2 / rhsNorm2);
60  return;
61  }
62 
63  VectorType p(n);
64  p = precond.solve(normal_residual); // initial search direction
65 
66  VectorType z(n), tmp(m);
67  RealScalar absNew = numext::real(normal_residual.dot(p)); // the square of the absolute value of r scaled by invM
68  Index i = 0;
69  while(i < maxIters)
70  {
71  tmp.noalias() = mat * p;
72 
73  Scalar alpha = absNew / tmp.squaredNorm(); // the amount we travel on dir
74  x += alpha * p; // update solution
75  residual -= alpha * tmp; // update residual
76  normal_residual = mat.adjoint() * residual; // update residual of the normal equation
77 
78  residualNorm2 = normal_residual.squaredNorm();
79  if(residualNorm2 < threshold)
80  break;
81 
82  z = precond.solve(normal_residual); // approximately solve for "A'A z = normal_residual"
83 
84  RealScalar absOld = absNew;
85  absNew = numext::real(normal_residual.dot(z)); // update the absolute value of r
86  RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction
87  p = z + beta * p; // update search direction
88  i++;
89  }
90  tol_error = sqrt(residualNorm2 / rhsNorm2);
91  iters = i;
92 }
93 
94 }
95 
96 template< typename _MatrixType,
99 
100 namespace internal {
101 
102 template< typename _MatrixType, typename _Preconditioner>
103 struct traits<LeastSquaresConjugateGradient<_MatrixType,_Preconditioner> >
104 {
105  typedef _MatrixType MatrixType;
106  typedef _Preconditioner Preconditioner;
107 };
108 
109 }
110 
148 template< typename _MatrixType, typename _Preconditioner>
149 class LeastSquaresConjugateGradient : public IterativeSolverBase<LeastSquaresConjugateGradient<_MatrixType,_Preconditioner> >
150 {
152  using Base::matrix;
153  using Base::m_error;
154  using Base::m_iterations;
155  using Base::m_info;
156  using Base::m_isInitialized;
157 public:
158  typedef _MatrixType MatrixType;
159  typedef typename MatrixType::Scalar Scalar;
161  typedef _Preconditioner Preconditioner;
162 
163 public:
164 
167 
178  template<typename MatrixDerived>
179  explicit LeastSquaresConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {}
180 
182 
184  template<typename Rhs,typename Dest>
185  void _solve_with_guess_impl(const Rhs& b, Dest& x) const
186  {
187  m_iterations = Base::maxIterations();
188  m_error = Base::m_tolerance;
189 
190  for(Index j=0; j<b.cols(); ++j)
191  {
192  m_iterations = Base::maxIterations();
193  m_error = Base::m_tolerance;
194 
195  typename Dest::ColXpr xj(x,j);
196  internal::least_square_conjugate_gradient(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
197  }
198 
199  m_isInitialized = true;
200  m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
201  }
202 
204  using Base::_solve_impl;
205  template<typename Rhs,typename Dest>
206  void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
207  {
208  x.setZero();
209  _solve_with_guess_impl(b.derived(),x);
210  }
211 
212 };
213 
214 } // end namespace Eigen
215 
216 #endif // EIGEN_LEAST_SQUARE_CONJUGATE_GRADIENT_H
Matrix3f m
SCALAR Scalar
Definition: bench_gemm.cpp:33
float real
Definition: datatypes.h:10
Scalar * b
Definition: benchVecAdd.cpp:17
IterativeSolverBase< LeastSquaresConjugateGradient > Base
void _solve_impl(const MatrixBase< Rhs > &b, Dest &x) const
int n
EIGEN_DEVICE_FUNC const SqrtReturnType sqrt() const
Namespace containing all symbols from the Eigen library.
Definition: jet.h:637
Block< Derived, internal::traits< Derived >::RowsAtCompileTime, 1,!IsRowMajor > ColXpr
Definition: BlockMethods.h:14
A conjugate gradient solver for sparse (or dense) least-square problems.
MatrixXf MatrixType
Jacobi preconditioner for LeastSquaresConjugateGradient.
#define EIGEN_DONT_INLINE
Definition: Macros.h:517
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33
RealScalar alpha
void _solve_with_guess_impl(const Rhs &b, Dest &x) const
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:34
EIGEN_DONT_INLINE void least_square_conjugate_gradient(const MatrixType &mat, const Rhs &rhs, Dest &x, const Preconditioner &precond, Index &iters, typename Dest::RealScalar &tol_error)
float * p
LeastSquaresConjugateGradient(const EigenBase< MatrixDerived > &A)
const G double tol
Definition: Group.h:83
Map< Matrix< T, Dynamic, Dynamic, ColMajor >, 0, OuterStride<> > matrix(T *data, int rows, int cols, int stride)
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
#define abs(x)
Definition: datatypes.h:17
Base class for linear iterative solvers.
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48
std::ptrdiff_t j


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:42:30