00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_MATRIX_SQUARE_ROOT
00011 #define EIGEN_MATRIX_SQUARE_ROOT
00012
00013 namespace Eigen {
00014
00026 template <typename MatrixType>
00027 class MatrixSquareRootQuasiTriangular
00028 {
00029 public:
00030
00039 MatrixSquareRootQuasiTriangular(const MatrixType& A)
00040 : m_A(A)
00041 {
00042 eigen_assert(A.rows() == A.cols());
00043 }
00044
00053 template <typename ResultType> void compute(ResultType &result);
00054
00055 private:
00056 typedef typename MatrixType::Index Index;
00057 typedef typename MatrixType::Scalar Scalar;
00058
00059 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00060 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00061 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
00062 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00063 typename MatrixType::Index i, typename MatrixType::Index j);
00064 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00065 typename MatrixType::Index i, typename MatrixType::Index j);
00066 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00067 typename MatrixType::Index i, typename MatrixType::Index j);
00068 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00069 typename MatrixType::Index i, typename MatrixType::Index j);
00070
00071 template <typename SmallMatrixType>
00072 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00073 const SmallMatrixType& B, const SmallMatrixType& C);
00074
00075 const MatrixType& m_A;
00076 };
00077
00078 template <typename MatrixType>
00079 template <typename ResultType>
00080 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
00081 {
00082 result.resize(m_A.rows(), m_A.cols());
00083 computeDiagonalPartOfSqrt(result, m_A);
00084 computeOffDiagonalPartOfSqrt(result, m_A);
00085 }
00086
00087
00088
00089 template <typename MatrixType>
00090 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
00091 const MatrixType& T)
00092 {
00093 using std::sqrt;
00094 const Index size = m_A.rows();
00095 for (Index i = 0; i < size; i++) {
00096 if (i == size - 1 || T.coeff(i+1, i) == 0) {
00097 eigen_assert(T(i,i) >= 0);
00098 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
00099 }
00100 else {
00101 compute2x2diagonalBlock(sqrtT, T, i);
00102 ++i;
00103 }
00104 }
00105 }
00106
00107
00108
00109 template <typename MatrixType>
00110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
00111 const MatrixType& T)
00112 {
00113 const Index size = m_A.rows();
00114 for (Index j = 1; j < size; j++) {
00115 if (T.coeff(j, j-1) != 0)
00116 continue;
00117 for (Index i = j-1; i >= 0; i--) {
00118 if (i > 0 && T.coeff(i, i-1) != 0)
00119 continue;
00120 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
00121 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
00122 if (iBlockIs2x2 && jBlockIs2x2)
00123 compute2x2offDiagonalBlock(sqrtT, T, i, j);
00124 else if (iBlockIs2x2 && !jBlockIs2x2)
00125 compute2x1offDiagonalBlock(sqrtT, T, i, j);
00126 else if (!iBlockIs2x2 && jBlockIs2x2)
00127 compute1x2offDiagonalBlock(sqrtT, T, i, j);
00128 else if (!iBlockIs2x2 && !jBlockIs2x2)
00129 compute1x1offDiagonalBlock(sqrtT, T, i, j);
00130 }
00131 }
00132 }
00133
00134
00135
00136 template <typename MatrixType>
00137 void MatrixSquareRootQuasiTriangular<MatrixType>
00138 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
00139 {
00140
00141
00142 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
00143 EigenSolver<Matrix<Scalar,2,2> > es(block);
00144 sqrtT.template block<2,2>(i,i)
00145 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
00146 }
00147
00148
00149
00150
00151 template <typename MatrixType>
00152 void MatrixSquareRootQuasiTriangular<MatrixType>
00153 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00154 typename MatrixType::Index i, typename MatrixType::Index j)
00155 {
00156 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
00157 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
00158 }
00159
00160
00161 template <typename MatrixType>
00162 void MatrixSquareRootQuasiTriangular<MatrixType>
00163 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00164 typename MatrixType::Index i, typename MatrixType::Index j)
00165 {
00166 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
00167 if (j-i > 1)
00168 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
00169 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
00170 A += sqrtT.template block<2,2>(j,j).transpose();
00171 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
00172 }
00173
00174
00175 template <typename MatrixType>
00176 void MatrixSquareRootQuasiTriangular<MatrixType>
00177 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00178 typename MatrixType::Index i, typename MatrixType::Index j)
00179 {
00180 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
00181 if (j-i > 2)
00182 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
00183 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
00184 A += sqrtT.template block<2,2>(i,i);
00185 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
00186 }
00187
00188
00189 template <typename MatrixType>
00190 void MatrixSquareRootQuasiTriangular<MatrixType>
00191 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00192 typename MatrixType::Index i, typename MatrixType::Index j)
00193 {
00194 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
00195 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
00196 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
00197 if (j-i > 2)
00198 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
00199 Matrix<Scalar,2,2> X;
00200 solveAuxiliaryEquation(X, A, B, C);
00201 sqrtT.template block<2,2>(i,j) = X;
00202 }
00203
00204
00205 template <typename MatrixType>
00206 template <typename SmallMatrixType>
00207 void MatrixSquareRootQuasiTriangular<MatrixType>
00208 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00209 const SmallMatrixType& B, const SmallMatrixType& C)
00210 {
00211 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
00212 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
00213
00214 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
00215 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
00216 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
00217 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
00218 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
00219 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
00220 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
00221 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
00222 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
00223 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
00224 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
00225 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
00226 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
00227
00228 Matrix<Scalar,4,1> rhs;
00229 rhs.coeffRef(0) = C.coeff(0,0);
00230 rhs.coeffRef(1) = C.coeff(0,1);
00231 rhs.coeffRef(2) = C.coeff(1,0);
00232 rhs.coeffRef(3) = C.coeff(1,1);
00233
00234 Matrix<Scalar,4,1> result;
00235 result = coeffMatrix.fullPivLu().solve(rhs);
00236
00237 X.coeffRef(0,0) = result.coeff(0);
00238 X.coeffRef(0,1) = result.coeff(1);
00239 X.coeffRef(1,0) = result.coeff(2);
00240 X.coeffRef(1,1) = result.coeff(3);
00241 }
00242
00243
00255 template <typename MatrixType>
00256 class MatrixSquareRootTriangular
00257 {
00258 public:
00259 MatrixSquareRootTriangular(const MatrixType& A)
00260 : m_A(A)
00261 {
00262 eigen_assert(A.rows() == A.cols());
00263 }
00264
00274 template <typename ResultType> void compute(ResultType &result);
00275
00276 private:
00277 const MatrixType& m_A;
00278 };
00279
00280 template <typename MatrixType>
00281 template <typename ResultType>
00282 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
00283 {
00284 using std::sqrt;
00285
00286
00287
00288 result.resize(m_A.rows(), m_A.cols());
00289 typedef typename MatrixType::Index Index;
00290 for (Index i = 0; i < m_A.rows(); i++) {
00291 result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
00292 }
00293 for (Index j = 1; j < m_A.cols(); j++) {
00294 for (Index i = j-1; i >= 0; i--) {
00295 typedef typename MatrixType::Scalar Scalar;
00296
00297 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
00298
00299 result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
00300 }
00301 }
00302 }
00303
00304
00312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
00313 class MatrixSquareRoot
00314 {
00315 public:
00316
00324 MatrixSquareRoot(const MatrixType& A);
00325
00333 template <typename ResultType> void compute(ResultType &result);
00334 };
00335
00336
00337
00338
00339 template <typename MatrixType>
00340 class MatrixSquareRoot<MatrixType, 0>
00341 {
00342 public:
00343
00344 MatrixSquareRoot(const MatrixType& A)
00345 : m_A(A)
00346 {
00347 eigen_assert(A.rows() == A.cols());
00348 }
00349
00350 template <typename ResultType> void compute(ResultType &result)
00351 {
00352
00353 const RealSchur<MatrixType> schurOfA(m_A);
00354 const MatrixType& T = schurOfA.matrixT();
00355 const MatrixType& U = schurOfA.matrixU();
00356
00357
00358 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
00359 MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
00360
00361
00362 result = U * sqrtT * U.adjoint();
00363 }
00364
00365 private:
00366 const MatrixType& m_A;
00367 };
00368
00369
00370
00371
00372 template <typename MatrixType>
00373 class MatrixSquareRoot<MatrixType, 1>
00374 {
00375 public:
00376
00377 MatrixSquareRoot(const MatrixType& A)
00378 : m_A(A)
00379 {
00380 eigen_assert(A.rows() == A.cols());
00381 }
00382
00383 template <typename ResultType> void compute(ResultType &result)
00384 {
00385
00386 const ComplexSchur<MatrixType> schurOfA(m_A);
00387 const MatrixType& T = schurOfA.matrixT();
00388 const MatrixType& U = schurOfA.matrixU();
00389
00390
00391 MatrixType sqrtT;
00392 MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
00393
00394
00395 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
00396 }
00397
00398 private:
00399 const MatrixType& m_A;
00400 };
00401
00402
00415 template<typename Derived> class MatrixSquareRootReturnValue
00416 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
00417 {
00418 typedef typename Derived::Index Index;
00419 public:
00425 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
00426
00432 template <typename ResultType>
00433 inline void evalTo(ResultType& result) const
00434 {
00435 const typename Derived::PlainObject srcEvaluated = m_src.eval();
00436 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
00437 me.compute(result);
00438 }
00439
00440 Index rows() const { return m_src.rows(); }
00441 Index cols() const { return m_src.cols(); }
00442
00443 protected:
00444 const Derived& m_src;
00445 private:
00446 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
00447 };
00448
00449 namespace internal {
00450 template<typename Derived>
00451 struct traits<MatrixSquareRootReturnValue<Derived> >
00452 {
00453 typedef typename Derived::PlainObject ReturnType;
00454 };
00455 }
00456
00457 template <typename Derived>
00458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
00459 {
00460 eigen_assert(rows() == cols());
00461 return MatrixSquareRootReturnValue<Derived>(derived());
00462 }
00463
00464 }
00465
00466 #endif // EIGEN_MATRIX_FUNCTION