00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00011 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00012
00013 namespace Eigen {
00014
00015 namespace internal {
00016
00017
00018
00019 template<typename Lhs, typename Rhs, typename ResultType>
00020 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance)
00021 {
00022
00023
00024 typedef typename remove_all<Lhs>::type::Scalar Scalar;
00025 typedef typename remove_all<Lhs>::type::Index Index;
00026
00027
00028 Index rows = lhs.innerSize();
00029 Index cols = rhs.outerSize();
00030
00031 eigen_assert(lhs.outerSize() == rhs.innerSize());
00032
00033
00034 AmbiVector<Scalar,Index> tempVector(rows);
00035
00036
00037
00038
00039
00040
00041
00042 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
00043
00044
00045 if(ResultType::IsRowMajor)
00046 res.resize(cols, rows);
00047 else
00048 res.resize(rows, cols);
00049
00050 res.reserve(estimated_nnz_prod);
00051 double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
00052 for (Index j=0; j<cols; ++j)
00053 {
00054
00055
00056
00057 tempVector.init(ratioColRes);
00058 tempVector.setZero();
00059 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00060 {
00061
00062 tempVector.restart();
00063 Scalar x = rhsIt.value();
00064 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00065 {
00066 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00067 }
00068 }
00069 res.startVec(j);
00070 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
00071 res.insertBackByOuterInner(j,it.index()) = it.value();
00072 }
00073 res.finalize();
00074 }
00075
00076 template<typename Lhs, typename Rhs, typename ResultType,
00077 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00078 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00079 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00080 struct sparse_sparse_product_with_pruning_selector;
00081
00082 template<typename Lhs, typename Rhs, typename ResultType>
00083 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00084 {
00085 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00086 typedef typename ResultType::RealScalar RealScalar;
00087
00088 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00089 {
00090 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00091 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
00092 res.swap(_res);
00093 }
00094 };
00095
00096 template<typename Lhs, typename Rhs, typename ResultType>
00097 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00098 {
00099 typedef typename ResultType::RealScalar RealScalar;
00100 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00101 {
00102
00103 typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00104 SparseTemporaryType _res(res.rows(), res.cols());
00105 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
00106 res = _res;
00107 }
00108 };
00109
00110 template<typename Lhs, typename Rhs, typename ResultType>
00111 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00112 {
00113 typedef typename ResultType::RealScalar RealScalar;
00114 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00115 {
00116
00117 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00118 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
00119 res.swap(_res);
00120 }
00121 };
00122
00123 template<typename Lhs, typename Rhs, typename ResultType>
00124 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00125 {
00126 typedef typename ResultType::RealScalar RealScalar;
00127 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance)
00128 {
00129 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00130 ColMajorMatrix colLhs(lhs);
00131 ColMajorMatrix colRhs(rhs);
00132 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance);
00133
00134
00135
00136
00137
00138
00139 }
00140 };
00141
00142
00143
00144
00145 }
00146
00147 }
00148
00149 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H