00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_MATRIX_FUNCTION
00026 #define EIGEN_MATRIX_FUNCTION
00027
00028 #include "StemFunction.h"
00029 #include "MatrixFunctionAtomic.h"
00030
00031
00037 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
00038 class MatrixFunction
00039 {
00040 private:
00041
00042 typedef typename internal::traits<MatrixType>::Index Index;
00043 typedef typename internal::traits<MatrixType>::Scalar Scalar;
00044 typedef typename internal::stem_function<Scalar>::type StemFunction;
00045
00046 public:
00047
00056 MatrixFunction(const MatrixType& A, StemFunction f);
00057
00066 template <typename ResultType>
00067 void compute(ResultType &result);
00068 };
00069
00070
00074 template <typename MatrixType>
00075 class MatrixFunction<MatrixType, 0>
00076 {
00077 private:
00078
00079 typedef internal::traits<MatrixType> Traits;
00080 typedef typename Traits::Scalar Scalar;
00081 static const int Rows = Traits::RowsAtCompileTime;
00082 static const int Cols = Traits::ColsAtCompileTime;
00083 static const int Options = MatrixType::Options;
00084 static const int MaxRows = Traits::MaxRowsAtCompileTime;
00085 static const int MaxCols = Traits::MaxColsAtCompileTime;
00086
00087 typedef std::complex<Scalar> ComplexScalar;
00088 typedef Matrix<ComplexScalar, Rows, Cols, Options, MaxRows, MaxCols> ComplexMatrix;
00089 typedef typename internal::stem_function<Scalar>::type StemFunction;
00090
00091 public:
00092
00098 MatrixFunction(const MatrixType& A, StemFunction f) : m_A(A), m_f(f) { }
00099
00109 template <typename ResultType>
00110 void compute(ResultType& result)
00111 {
00112 ComplexMatrix CA = m_A.template cast<ComplexScalar>();
00113 ComplexMatrix Cresult;
00114 MatrixFunction<ComplexMatrix> mf(CA, m_f);
00115 mf.compute(Cresult);
00116 result = Cresult.real();
00117 }
00118
00119 private:
00120 typename internal::nested<MatrixType>::type m_A;
00121 StemFunction *m_f;
00123 MatrixFunction& operator=(const MatrixFunction&);
00124 };
00125
00126
00130 template <typename MatrixType>
00131 class MatrixFunction<MatrixType, 1>
00132 {
00133 private:
00134
00135 typedef internal::traits<MatrixType> Traits;
00136 typedef typename MatrixType::Scalar Scalar;
00137 typedef typename MatrixType::Index Index;
00138 static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
00139 static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
00140 static const int Options = MatrixType::Options;
00141 typedef typename NumTraits<Scalar>::Real RealScalar;
00142 typedef typename internal::stem_function<Scalar>::type StemFunction;
00143 typedef Matrix<Scalar, Traits::RowsAtCompileTime, 1> VectorType;
00144 typedef Matrix<Index, Traits::RowsAtCompileTime, 1> IntVectorType;
00145 typedef Matrix<Index, Dynamic, 1> DynamicIntVectorType;
00146 typedef std::list<Scalar> Cluster;
00147 typedef std::list<Cluster> ListOfClusters;
00148 typedef Matrix<Scalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
00149
00150 public:
00151
00152 MatrixFunction(const MatrixType& A, StemFunction f);
00153 template <typename ResultType> void compute(ResultType& result);
00154
00155 private:
00156
00157 void computeSchurDecomposition();
00158 void partitionEigenvalues();
00159 typename ListOfClusters::iterator findCluster(Scalar key);
00160 void computeClusterSize();
00161 void computeBlockStart();
00162 void constructPermutation();
00163 void permuteSchur();
00164 void swapEntriesInSchur(Index index);
00165 void computeBlockAtomic();
00166 Block<MatrixType> block(MatrixType& A, Index i, Index j);
00167 void computeOffDiagonal();
00168 DynMatrixType solveTriangularSylvester(const DynMatrixType& A, const DynMatrixType& B, const DynMatrixType& C);
00169
00170 typename internal::nested<MatrixType>::type m_A;
00171 StemFunction *m_f;
00172 MatrixType m_T;
00173 MatrixType m_U;
00174 MatrixType m_fT;
00175 ListOfClusters m_clusters;
00176 DynamicIntVectorType m_eivalToCluster;
00177 DynamicIntVectorType m_clusterSize;
00178 DynamicIntVectorType m_blockStart;
00179 IntVectorType m_permutation;
00187 static const RealScalar separation() { return static_cast<RealScalar>(0.1); }
00188
00189 MatrixFunction& operator=(const MatrixFunction&);
00190 };
00191
00197 template <typename MatrixType>
00198 MatrixFunction<MatrixType,1>::MatrixFunction(const MatrixType& A, StemFunction f) :
00199 m_A(A), m_f(f)
00200 {
00201
00202 }
00203
00209 template <typename MatrixType>
00210 template <typename ResultType>
00211 void MatrixFunction<MatrixType,1>::compute(ResultType& result)
00212 {
00213 computeSchurDecomposition();
00214 partitionEigenvalues();
00215 computeClusterSize();
00216 computeBlockStart();
00217 constructPermutation();
00218 permuteSchur();
00219 computeBlockAtomic();
00220 computeOffDiagonal();
00221 result = m_U * m_fT * m_U.adjoint();
00222 }
00223
00225 template <typename MatrixType>
00226 void MatrixFunction<MatrixType,1>::computeSchurDecomposition()
00227 {
00228 const ComplexSchur<MatrixType> schurOfA(m_A);
00229 m_T = schurOfA.matrixT();
00230 m_U = schurOfA.matrixU();
00231 }
00232
00244 template <typename MatrixType>
00245 void MatrixFunction<MatrixType,1>::partitionEigenvalues()
00246 {
00247 const Index rows = m_T.rows();
00248 VectorType diag = m_T.diagonal();
00249
00250 for (Index i=0; i<rows; ++i) {
00251
00252 typename ListOfClusters::iterator qi = findCluster(diag(i));
00253 if (qi == m_clusters.end()) {
00254 Cluster l;
00255 l.push_back(diag(i));
00256 m_clusters.push_back(l);
00257 qi = m_clusters.end();
00258 --qi;
00259 }
00260
00261
00262 for (Index j=i+1; j<rows; ++j) {
00263 if (internal::abs(diag(j) - diag(i)) <= separation() && std::find(qi->begin(), qi->end(), diag(j)) == qi->end()) {
00264 typename ListOfClusters::iterator qj = findCluster(diag(j));
00265 if (qj == m_clusters.end()) {
00266 qi->push_back(diag(j));
00267 } else {
00268 qi->insert(qi->end(), qj->begin(), qj->end());
00269 m_clusters.erase(qj);
00270 }
00271 }
00272 }
00273 }
00274 }
00275
00281 template <typename MatrixType>
00282 typename MatrixFunction<MatrixType,1>::ListOfClusters::iterator MatrixFunction<MatrixType,1>::findCluster(Scalar key)
00283 {
00284 typename Cluster::iterator j;
00285 for (typename ListOfClusters::iterator i = m_clusters.begin(); i != m_clusters.end(); ++i) {
00286 j = std::find(i->begin(), i->end(), key);
00287 if (j != i->end())
00288 return i;
00289 }
00290 return m_clusters.end();
00291 }
00292
00294 template <typename MatrixType>
00295 void MatrixFunction<MatrixType,1>::computeClusterSize()
00296 {
00297 const Index rows = m_T.rows();
00298 VectorType diag = m_T.diagonal();
00299 const Index numClusters = static_cast<Index>(m_clusters.size());
00300
00301 m_clusterSize.setZero(numClusters);
00302 m_eivalToCluster.resize(rows);
00303 Index clusterIndex = 0;
00304 for (typename ListOfClusters::const_iterator cluster = m_clusters.begin(); cluster != m_clusters.end(); ++cluster) {
00305 for (Index i = 0; i < diag.rows(); ++i) {
00306 if (std::find(cluster->begin(), cluster->end(), diag(i)) != cluster->end()) {
00307 ++m_clusterSize[clusterIndex];
00308 m_eivalToCluster[i] = clusterIndex;
00309 }
00310 }
00311 ++clusterIndex;
00312 }
00313 }
00314
00316 template <typename MatrixType>
00317 void MatrixFunction<MatrixType,1>::computeBlockStart()
00318 {
00319 m_blockStart.resize(m_clusterSize.rows());
00320 m_blockStart(0) = 0;
00321 for (Index i = 1; i < m_clusterSize.rows(); i++) {
00322 m_blockStart(i) = m_blockStart(i-1) + m_clusterSize(i-1);
00323 }
00324 }
00325
00327 template <typename MatrixType>
00328 void MatrixFunction<MatrixType,1>::constructPermutation()
00329 {
00330 DynamicIntVectorType indexNextEntry = m_blockStart;
00331 m_permutation.resize(m_T.rows());
00332 for (Index i = 0; i < m_T.rows(); i++) {
00333 Index cluster = m_eivalToCluster[i];
00334 m_permutation[i] = indexNextEntry[cluster];
00335 ++indexNextEntry[cluster];
00336 }
00337 }
00338
00340 template <typename MatrixType>
00341 void MatrixFunction<MatrixType,1>::permuteSchur()
00342 {
00343 IntVectorType p = m_permutation;
00344 for (Index i = 0; i < p.rows() - 1; i++) {
00345 Index j;
00346 for (j = i; j < p.rows(); j++) {
00347 if (p(j) == i) break;
00348 }
00349 eigen_assert(p(j) == i);
00350 for (Index k = j-1; k >= i; k--) {
00351 swapEntriesInSchur(k);
00352 std::swap(p.coeffRef(k), p.coeffRef(k+1));
00353 }
00354 }
00355 }
00356
00358 template <typename MatrixType>
00359 void MatrixFunction<MatrixType,1>::swapEntriesInSchur(Index index)
00360 {
00361 JacobiRotation<Scalar> rotation;
00362 rotation.makeGivens(m_T(index, index+1), m_T(index+1, index+1) - m_T(index, index));
00363 m_T.applyOnTheLeft(index, index+1, rotation.adjoint());
00364 m_T.applyOnTheRight(index, index+1, rotation);
00365 m_U.applyOnTheRight(index, index+1, rotation);
00366 }
00367
00375 template <typename MatrixType>
00376 void MatrixFunction<MatrixType,1>::computeBlockAtomic()
00377 {
00378 m_fT.resize(m_T.rows(), m_T.cols());
00379 m_fT.setZero();
00380 MatrixFunctionAtomic<DynMatrixType> mfa(m_f);
00381 for (Index i = 0; i < m_clusterSize.rows(); ++i) {
00382 block(m_fT, i, i) = mfa.compute(block(m_T, i, i));
00383 }
00384 }
00385
00387 template <typename MatrixType>
00388 Block<MatrixType> MatrixFunction<MatrixType,1>::block(MatrixType& A, Index i, Index j)
00389 {
00390 return A.block(m_blockStart(i), m_blockStart(j), m_clusterSize(i), m_clusterSize(j));
00391 }
00392
00400 template <typename MatrixType>
00401 void MatrixFunction<MatrixType,1>::computeOffDiagonal()
00402 {
00403 for (Index diagIndex = 1; diagIndex < m_clusterSize.rows(); diagIndex++) {
00404 for (Index blockIndex = 0; blockIndex < m_clusterSize.rows() - diagIndex; blockIndex++) {
00405
00406 DynMatrixType A = block(m_T, blockIndex, blockIndex);
00407 DynMatrixType B = -block(m_T, blockIndex+diagIndex, blockIndex+diagIndex);
00408 DynMatrixType C = block(m_fT, blockIndex, blockIndex) * block(m_T, blockIndex, blockIndex+diagIndex);
00409 C -= block(m_T, blockIndex, blockIndex+diagIndex) * block(m_fT, blockIndex+diagIndex, blockIndex+diagIndex);
00410 for (Index k = blockIndex + 1; k < blockIndex + diagIndex; k++) {
00411 C += block(m_fT, blockIndex, k) * block(m_T, k, blockIndex+diagIndex);
00412 C -= block(m_T, blockIndex, k) * block(m_fT, k, blockIndex+diagIndex);
00413 }
00414 block(m_fT, blockIndex, blockIndex+diagIndex) = solveTriangularSylvester(A, B, C);
00415 }
00416 }
00417 }
00418
00442 template <typename MatrixType>
00443 typename MatrixFunction<MatrixType,1>::DynMatrixType MatrixFunction<MatrixType,1>::solveTriangularSylvester(
00444 const DynMatrixType& A,
00445 const DynMatrixType& B,
00446 const DynMatrixType& C)
00447 {
00448 eigen_assert(A.rows() == A.cols());
00449 eigen_assert(A.isUpperTriangular());
00450 eigen_assert(B.rows() == B.cols());
00451 eigen_assert(B.isUpperTriangular());
00452 eigen_assert(C.rows() == A.rows());
00453 eigen_assert(C.cols() == B.rows());
00454
00455 Index m = A.rows();
00456 Index n = B.rows();
00457 DynMatrixType X(m, n);
00458
00459 for (Index i = m - 1; i >= 0; --i) {
00460 for (Index j = 0; j < n; ++j) {
00461
00462
00463 Scalar AX;
00464 if (i == m - 1) {
00465 AX = 0;
00466 } else {
00467 Matrix<Scalar,1,1> AXmatrix = A.row(i).tail(m-1-i) * X.col(j).tail(m-1-i);
00468 AX = AXmatrix(0,0);
00469 }
00470
00471
00472 Scalar XB;
00473 if (j == 0) {
00474 XB = 0;
00475 } else {
00476 Matrix<Scalar,1,1> XBmatrix = X.row(i).head(j) * B.col(j).head(j);
00477 XB = XBmatrix(0,0);
00478 }
00479
00480 X(i,j) = (C(i,j) - AX - XB) / (A(i,i) + B(j,j));
00481 }
00482 }
00483 return X;
00484 }
00485
00498 template<typename Derived> class MatrixFunctionReturnValue
00499 : public ReturnByValue<MatrixFunctionReturnValue<Derived> >
00500 {
00501 public:
00502
00503 typedef typename Derived::Scalar Scalar;
00504 typedef typename Derived::Index Index;
00505 typedef typename internal::stem_function<Scalar>::type StemFunction;
00506
00513 MatrixFunctionReturnValue(const Derived& A, StemFunction f) : m_A(A), m_f(f) { }
00514
00520 template <typename ResultType>
00521 inline void evalTo(ResultType& result) const
00522 {
00523 const typename Derived::PlainObject Aevaluated = m_A.eval();
00524 MatrixFunction<typename Derived::PlainObject> mf(Aevaluated, m_f);
00525 mf.compute(result);
00526 }
00527
00528 Index rows() const { return m_A.rows(); }
00529 Index cols() const { return m_A.cols(); }
00530
00531 private:
00532 typename internal::nested<Derived>::type m_A;
00533 StemFunction *m_f;
00534
00535 MatrixFunctionReturnValue& operator=(const MatrixFunctionReturnValue&);
00536 };
00537
00538 namespace internal {
00539 template<typename Derived>
00540 struct traits<MatrixFunctionReturnValue<Derived> >
00541 {
00542 typedef typename Derived::PlainObject ReturnType;
00543 };
00544 }
00545
00546
00547
00548
00549
00550 template <typename Derived>
00551 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::matrixFunction(typename internal::stem_function<typename internal::traits<Derived>::Scalar>::type f) const
00552 {
00553 eigen_assert(rows() == cols());
00554 return MatrixFunctionReturnValue<Derived>(derived(), f);
00555 }
00556
00557 template <typename Derived>
00558 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::sin() const
00559 {
00560 eigen_assert(rows() == cols());
00561 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
00562 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::sin);
00563 }
00564
00565 template <typename Derived>
00566 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::cos() const
00567 {
00568 eigen_assert(rows() == cols());
00569 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
00570 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::cos);
00571 }
00572
00573 template <typename Derived>
00574 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::sinh() const
00575 {
00576 eigen_assert(rows() == cols());
00577 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
00578 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::sinh);
00579 }
00580
00581 template <typename Derived>
00582 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::cosh() const
00583 {
00584 eigen_assert(rows() == cols());
00585 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
00586 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::cosh);
00587 }
00588
00589 #endif // EIGEN_MATRIX_FUNCTION