TriangularSolver.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // Eigen is free software; you can redistribute it and/or
00007 // modify it under the terms of the GNU Lesser General Public
00008 // License as published by the Free Software Foundation; either
00009 // version 3 of the License, or (at your option) any later version.
00010 //
00011 // Alternatively, you can redistribute it and/or
00012 // modify it under the terms of the GNU General Public License as
00013 // published by the Free Software Foundation; either version 2 of
00014 // the License, or (at your option) any later version.
00015 //
00016 // Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
00017 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
00018 // FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
00019 // GNU General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License and a copy of the GNU General Public License along with
00023 // Eigen. If not, see <http://www.gnu.org/licenses/>.
00024 
00025 #ifndef EIGEN_SPARSETRIANGULARSOLVER_H
00026 #define EIGEN_SPARSETRIANGULARSOLVER_H
00027 
00028 namespace internal {
00029 
00030 template<typename Lhs, typename Rhs, int Mode,
00031   int UpLo = (Mode & Lower)
00032            ? Lower
00033            : (Mode & Upper)
00034            ? Upper
00035            : -1,
00036   int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
00037 struct sparse_solve_triangular_selector;
00038 
00039 // forward substitution, row-major
00040 template<typename Lhs, typename Rhs, int Mode>
00041 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
00042 {
00043   typedef typename Rhs::Scalar Scalar;
00044   static void run(const Lhs& lhs, Rhs& other)
00045   {
00046     for(int col=0 ; col<other.cols() ; ++col)
00047     {
00048       for(int i=0; i<lhs.rows(); ++i)
00049       {
00050         Scalar tmp = other.coeff(i,col);
00051         Scalar lastVal = 0;
00052         int lastIndex = 0;
00053         for(typename Lhs::InnerIterator it(lhs, i); it; ++it)
00054         {
00055           lastVal = it.value();
00056           lastIndex = it.index();
00057           if(lastIndex==i)
00058             break;
00059           tmp -= lastVal * other.coeff(lastIndex,col);
00060         }
00061         if (Mode & UnitDiag)
00062           other.coeffRef(i,col) = tmp;
00063         else
00064         {
00065           eigen_assert(lastIndex==i);
00066           other.coeffRef(i,col) = tmp/lastVal;
00067         }
00068       }
00069     }
00070   }
00071 };
00072 
00073 // backward substitution, row-major
00074 template<typename Lhs, typename Rhs, int Mode>
00075 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
00076 {
00077   typedef typename Rhs::Scalar Scalar;
00078   static void run(const Lhs& lhs, Rhs& other)
00079   {
00080     for(int col=0 ; col<other.cols() ; ++col)
00081     {
00082       for(int i=lhs.rows()-1 ; i>=0 ; --i)
00083       {
00084         Scalar tmp = other.coeff(i,col);
00085         typename Lhs::InnerIterator it(lhs, i);
00086         if (it && it.index() == i)
00087           ++it;
00088         for(; it; ++it)
00089         {
00090           tmp -= it.value() * other.coeff(it.index(),col);
00091         }
00092 
00093         if (Mode & UnitDiag)
00094           other.coeffRef(i,col) = tmp;
00095         else
00096         {
00097           typename Lhs::InnerIterator it(lhs, i);
00098           eigen_assert(it && it.index() == i);
00099           other.coeffRef(i,col) = tmp/it.value();
00100         }
00101       }
00102     }
00103   }
00104 };
00105 
00106 // forward substitution, col-major
00107 template<typename Lhs, typename Rhs, int Mode>
00108 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
00109 {
00110   typedef typename Rhs::Scalar Scalar;
00111   static void run(const Lhs& lhs, Rhs& other)
00112   {
00113     for(int col=0 ; col<other.cols() ; ++col)
00114     {
00115       for(int i=0; i<lhs.cols(); ++i)
00116       {
00117         Scalar& tmp = other.coeffRef(i,col);
00118         if (tmp!=Scalar(0)) // optimization when other is actually sparse
00119         {
00120           typename Lhs::InnerIterator it(lhs, i);
00121           if(!(Mode & UnitDiag))
00122           {
00123             eigen_assert(it.index()==i);
00124             tmp /= it.value();
00125           }
00126           if (it && it.index()==i)
00127             ++it;
00128           for(; it; ++it)
00129             other.coeffRef(it.index(), col) -= tmp * it.value();
00130         }
00131       }
00132     }
00133   }
00134 };
00135 
00136 // backward substitution, col-major
00137 template<typename Lhs, typename Rhs, int Mode>
00138 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
00139 {
00140   typedef typename Rhs::Scalar Scalar;
00141   static void run(const Lhs& lhs, Rhs& other)
00142   {
00143     for(int col=0 ; col<other.cols() ; ++col)
00144     {
00145       for(int i=lhs.cols()-1; i>=0; --i)
00146       {
00147         Scalar& tmp = other.coeffRef(i,col);
00148         if (tmp!=Scalar(0)) // optimization when other is actually sparse
00149         {
00150           if(!(Mode & UnitDiag))
00151           {
00152             // FIXME lhs.coeff(i,i) might not be always efficient while it must simply be the
00153             // last element of the column !
00154             other.coeffRef(i,col) /= lhs.innerVector(i).lastCoeff();
00155           }
00156           typename Lhs::InnerIterator it(lhs, i);
00157           for(; it && it.index()<i; ++it)
00158             other.coeffRef(it.index(), col) -= tmp * it.value();
00159         }
00160       }
00161     }
00162   }
00163 };
00164 
00165 } // end namespace internal
00166 
00167 template<typename ExpressionType,int Mode>
00168 template<typename OtherDerived>
00169 void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
00170 {
00171   eigen_assert(m_matrix.cols() == m_matrix.rows());
00172   eigen_assert(m_matrix.cols() == other.rows());
00173   eigen_assert(!(Mode & ZeroDiag));
00174   eigen_assert(Mode & (Upper|Lower));
00175 
00176   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
00177 
00178   typedef typename internal::conditional<copy,
00179     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
00180   OtherCopy otherCopy(other.derived());
00181 
00182   internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
00183 
00184   if (copy)
00185     other = otherCopy;
00186 }
00187 
00188 template<typename ExpressionType,int Mode>
00189 template<typename OtherDerived>
00190 typename internal::plain_matrix_type_column_major<OtherDerived>::type
00191 SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
00192 {
00193   typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
00194   solveInPlace(res);
00195   return res;
00196 }
00197 
00198 // pure sparse path
00199 
00200 namespace internal {
00201 
00202 template<typename Lhs, typename Rhs, int Mode,
00203   int UpLo = (Mode & Lower)
00204            ? Lower
00205            : (Mode & Upper)
00206            ? Upper
00207            : -1,
00208   int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
00209 struct sparse_solve_triangular_sparse_selector;
00210 
00211 // forward substitution, col-major
00212 template<typename Lhs, typename Rhs, int Mode, int UpLo>
00213 struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
00214 {
00215   typedef typename Rhs::Scalar Scalar;
00216   typedef typename promote_index_type<typename traits<Lhs>::Index,
00217                                          typename traits<Rhs>::Index>::type Index;
00218   static void run(const Lhs& lhs, Rhs& other)
00219   {
00220     const bool IsLower = (UpLo==Lower);
00221     AmbiVector<Scalar,Index> tempVector(other.rows()*2);
00222     tempVector.setBounds(0,other.rows());
00223 
00224     Rhs res(other.rows(), other.cols());
00225     res.reserve(other.nonZeros());
00226 
00227     for(int col=0 ; col<other.cols() ; ++col)
00228     {
00229       // FIXME estimate number of non zeros
00230       tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/);
00231       tempVector.setZero();
00232       tempVector.restart();
00233       for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt)
00234       {
00235         tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
00236       }
00237 
00238       for(int i=IsLower?0:lhs.cols()-1;
00239           IsLower?i<lhs.cols():i>=0;
00240           i+=IsLower?1:-1)
00241       {
00242         tempVector.restart();
00243         Scalar& ci = tempVector.coeffRef(i);
00244         if (ci!=Scalar(0))
00245         {
00246           // find
00247           typename Lhs::InnerIterator it(lhs, i);
00248           if(!(Mode & UnitDiag))
00249           {
00250             if (IsLower)
00251             {
00252               eigen_assert(it.index()==i);
00253               ci /= it.value();
00254             }
00255             else
00256               ci /= lhs.coeff(i,i);
00257           }
00258           tempVector.restart();
00259           if (IsLower)
00260           {
00261             if (it.index()==i)
00262               ++it;
00263             for(; it; ++it)
00264               tempVector.coeffRef(it.index()) -= ci * it.value();
00265           }
00266           else
00267           {
00268             for(; it && it.index()<i; ++it)
00269               tempVector.coeffRef(it.index()) -= ci * it.value();
00270           }
00271         }
00272       }
00273 
00274 
00275       int count = 0;
00276       // FIXME compute a reference value to filter zeros
00277       for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector/*,1e-12*/); it; ++it)
00278       {
00279         ++ count;
00280 //         std::cerr << "fill " << it.index() << ", " << col << "\n";
00281 //         std::cout << it.value() << "  ";
00282         // FIXME use insertBack
00283         res.insert(it.index(), col) = it.value();
00284       }
00285 //       std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
00286     }
00287     res.finalize();
00288     other = res.markAsRValue();
00289   }
00290 };
00291 
00292 } // end namespace internal
00293 
00294 template<typename ExpressionType,int Mode>
00295 template<typename OtherDerived>
00296 void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
00297 {
00298   eigen_assert(m_matrix.cols() == m_matrix.rows());
00299   eigen_assert(m_matrix.cols() == other.rows());
00300   eigen_assert(!(Mode & ZeroDiag));
00301   eigen_assert(Mode & (Upper|Lower));
00302 
00303 //   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
00304 
00305 //   typedef typename internal::conditional<copy,
00306 //     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
00307 //   OtherCopy otherCopy(other.derived());
00308 
00309   internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
00310 
00311 //   if (copy)
00312 //     other = otherCopy;
00313 }
00314 
00315 #ifdef EIGEN2_SUPPORT
00316 
00317 // deprecated stuff:
00318 
00320 template<typename Derived>
00321 template<typename OtherDerived>
00322 void SparseMatrixBase<Derived>::solveTriangularInPlace(MatrixBase<OtherDerived>& other) const
00323 {
00324   this->template triangular<Flags&(Upper|Lower)>().solveInPlace(other);
00325 }
00326 
00328 template<typename Derived>
00329 template<typename OtherDerived>
00330 typename internal::plain_matrix_type_column_major<OtherDerived>::type
00331 SparseMatrixBase<Derived>::solveTriangular(const MatrixBase<OtherDerived>& other) const
00332 {
00333   typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
00334   derived().solveTriangularInPlace(res);
00335   return res;
00336 }
00337 #endif // EIGEN2_SUPPORT
00338 
00339 #endif // EIGEN_SPARSETRIANGULARSOLVER_H


libicr
Author(s): Robert Krug
autogenerated on Mon Jan 6 2014 11:33:49