00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_MATRIX_SQUARE_ROOT
00011 #define EIGEN_MATRIX_SQUARE_ROOT
00012
00013 namespace Eigen {
00014
00026 template <typename MatrixType>
00027 class MatrixSquareRootQuasiTriangular
00028 {
00029 public:
00030
00039 MatrixSquareRootQuasiTriangular(const MatrixType& A)
00040 : m_A(A)
00041 {
00042 eigen_assert(A.rows() == A.cols());
00043 }
00044
00053 template <typename ResultType> void compute(ResultType &result);
00054
00055 private:
00056 typedef typename MatrixType::Index Index;
00057 typedef typename MatrixType::Scalar Scalar;
00058
00059 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00060 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
00061 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
00062 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00063 typename MatrixType::Index i, typename MatrixType::Index j);
00064 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00065 typename MatrixType::Index i, typename MatrixType::Index j);
00066 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00067 typename MatrixType::Index i, typename MatrixType::Index j);
00068 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00069 typename MatrixType::Index i, typename MatrixType::Index j);
00070
00071 template <typename SmallMatrixType>
00072 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00073 const SmallMatrixType& B, const SmallMatrixType& C);
00074
00075 const MatrixType& m_A;
00076 };
00077
00078 template <typename MatrixType>
00079 template <typename ResultType>
00080 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
00081 {
00082
00083 const RealSchur<MatrixType> schurOfA(m_A);
00084 const MatrixType& T = schurOfA.matrixT();
00085 const MatrixType& U = schurOfA.matrixU();
00086
00087
00088 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00089 computeDiagonalPartOfSqrt(sqrtT, T);
00090 computeOffDiagonalPartOfSqrt(sqrtT, T);
00091
00092
00093 result = U * sqrtT * U.adjoint();
00094 }
00095
00096
00097
00098 template <typename MatrixType>
00099 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
00100 const MatrixType& T)
00101 {
00102 const Index size = m_A.rows();
00103 for (Index i = 0; i < size; i++) {
00104 if (i == size - 1 || T.coeff(i+1, i) == 0) {
00105 eigen_assert(T(i,i) > 0);
00106 sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00107 }
00108 else {
00109 compute2x2diagonalBlock(sqrtT, T, i);
00110 ++i;
00111 }
00112 }
00113 }
00114
00115
00116
00117 template <typename MatrixType>
00118 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
00119 const MatrixType& T)
00120 {
00121 const Index size = m_A.rows();
00122 for (Index j = 1; j < size; j++) {
00123 if (T.coeff(j, j-1) != 0)
00124 continue;
00125 for (Index i = j-1; i >= 0; i--) {
00126 if (i > 0 && T.coeff(i, i-1) != 0)
00127 continue;
00128 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
00129 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
00130 if (iBlockIs2x2 && jBlockIs2x2)
00131 compute2x2offDiagonalBlock(sqrtT, T, i, j);
00132 else if (iBlockIs2x2 && !jBlockIs2x2)
00133 compute2x1offDiagonalBlock(sqrtT, T, i, j);
00134 else if (!iBlockIs2x2 && jBlockIs2x2)
00135 compute1x2offDiagonalBlock(sqrtT, T, i, j);
00136 else if (!iBlockIs2x2 && !jBlockIs2x2)
00137 compute1x1offDiagonalBlock(sqrtT, T, i, j);
00138 }
00139 }
00140 }
00141
00142
00143
00144 template <typename MatrixType>
00145 void MatrixSquareRootQuasiTriangular<MatrixType>
00146 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
00147 {
00148
00149
00150 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
00151 EigenSolver<Matrix<Scalar,2,2> > es(block);
00152 sqrtT.template block<2,2>(i,i)
00153 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
00154 }
00155
00156
00157
00158
00159 template <typename MatrixType>
00160 void MatrixSquareRootQuasiTriangular<MatrixType>
00161 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00162 typename MatrixType::Index i, typename MatrixType::Index j)
00163 {
00164 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
00165 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
00166 }
00167
00168
00169 template <typename MatrixType>
00170 void MatrixSquareRootQuasiTriangular<MatrixType>
00171 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00172 typename MatrixType::Index i, typename MatrixType::Index j)
00173 {
00174 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
00175 if (j-i > 1)
00176 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
00177 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
00178 A += sqrtT.template block<2,2>(j,j).transpose();
00179 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
00180 }
00181
00182
00183 template <typename MatrixType>
00184 void MatrixSquareRootQuasiTriangular<MatrixType>
00185 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00186 typename MatrixType::Index i, typename MatrixType::Index j)
00187 {
00188 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
00189 if (j-i > 2)
00190 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
00191 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
00192 A += sqrtT.template block<2,2>(i,i);
00193 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
00194 }
00195
00196
00197 template <typename MatrixType>
00198 void MatrixSquareRootQuasiTriangular<MatrixType>
00199 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
00200 typename MatrixType::Index i, typename MatrixType::Index j)
00201 {
00202 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
00203 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
00204 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
00205 if (j-i > 2)
00206 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
00207 Matrix<Scalar,2,2> X;
00208 solveAuxiliaryEquation(X, A, B, C);
00209 sqrtT.template block<2,2>(i,j) = X;
00210 }
00211
00212
00213 template <typename MatrixType>
00214 template <typename SmallMatrixType>
00215 void MatrixSquareRootQuasiTriangular<MatrixType>
00216 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
00217 const SmallMatrixType& B, const SmallMatrixType& C)
00218 {
00219 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
00220 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
00221
00222 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
00223 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
00224 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
00225 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
00226 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
00227 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
00228 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
00229 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
00230 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
00231 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
00232 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
00233 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
00234 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
00235
00236 Matrix<Scalar,4,1> rhs;
00237 rhs.coeffRef(0) = C.coeff(0,0);
00238 rhs.coeffRef(1) = C.coeff(0,1);
00239 rhs.coeffRef(2) = C.coeff(1,0);
00240 rhs.coeffRef(3) = C.coeff(1,1);
00241
00242 Matrix<Scalar,4,1> result;
00243 result = coeffMatrix.fullPivLu().solve(rhs);
00244
00245 X.coeffRef(0,0) = result.coeff(0);
00246 X.coeffRef(0,1) = result.coeff(1);
00247 X.coeffRef(1,0) = result.coeff(2);
00248 X.coeffRef(1,1) = result.coeff(3);
00249 }
00250
00251
00263 template <typename MatrixType>
00264 class MatrixSquareRootTriangular
00265 {
00266 public:
00267 MatrixSquareRootTriangular(const MatrixType& A)
00268 : m_A(A)
00269 {
00270 eigen_assert(A.rows() == A.cols());
00271 }
00272
00282 template <typename ResultType> void compute(ResultType &result);
00283
00284 private:
00285 const MatrixType& m_A;
00286 };
00287
00288 template <typename MatrixType>
00289 template <typename ResultType>
00290 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
00291 {
00292
00293 const ComplexSchur<MatrixType> schurOfA(m_A);
00294 const MatrixType& T = schurOfA.matrixT();
00295 const MatrixType& U = schurOfA.matrixU();
00296
00297
00298
00299 result.resize(m_A.rows(), m_A.cols());
00300 typedef typename MatrixType::Index Index;
00301 for (Index i = 0; i < m_A.rows(); i++) {
00302 result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
00303 }
00304 for (Index j = 1; j < m_A.cols(); j++) {
00305 for (Index i = j-1; i >= 0; i--) {
00306 typedef typename MatrixType::Scalar Scalar;
00307
00308 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
00309
00310 result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
00311 }
00312 }
00313
00314
00315 MatrixType tmp;
00316 tmp.noalias() = U * result.template triangularView<Upper>();
00317 result.noalias() = tmp * U.adjoint();
00318 }
00319
00320
00328 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
00329 class MatrixSquareRoot
00330 {
00331 public:
00332
00340 MatrixSquareRoot(const MatrixType& A);
00341
00349 template <typename ResultType> void compute(ResultType &result);
00350 };
00351
00352
00353
00354
00355 template <typename MatrixType>
00356 class MatrixSquareRoot<MatrixType, 0>
00357 {
00358 public:
00359
00360 MatrixSquareRoot(const MatrixType& A)
00361 : m_A(A)
00362 {
00363 eigen_assert(A.rows() == A.cols());
00364 }
00365
00366 template <typename ResultType> void compute(ResultType &result)
00367 {
00368
00369 const RealSchur<MatrixType> schurOfA(m_A);
00370 const MatrixType& T = schurOfA.matrixT();
00371 const MatrixType& U = schurOfA.matrixU();
00372
00373
00374 MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
00375 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00376 tmp.compute(sqrtT);
00377
00378
00379 result = U * sqrtT * U.adjoint();
00380 }
00381
00382 private:
00383 const MatrixType& m_A;
00384 };
00385
00386
00387
00388
00389 template <typename MatrixType>
00390 class MatrixSquareRoot<MatrixType, 1>
00391 {
00392 public:
00393
00394 MatrixSquareRoot(const MatrixType& A)
00395 : m_A(A)
00396 {
00397 eigen_assert(A.rows() == A.cols());
00398 }
00399
00400 template <typename ResultType> void compute(ResultType &result)
00401 {
00402
00403 const ComplexSchur<MatrixType> schurOfA(m_A);
00404 const MatrixType& T = schurOfA.matrixT();
00405 const MatrixType& U = schurOfA.matrixU();
00406
00407
00408 MatrixSquareRootTriangular<MatrixType> tmp(T);
00409 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
00410 tmp.compute(sqrtT);
00411
00412
00413 result = U * sqrtT * U.adjoint();
00414 }
00415
00416 private:
00417 const MatrixType& m_A;
00418 };
00419
00420
00433 template<typename Derived> class MatrixSquareRootReturnValue
00434 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
00435 {
00436 typedef typename Derived::Index Index;
00437 public:
00443 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
00444
00450 template <typename ResultType>
00451 inline void evalTo(ResultType& result) const
00452 {
00453 const typename Derived::PlainObject srcEvaluated = m_src.eval();
00454 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
00455 me.compute(result);
00456 }
00457
00458 Index rows() const { return m_src.rows(); }
00459 Index cols() const { return m_src.cols(); }
00460
00461 protected:
00462 const Derived& m_src;
00463 private:
00464 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
00465 };
00466
00467 namespace internal {
00468 template<typename Derived>
00469 struct traits<MatrixSquareRootReturnValue<Derived> >
00470 {
00471 typedef typename Derived::PlainObject ReturnType;
00472 };
00473 }
00474
00475 template <typename Derived>
00476 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
00477 {
00478 eigen_assert(rows() == cols());
00479 return MatrixSquareRootReturnValue<Derived>(derived());
00480 }
00481
00482 }
00483
00484 #endif // EIGEN_MATRIX_FUNCTION