SparseSparseProductWithPruning.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00011 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00012 
00013 namespace Eigen { 
00014 
00015 namespace internal {
00016 
00017 
00018 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
00019 template<typename Lhs, typename Rhs, typename ResultType>
00020 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance)
00021 {
00022   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
00023 
00024   typedef typename remove_all<Lhs>::type::Scalar Scalar;
00025   typedef typename remove_all<Lhs>::type::Index Index;
00026 
00027   // make sure to call innerSize/outerSize since we fake the storage order.
00028   Index rows = lhs.innerSize();
00029   Index cols = rhs.outerSize();
00030   //int size = lhs.outerSize();
00031   eigen_assert(lhs.outerSize() == rhs.innerSize());
00032 
00033   // allocate a temporary buffer
00034   AmbiVector<Scalar,Index> tempVector(rows);
00035 
00036   // estimate the number of non zero entries
00037   // given a rhs column containing Y non zeros, we assume that the respective Y columns
00038   // of the lhs differs in average of one non zeros, thus the number of non zeros for
00039   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
00040   // per column of the lhs.
00041   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
00042   Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
00043 
00044   // mimics a resizeByInnerOuter:
00045   if(ResultType::IsRowMajor)
00046     res.resize(cols, rows);
00047   else
00048     res.resize(rows, cols);
00049 
00050   res.reserve(estimated_nnz_prod);
00051   double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
00052   for (Index j=0; j<cols; ++j)
00053   {
00054     // FIXME:
00055     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
00056     // let's do a more accurate determination of the nnz ratio for the current column j of res
00057     tempVector.init(ratioColRes);
00058     tempVector.setZero();
00059     for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00060     {
00061       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
00062       tempVector.restart();
00063       Scalar x = rhsIt.value();
00064       for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00065       {
00066         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00067       }
00068     }
00069     res.startVec(j);
00070     for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
00071       res.insertBackByOuterInner(j,it.index()) = it.value();
00072   }
00073   res.finalize();
00074 }
00075 
00076 template<typename Lhs, typename Rhs, typename ResultType,
00077   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00078   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00079   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00080 struct sparse_sparse_product_with_pruning_selector;
00081 
00082 template<typename Lhs, typename Rhs, typename ResultType>
00083 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00084 {
00085   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00086   typedef typename ResultType::RealScalar RealScalar;
00087 
00088   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00089   {
00090     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00091     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
00092     res.swap(_res);
00093   }
00094 };
00095 
00096 template<typename Lhs, typename Rhs, typename ResultType>
00097 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00098 {
00099   typedef typename ResultType::RealScalar RealScalar;
00100   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00101   {
00102     // we need a col-major matrix to hold the result
00103     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00104     SparseTemporaryType _res(res.rows(), res.cols());
00105     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
00106     res = _res;
00107   }
00108 };
00109 
00110 template<typename Lhs, typename Rhs, typename ResultType>
00111 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00112 {
00113   typedef typename ResultType::RealScalar RealScalar;
00114   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00115   {
00116     // let's transpose the product to get a column x column product
00117     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00118     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
00119     res.swap(_res);
00120   }
00121 };
00122 
00123 template<typename Lhs, typename Rhs, typename ResultType>
00124 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00125 {
00126   typedef typename ResultType::RealScalar RealScalar;
00127   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00128   {
00129     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00130     ColMajorMatrix colLhs(lhs);
00131     ColMajorMatrix colRhs(rhs);
00132     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance);
00133 
00134     // let's transpose the product to get a column x column product
00135 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00136 //     SparseTemporaryType _res(res.cols(), res.rows());
00137 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
00138 //     res = _res.transpose();
00139   }
00140 };
00141 
00142 // NOTE the 2 others cases (col row *) must never occur since they are caught
00143 // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
00144 
00145 } // end namespace internal
00146 
00147 } // end namespace Eigen
00148 
00149 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H


win_eigen
Author(s): Daniel Stonier
autogenerated on Wed Sep 16 2015 07:12:13