00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef EIGEN_SPARSE_QR_H
00012 #define EIGEN_SPARSE_QR_H
00013
00014 namespace Eigen {
00015
00016 template<typename MatrixType, typename OrderingType> class SparseQR;
00017 template<typename SparseQRType> struct SparseQRMatrixQReturnType;
00018 template<typename SparseQRType> struct SparseQRMatrixQTransposeReturnType;
00019 template<typename SparseQRType, typename Derived> struct SparseQR_QProduct;
00020 namespace internal {
00021 template <typename SparseQRType> struct traits<SparseQRMatrixQReturnType<SparseQRType> >
00022 {
00023 typedef typename SparseQRType::MatrixType ReturnType;
00024 typedef typename ReturnType::Index Index;
00025 typedef typename ReturnType::StorageKind StorageKind;
00026 };
00027 template <typename SparseQRType> struct traits<SparseQRMatrixQTransposeReturnType<SparseQRType> >
00028 {
00029 typedef typename SparseQRType::MatrixType ReturnType;
00030 };
00031 template <typename SparseQRType, typename Derived> struct traits<SparseQR_QProduct<SparseQRType, Derived> >
00032 {
00033 typedef typename Derived::PlainObject ReturnType;
00034 };
00035 }
00036
00063 template<typename _MatrixType, typename _OrderingType>
00064 class SparseQR
00065 {
00066 public:
00067 typedef _MatrixType MatrixType;
00068 typedef _OrderingType OrderingType;
00069 typedef typename MatrixType::Scalar Scalar;
00070 typedef typename MatrixType::RealScalar RealScalar;
00071 typedef typename MatrixType::Index Index;
00072 typedef SparseMatrix<Scalar,ColMajor,Index> QRMatrixType;
00073 typedef Matrix<Index, Dynamic, 1> IndexVector;
00074 typedef Matrix<Scalar, Dynamic, 1> ScalarVector;
00075 typedef PermutationMatrix<Dynamic, Dynamic, Index> PermutationType;
00076 public:
00077 SparseQR () : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true),m_isQSorted(false)
00078 { }
00079
00080 SparseQR(const MatrixType& mat) : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true),m_isQSorted(false)
00081 {
00082 compute(mat);
00083 }
00084 void compute(const MatrixType& mat)
00085 {
00086 analyzePattern(mat);
00087 factorize(mat);
00088 }
00089 void analyzePattern(const MatrixType& mat);
00090 void factorize(const MatrixType& mat);
00091
00094 inline Index rows() const { return m_pmat.rows(); }
00095
00098 inline Index cols() const { return m_pmat.cols();}
00099
00102 const QRMatrixType& matrixR() const { return m_R; }
00103
00108 Index rank() const
00109 {
00110 eigen_assert(m_isInitialized && "The factorization should be called first, use compute()");
00111 return m_nonzeropivots;
00112 }
00113
00132 SparseQRMatrixQReturnType<SparseQR> matrixQ() const
00133 { return SparseQRMatrixQReturnType<SparseQR>(*this); }
00134
00138 const PermutationType& colsPermutation() const
00139 {
00140 eigen_assert(m_isInitialized && "Decomposition is not initialized.");
00141 return m_outputPerm_c;
00142 }
00143
00147 std::string lastErrorMessage() const { return m_lastError; }
00148
00150 template<typename Rhs, typename Dest>
00151 bool _solve(const MatrixBase<Rhs> &B, MatrixBase<Dest> &dest) const
00152 {
00153 eigen_assert(m_isInitialized && "The factorization should be called first, use compute()");
00154 eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix");
00155
00156 Index rank = this->rank();
00157
00158
00159 typename Dest::PlainObject y, b;
00160 y = this->matrixQ().transpose() * B;
00161 b = y;
00162
00163
00164 y.topRows(rank) = this->matrixR().topLeftCorner(rank, rank).template triangularView<Upper>().solve(b.topRows(rank));
00165 y.bottomRows(y.size()-rank).setZero();
00166
00167
00168 if (m_perm_c.size()) dest.topRows(cols()) = colsPermutation() * y.topRows(cols());
00169 else dest = y.topRows(cols());
00170
00171 m_info = Success;
00172 return true;
00173 }
00174
00175
00181 void setPivotThreshold(const RealScalar& threshold)
00182 {
00183 m_useDefaultThreshold = false;
00184 m_threshold = threshold;
00185 }
00186
00191 template<typename Rhs>
00192 inline const internal::solve_retval<SparseQR, Rhs> solve(const MatrixBase<Rhs>& B) const
00193 {
00194 eigen_assert(m_isInitialized && "The factorization should be called first, use compute()");
00195 eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix");
00196 return internal::solve_retval<SparseQR, Rhs>(*this, B.derived());
00197 }
00198 template<typename Rhs>
00199 inline const internal::sparse_solve_retval<SparseQR, Rhs> solve(const SparseMatrixBase<Rhs>& B) const
00200 {
00201 eigen_assert(m_isInitialized && "The factorization should be called first, use compute()");
00202 eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix");
00203 return internal::sparse_solve_retval<SparseQR, Rhs>(*this, B.derived());
00204 }
00205
00214 ComputationInfo info() const
00215 {
00216 eigen_assert(m_isInitialized && "Decomposition is not initialized.");
00217 return m_info;
00218 }
00219
00220 protected:
00221 inline void sort_matrix_Q()
00222 {
00223 if(this->m_isQSorted) return;
00224
00225 SparseMatrix<Scalar, RowMajor, Index> mQrm(this->m_Q);
00226 this->m_Q = mQrm;
00227 this->m_isQSorted = true;
00228 }
00229
00230
00231 protected:
00232 bool m_isInitialized;
00233 bool m_analysisIsok;
00234 bool m_factorizationIsok;
00235 mutable ComputationInfo m_info;
00236 std::string m_lastError;
00237 QRMatrixType m_pmat;
00238 QRMatrixType m_R;
00239 QRMatrixType m_Q;
00240 ScalarVector m_hcoeffs;
00241 PermutationType m_perm_c;
00242 PermutationType m_pivotperm;
00243 PermutationType m_outputPerm_c;
00244 RealScalar m_threshold;
00245 bool m_useDefaultThreshold;
00246 Index m_nonzeropivots;
00247 IndexVector m_etree;
00248 IndexVector m_firstRowElt;
00249 bool m_isQSorted;
00250
00251 template <typename, typename > friend struct SparseQR_QProduct;
00252 template <typename > friend struct SparseQRMatrixQReturnType;
00253
00254 };
00255
00263 template <typename MatrixType, typename OrderingType>
00264 void SparseQR<MatrixType,OrderingType>::analyzePattern(const MatrixType& mat)
00265 {
00266
00267 OrderingType ord;
00268 ord(mat, m_perm_c);
00269 Index n = mat.cols();
00270 Index m = mat.rows();
00271
00272 if (!m_perm_c.size())
00273 {
00274 m_perm_c.resize(n);
00275 m_perm_c.indices().setLinSpaced(n, 0,n-1);
00276 }
00277
00278
00279 m_outputPerm_c = m_perm_c.inverse();
00280 internal::coletree(mat, m_etree, m_firstRowElt, m_outputPerm_c.indices().data());
00281
00282 m_R.resize(n, n);
00283 m_Q.resize(m, n);
00284
00285
00286 m_R.reserve(2*mat.nonZeros());
00287 m_Q.reserve(2*mat.nonZeros());
00288 m_hcoeffs.resize(n);
00289 m_analysisIsok = true;
00290 }
00291
00299 template <typename MatrixType, typename OrderingType>
00300 void SparseQR<MatrixType,OrderingType>::factorize(const MatrixType& mat)
00301 {
00302 using std::abs;
00303 using std::max;
00304
00305 eigen_assert(m_analysisIsok && "analyzePattern() should be called before this step");
00306 Index m = mat.rows();
00307 Index n = mat.cols();
00308 IndexVector mark(m); mark.setConstant(-1);
00309 IndexVector Ridx(n), Qidx(m);
00310 Index nzcolR, nzcolQ;
00311 ScalarVector tval(m);
00312 bool found_diag;
00313
00314 m_pmat = mat;
00315 m_pmat.uncompress();
00316
00317 for (int i = 0; i < n; i++)
00318 {
00319 Index p = m_perm_c.size() ? m_perm_c.indices()(i) : i;
00320 m_pmat.outerIndexPtr()[p] = mat.outerIndexPtr()[i];
00321 m_pmat.innerNonZeroPtr()[p] = mat.outerIndexPtr()[i+1] - mat.outerIndexPtr()[i];
00322 }
00323
00324
00325
00326
00327
00328 if(m_useDefaultThreshold)
00329 {
00330 RealScalar max2Norm = 0.0;
00331 for (int j = 0; j < n; j++) max2Norm = (max)(max2Norm, m_pmat.col(j).norm());
00332 m_threshold = 20 * (m + n) * max2Norm * NumTraits<RealScalar>::epsilon();
00333 }
00334
00335
00336 m_pivotperm.setIdentity(n);
00337
00338 Index nonzeroCol = 0;
00339
00340
00341 for (Index col = 0; col < n; ++col)
00342 {
00343 mark.setConstant(-1);
00344 m_R.startVec(col);
00345 m_Q.startVec(col);
00346 mark(nonzeroCol) = col;
00347 Qidx(0) = nonzeroCol;
00348 nzcolR = 0; nzcolQ = 1;
00349 found_diag = false;
00350 tval.setZero();
00351
00352
00353
00354
00355
00356 for (typename MatrixType::InnerIterator itp(m_pmat, col); itp || !found_diag; ++itp)
00357 {
00358 Index curIdx = nonzeroCol ;
00359 if(itp) curIdx = itp.row();
00360 if(curIdx == nonzeroCol) found_diag = true;
00361
00362
00363 Index st = m_firstRowElt(curIdx);
00364 if (st < 0 )
00365 {
00366 m_lastError = "Empty row found during numerical factorization";
00367 m_info = InvalidInput;
00368 return;
00369 }
00370
00371
00372 Index bi = nzcolR;
00373 for (; mark(st) != col; st = m_etree(st))
00374 {
00375 Ridx(nzcolR) = st;
00376 mark(st) = col;
00377 nzcolR++;
00378 }
00379
00380
00381 Index nt = nzcolR-bi;
00382 for(Index i = 0; i < nt/2; i++) std::swap(Ridx(bi+i), Ridx(nzcolR-i-1));
00383
00384
00385 if(itp) tval(curIdx) = itp.value();
00386 else tval(curIdx) = Scalar(0);
00387
00388
00389 if(curIdx > nonzeroCol && mark(curIdx) != col )
00390 {
00391 Qidx(nzcolQ) = curIdx;
00392 mark(curIdx) = col;
00393 nzcolQ++;
00394 }
00395 }
00396
00397
00398 for (Index i = nzcolR-1; i >= 0; i--)
00399 {
00400 Index curIdx = m_pivotperm.indices()(Ridx(i));
00401
00402
00403 Scalar tdot(0);
00404
00405
00406 tdot = m_Q.col(curIdx).dot(tval);
00407
00408 tdot *= m_hcoeffs(curIdx);
00409
00410
00411
00412 for (typename QRMatrixType::InnerIterator itq(m_Q, curIdx); itq; ++itq)
00413 tval(itq.row()) -= itq.value() * tdot;
00414
00415
00416 if(m_etree(Ridx(i)) == nonzeroCol)
00417 {
00418 for (typename QRMatrixType::InnerIterator itq(m_Q, curIdx); itq; ++itq)
00419 {
00420 Index iQ = itq.row();
00421 if (mark(iQ) != col)
00422 {
00423 Qidx(nzcolQ++) = iQ;
00424 mark(iQ) = col;
00425 }
00426 }
00427 }
00428 }
00429
00430
00431
00432 Scalar tau;
00433 RealScalar beta;
00434 Scalar c0 = nzcolQ ? tval(Qidx(0)) : Scalar(0);
00435
00436
00437 RealScalar sqrNorm = 0.;
00438 for (Index itq = 1; itq < nzcolQ; ++itq) sqrNorm += numext::abs2(tval(Qidx(itq)));
00439
00440 if(sqrNorm == RealScalar(0) && numext::imag(c0) == RealScalar(0))
00441 {
00442 tau = RealScalar(0);
00443 beta = numext::real(c0);
00444 tval(Qidx(0)) = 1;
00445 }
00446 else
00447 {
00448 beta = std::sqrt(numext::abs2(c0) + sqrNorm);
00449 if(numext::real(c0) >= RealScalar(0))
00450 beta = -beta;
00451 tval(Qidx(0)) = 1;
00452 for (Index itq = 1; itq < nzcolQ; ++itq)
00453 tval(Qidx(itq)) /= (c0 - beta);
00454 tau = numext::conj((beta-c0) / beta);
00455
00456 }
00457
00458
00459 for (Index i = nzcolR-1; i >= 0; i--)
00460 {
00461 Index curIdx = Ridx(i);
00462 if(curIdx < nonzeroCol)
00463 {
00464 m_R.insertBackByOuterInnerUnordered(col, curIdx) = tval(curIdx);
00465 tval(curIdx) = Scalar(0.);
00466 }
00467 }
00468
00469 if(abs(beta) >= m_threshold)
00470 {
00471 m_R.insertBackByOuterInner(col, nonzeroCol) = beta;
00472 nonzeroCol++;
00473
00474 m_hcoeffs(col) = tau;
00475
00476 for (Index itq = 0; itq < nzcolQ; ++itq)
00477 {
00478 Index iQ = Qidx(itq);
00479 m_Q.insertBackByOuterInnerUnordered(col,iQ) = tval(iQ);
00480 tval(iQ) = Scalar(0.);
00481 }
00482 }
00483 else
00484 {
00485
00486 m_hcoeffs(col) = Scalar(0);
00487 for (Index j = nonzeroCol; j < n-1; j++)
00488 std::swap(m_pivotperm.indices()(j), m_pivotperm.indices()[j+1]);
00489
00490
00491 internal::coletree(m_pmat, m_etree, m_firstRowElt, m_pivotperm.indices().data());
00492 }
00493 }
00494
00495
00496 m_Q.finalize();
00497 m_Q.makeCompressed();
00498 m_R.finalize();
00499 m_R.makeCompressed();
00500 m_isQSorted = false;
00501
00502 m_nonzeropivots = nonzeroCol;
00503
00504 if(nonzeroCol<n)
00505 {
00506
00507 MatrixType tempR(m_R);
00508 m_R = tempR * m_pivotperm;
00509
00510
00511 m_outputPerm_c = m_outputPerm_c * m_pivotperm;
00512 }
00513
00514 m_isInitialized = true;
00515 m_factorizationIsok = true;
00516 m_info = Success;
00517 }
00518
00519 namespace internal {
00520
00521 template<typename _MatrixType, typename OrderingType, typename Rhs>
00522 struct solve_retval<SparseQR<_MatrixType,OrderingType>, Rhs>
00523 : solve_retval_base<SparseQR<_MatrixType,OrderingType>, Rhs>
00524 {
00525 typedef SparseQR<_MatrixType,OrderingType> Dec;
00526 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
00527
00528 template<typename Dest> void evalTo(Dest& dst) const
00529 {
00530 dec()._solve(rhs(),dst);
00531 }
00532 };
00533 template<typename _MatrixType, typename OrderingType, typename Rhs>
00534 struct sparse_solve_retval<SparseQR<_MatrixType, OrderingType>, Rhs>
00535 : sparse_solve_retval_base<SparseQR<_MatrixType, OrderingType>, Rhs>
00536 {
00537 typedef SparseQR<_MatrixType, OrderingType> Dec;
00538 EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec, Rhs)
00539
00540 template<typename Dest> void evalTo(Dest& dst) const
00541 {
00542 this->defaultEvalTo(dst);
00543 }
00544 };
00545 }
00546
00547 template <typename SparseQRType, typename Derived>
00548 struct SparseQR_QProduct : ReturnByValue<SparseQR_QProduct<SparseQRType, Derived> >
00549 {
00550 typedef typename SparseQRType::QRMatrixType MatrixType;
00551 typedef typename SparseQRType::Scalar Scalar;
00552 typedef typename SparseQRType::Index Index;
00553
00554 SparseQR_QProduct(const SparseQRType& qr, const Derived& other, bool transpose) :
00555 m_qr(qr),m_other(other),m_transpose(transpose) {}
00556 inline Index rows() const { return m_transpose ? m_qr.rows() : m_qr.cols(); }
00557 inline Index cols() const { return m_other.cols(); }
00558
00559
00560 template<typename DesType>
00561 void evalTo(DesType& res) const
00562 {
00563 Index n = m_qr.cols();
00564 res = m_other;
00565 if (m_transpose)
00566 {
00567 eigen_assert(m_qr.m_Q.rows() == m_other.rows() && "Non conforming object sizes");
00568
00569 for(Index j = 0; j < res.cols(); j++){
00570 for (Index k = 0; k < n; k++)
00571 {
00572 Scalar tau = Scalar(0);
00573 tau = m_qr.m_Q.col(k).dot(res.col(j));
00574 tau = tau * m_qr.m_hcoeffs(k);
00575 res.col(j) -= tau * m_qr.m_Q.col(k);
00576 }
00577 }
00578 }
00579 else
00580 {
00581 eigen_assert(m_qr.m_Q.cols() == m_other.rows() && "Non conforming object sizes");
00582
00583 for(Index j = 0; j < res.cols(); j++)
00584 {
00585 for (Index k = n-1; k >=0; k--)
00586 {
00587 Scalar tau = Scalar(0);
00588 tau = m_qr.m_Q.col(k).dot(res.col(j));
00589 tau = tau * m_qr.m_hcoeffs(k);
00590 res.col(j) -= tau * m_qr.m_Q.col(k);
00591 }
00592 }
00593 }
00594 }
00595
00596 const SparseQRType& m_qr;
00597 const Derived& m_other;
00598 bool m_transpose;
00599 };
00600
00601 template<typename SparseQRType>
00602 struct SparseQRMatrixQReturnType : public EigenBase<SparseQRMatrixQReturnType<SparseQRType> >
00603 {
00604 typedef typename SparseQRType::Index Index;
00605 typedef typename SparseQRType::Scalar Scalar;
00606 typedef Matrix<Scalar,Dynamic,Dynamic> DenseMatrix;
00607 SparseQRMatrixQReturnType(const SparseQRType& qr) : m_qr(qr) {}
00608 template<typename Derived>
00609 SparseQR_QProduct<SparseQRType, Derived> operator*(const MatrixBase<Derived>& other)
00610 {
00611 return SparseQR_QProduct<SparseQRType,Derived>(m_qr,other.derived(),false);
00612 }
00613 SparseQRMatrixQTransposeReturnType<SparseQRType> adjoint() const
00614 {
00615 return SparseQRMatrixQTransposeReturnType<SparseQRType>(m_qr);
00616 }
00617 inline Index rows() const { return m_qr.rows(); }
00618 inline Index cols() const { return m_qr.cols(); }
00619
00620 SparseQRMatrixQTransposeReturnType<SparseQRType> transpose() const
00621 {
00622 return SparseQRMatrixQTransposeReturnType<SparseQRType>(m_qr);
00623 }
00624 template<typename Dest> void evalTo(MatrixBase<Dest>& dest) const
00625 {
00626 dest.derived() = m_qr.matrixQ() * Dest::Identity(m_qr.rows(), m_qr.rows());
00627 }
00628 template<typename Dest> void evalTo(SparseMatrixBase<Dest>& dest) const
00629 {
00630 Dest idMat(m_qr.rows(), m_qr.rows());
00631 idMat.setIdentity();
00632
00633 const_cast<SparseQRType *>(&m_qr)->sort_matrix_Q();
00634 dest.derived() = SparseQR_QProduct<SparseQRType, Dest>(m_qr, idMat, false);
00635 }
00636
00637 const SparseQRType& m_qr;
00638 };
00639
00640 template<typename SparseQRType>
00641 struct SparseQRMatrixQTransposeReturnType
00642 {
00643 SparseQRMatrixQTransposeReturnType(const SparseQRType& qr) : m_qr(qr) {}
00644 template<typename Derived>
00645 SparseQR_QProduct<SparseQRType,Derived> operator*(const MatrixBase<Derived>& other)
00646 {
00647 return SparseQR_QProduct<SparseQRType,Derived>(m_qr,other.derived(), true);
00648 }
00649 const SparseQRType& m_qr;
00650 };
00651
00652 }
00653
00654 #endif