MatrixSquareRoot.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_MATRIX_SQUARE_ROOT
00011 #define EIGEN_MATRIX_SQUARE_ROOT
00012 
00013 namespace Eigen { 
00014 
00026 template <typename MatrixType>
00027 class MatrixSquareRootQuasiTriangular
00028 {
00029   public:
00030 
00039     MatrixSquareRootQuasiTriangular(const MatrixType& A) 
00040       : m_A(A) 
00041     {
00042       eigen_assert(A.rows() == A.cols());
00043     }
00044     
00053     template <typename ResultType> void compute(ResultType &result);    
00054     
00055   private:
00056     typedef typename MatrixType::Index Index;
00057     typedef typename MatrixType::Scalar Scalar;
00058     
00059     void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00060     void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00061     void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
00062     void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00063                                   typename MatrixType::Index i, typename MatrixType::Index j);
00064     void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00065                                   typename MatrixType::Index i, typename MatrixType::Index j);
00066     void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00067                                   typename MatrixType::Index i, typename MatrixType::Index j);
00068     void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00069                                   typename MatrixType::Index i, typename MatrixType::Index j);
00070   
00071     template <typename SmallMatrixType>
00072     static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 
00073                                      const SmallMatrixType& B, const SmallMatrixType& C);
00074   
00075     const MatrixType& m_A;
00076 };
00077 
00078 template <typename MatrixType>
00079 template <typename ResultType> 
00080 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
00081 {
00082   // Compute Schur decomposition of m_A
00083   const RealSchur<MatrixType> schurOfA(m_A);  
00084   const MatrixType& T = schurOfA.matrixT();
00085   const MatrixType& U = schurOfA.matrixU();
00086 
00087   // Compute square root of T
00088   MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00089   computeDiagonalPartOfSqrt(sqrtT, T);
00090   computeOffDiagonalPartOfSqrt(sqrtT, T);
00091 
00092   // Compute square root of m_A
00093   result = U * sqrtT * U.adjoint();
00094 }
00095 
00096 // pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
00097 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
00098 template <typename MatrixType>
00099 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, 
00100                                                                           const MatrixType& T)
00101 {
00102   const Index size = m_A.rows();
00103   for (Index i = 0; i < size; i++) {
00104     if (i == size - 1 || T.coeff(i+1, i) == 0) {
00105       eigen_assert(T(i,i) > 0);
00106       sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00107     }
00108     else {
00109       compute2x2diagonalBlock(sqrtT, T, i);
00110       ++i;
00111     }
00112   }
00113 }
00114 
00115 // pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
00116 // post: sqrtT is the square root of T.
00117 template <typename MatrixType>
00118 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, 
00119                                                                              const MatrixType& T)
00120 {
00121   const Index size = m_A.rows();
00122   for (Index j = 1; j < size; j++) {
00123       if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
00124         continue;
00125     for (Index i = j-1; i >= 0; i--) {
00126       if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
00127         continue;
00128       bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
00129       bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
00130       if (iBlockIs2x2 && jBlockIs2x2) 
00131         compute2x2offDiagonalBlock(sqrtT, T, i, j);
00132       else if (iBlockIs2x2 && !jBlockIs2x2) 
00133         compute2x1offDiagonalBlock(sqrtT, T, i, j);
00134       else if (!iBlockIs2x2 && jBlockIs2x2) 
00135         compute1x2offDiagonalBlock(sqrtT, T, i, j);
00136       else if (!iBlockIs2x2 && !jBlockIs2x2) 
00137         compute1x1offDiagonalBlock(sqrtT, T, i, j);
00138     }
00139   }
00140 }
00141 
00142 // pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
00143 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
00144 template <typename MatrixType>
00145 void MatrixSquareRootQuasiTriangular<MatrixType>
00146      ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
00147 {
00148   // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
00149   //       in EigenSolver. If we expose it, we could call it directly from here.
00150   Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
00151   EigenSolver<Matrix<Scalar,2,2> > es(block);
00152   sqrtT.template block<2,2>(i,i)
00153     = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
00154 }
00155 
00156 // pre:  block structure of T is such that (i,j) is a 1x1 block,
00157 //       all blocks of sqrtT to left of and below (i,j) are correct
00158 // post: sqrtT(i,j) has the correct value
00159 template <typename MatrixType>
00160 void MatrixSquareRootQuasiTriangular<MatrixType>
00161      ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00162                                   typename MatrixType::Index i, typename MatrixType::Index j)
00163 {
00164   Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
00165   sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
00166 }
00167 
00168 // similar to compute1x1offDiagonalBlock()
00169 template <typename MatrixType>
00170 void MatrixSquareRootQuasiTriangular<MatrixType>
00171      ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00172                                   typename MatrixType::Index i, typename MatrixType::Index j)
00173 {
00174   Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
00175   if (j-i > 1)
00176     rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
00177   Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
00178   A += sqrtT.template block<2,2>(j,j).transpose();
00179   sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
00180 }
00181 
00182 // similar to compute1x1offDiagonalBlock()
00183 template <typename MatrixType>
00184 void MatrixSquareRootQuasiTriangular<MatrixType>
00185      ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00186                                   typename MatrixType::Index i, typename MatrixType::Index j)
00187 {
00188   Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
00189   if (j-i > 2)
00190     rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
00191   Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
00192   A += sqrtT.template block<2,2>(i,i);
00193   sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
00194 }
00195 
00196 // similar to compute1x1offDiagonalBlock()
00197 template <typename MatrixType>
00198 void MatrixSquareRootQuasiTriangular<MatrixType>
00199      ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 
00200                                   typename MatrixType::Index i, typename MatrixType::Index j)
00201 {
00202   Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
00203   Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
00204   Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
00205   if (j-i > 2)
00206     C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
00207   Matrix<Scalar,2,2> X;
00208   solveAuxiliaryEquation(X, A, B, C);
00209   sqrtT.template block<2,2>(i,j) = X;
00210 }
00211 
00212 // solves the equation A X + X B = C where all matrices are 2-by-2
00213 template <typename MatrixType>
00214 template <typename SmallMatrixType>
00215 void MatrixSquareRootQuasiTriangular<MatrixType>
00216      ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00217                               const SmallMatrixType& B, const SmallMatrixType& C)
00218 {
00219   EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
00220                       EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
00221 
00222   Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
00223   coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
00224   coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
00225   coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
00226   coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
00227   coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
00228   coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
00229   coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
00230   coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
00231   coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
00232   coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
00233   coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
00234   coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
00235   
00236   Matrix<Scalar,4,1> rhs;
00237   rhs.coeffRef(0) = C.coeff(0,0);
00238   rhs.coeffRef(1) = C.coeff(0,1);
00239   rhs.coeffRef(2) = C.coeff(1,0);
00240   rhs.coeffRef(3) = C.coeff(1,1);
00241   
00242   Matrix<Scalar,4,1> result;
00243   result = coeffMatrix.fullPivLu().solve(rhs);
00244 
00245   X.coeffRef(0,0) = result.coeff(0);
00246   X.coeffRef(0,1) = result.coeff(1);
00247   X.coeffRef(1,0) = result.coeff(2);
00248   X.coeffRef(1,1) = result.coeff(3);
00249 }
00250 
00251 
00263 template <typename MatrixType>
00264 class MatrixSquareRootTriangular
00265 {
00266   public:
00267     MatrixSquareRootTriangular(const MatrixType& A) 
00268       : m_A(A) 
00269     {
00270       eigen_assert(A.rows() == A.cols());
00271     }
00272 
00282     template <typename ResultType> void compute(ResultType &result);    
00283 
00284  private:
00285     const MatrixType& m_A;
00286 };
00287 
00288 template <typename MatrixType>
00289 template <typename ResultType> 
00290 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
00291 {
00292   // Compute Schur decomposition of m_A
00293   const ComplexSchur<MatrixType> schurOfA(m_A);  
00294   const MatrixType& T = schurOfA.matrixT();
00295   const MatrixType& U = schurOfA.matrixU();
00296 
00297   // Compute square root of T and store it in upper triangular part of result
00298   // This uses that the square root of triangular matrices can be computed directly.
00299   result.resize(m_A.rows(), m_A.cols());
00300   typedef typename MatrixType::Index Index;
00301   for (Index i = 0; i < m_A.rows(); i++) {
00302     result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00303   }
00304   for (Index j = 1; j < m_A.cols(); j++) {
00305     for (Index i = j-1; i >= 0; i--) {
00306       typedef typename MatrixType::Scalar Scalar;
00307       // if i = j-1, then segment has length 0 so tmp = 0
00308       Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
00309       // denominator may be zero if original matrix is singular
00310       result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
00311     }
00312   }
00313 
00314   // Compute square root of m_A as U * result * U.adjoint()
00315   MatrixType tmp;
00316   tmp.noalias() = U * result.template triangularView<Upper>();
00317   result.noalias() = tmp * U.adjoint();
00318 }
00319 
00320 
00328 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
00329 class MatrixSquareRoot
00330 {
00331   public:
00332 
00340     MatrixSquareRoot(const MatrixType& A); 
00341     
00349     template <typename ResultType> void compute(ResultType &result);    
00350 };
00351 
00352 
00353 // ********** Partial specialization for real matrices **********
00354 
00355 template <typename MatrixType>
00356 class MatrixSquareRoot<MatrixType, 0>
00357 {
00358   public:
00359 
00360     MatrixSquareRoot(const MatrixType& A) 
00361       : m_A(A) 
00362     {  
00363       eigen_assert(A.rows() == A.cols());
00364     }
00365   
00366     template <typename ResultType> void compute(ResultType &result)
00367     {
00368       // Compute Schur decomposition of m_A
00369       const RealSchur<MatrixType> schurOfA(m_A);  
00370       const MatrixType& T = schurOfA.matrixT();
00371       const MatrixType& U = schurOfA.matrixU();
00372     
00373       // Compute square root of T
00374       MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
00375       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00376       tmp.compute(sqrtT);
00377     
00378       // Compute square root of m_A
00379       result = U * sqrtT * U.adjoint();
00380     }
00381     
00382   private:
00383     const MatrixType& m_A;
00384 };
00385 
00386 
00387 // ********** Partial specialization for complex matrices **********
00388 
00389 template <typename MatrixType>
00390 class MatrixSquareRoot<MatrixType, 1>
00391 {
00392   public:
00393 
00394     MatrixSquareRoot(const MatrixType& A) 
00395       : m_A(A) 
00396     {  
00397       eigen_assert(A.rows() == A.cols());
00398     }
00399   
00400     template <typename ResultType> void compute(ResultType &result)
00401     {
00402       // Compute Schur decomposition of m_A
00403       const ComplexSchur<MatrixType> schurOfA(m_A);  
00404       const MatrixType& T = schurOfA.matrixT();
00405       const MatrixType& U = schurOfA.matrixU();
00406     
00407       // Compute square root of T
00408       MatrixSquareRootTriangular<MatrixType> tmp(T);
00409       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00410       tmp.compute(sqrtT);
00411     
00412       // Compute square root of m_A
00413       result = U * sqrtT * U.adjoint();
00414     }
00415     
00416   private:
00417     const MatrixType& m_A;
00418 };
00419 
00420 
00433 template<typename Derived> class MatrixSquareRootReturnValue
00434 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
00435 {
00436     typedef typename Derived::Index Index;
00437   public:
00443     MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
00444 
00450     template <typename ResultType>
00451     inline void evalTo(ResultType& result) const
00452     {
00453       const typename Derived::PlainObject srcEvaluated = m_src.eval();
00454       MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
00455       me.compute(result);
00456     }
00457 
00458     Index rows() const { return m_src.rows(); }
00459     Index cols() const { return m_src.cols(); }
00460 
00461   protected:
00462     const Derived& m_src;
00463   private:
00464     MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
00465 };
00466 
00467 namespace internal {
00468 template<typename Derived>
00469 struct traits<MatrixSquareRootReturnValue<Derived> >
00470 {
00471   typedef typename Derived::PlainObject ReturnType;
00472 };
00473 }
00474 
00475 template <typename Derived>
00476 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
00477 {
00478   eigen_assert(rows() == cols());
00479   return MatrixSquareRootReturnValue<Derived>(derived());
00480 }
00481 
00482 } // end namespace Eigen
00483 
00484 #endif // EIGEN_MATRIX_FUNCTION


win_eigen
Author(s): Daniel Stonier
autogenerated on Wed Sep 16 2015 07:11:18