00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_SPARSESPARSEPRODUCT_H
00026 #define EIGEN_SPARSESPARSEPRODUCT_H
00027
00028 namespace internal {
00029
00030 template<typename Lhs, typename Rhs, typename ResultType>
00031 static void sparse_product_impl2(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00032 {
00033 typedef typename remove_all<Lhs>::type::Scalar Scalar;
00034 typedef typename remove_all<Lhs>::type::Index Index;
00035
00036
00037 Index rows = lhs.innerSize();
00038 Index cols = rhs.outerSize();
00039 eigen_assert(lhs.outerSize() == rhs.innerSize());
00040
00041 std::vector<bool> mask(rows,false);
00042 Matrix<Scalar,Dynamic,1> values(rows);
00043 Matrix<Index,Dynamic,1> indices(rows);
00044
00045
00046 float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
00047 float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
00048 float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
00049
00050
00051
00052
00053 res.resize(rows, cols);
00054 res.reserve(Index(ratioRes*rows*cols));
00055
00056 for (Index j=0; j<cols; ++j)
00057 {
00058
00059 res.startVec(j);
00060 Index nnz = 0;
00061 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00062 {
00063 Scalar y = rhsIt.value();
00064 Index k = rhsIt.index();
00065 for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt)
00066 {
00067 Index i = lhsIt.index();
00068 Scalar x = lhsIt.value();
00069 if(!mask[i])
00070 {
00071 mask[i] = true;
00072
00073
00074 ++nnz;
00075 }
00076 else
00077 values[i] += x * y;
00078 }
00079 }
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109 }
00110 res.finalize();
00111 }
00112
00113
00114 template<typename Lhs, typename Rhs, typename ResultType>
00115 static void sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00116 {
00117
00118
00119 typedef typename remove_all<Lhs>::type::Scalar Scalar;
00120 typedef typename remove_all<Lhs>::type::Index Index;
00121
00122
00123 Index rows = lhs.innerSize();
00124 Index cols = rhs.outerSize();
00125
00126 eigen_assert(lhs.outerSize() == rhs.innerSize());
00127
00128
00129 AmbiVector<Scalar,Index> tempVector(rows);
00130
00131
00132 float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
00133 float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
00134 float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
00135
00136
00137 if(ResultType::IsRowMajor)
00138 res.resize(cols, rows);
00139 else
00140 res.resize(rows, cols);
00141
00142 res.reserve(Index(ratioRes*rows*cols));
00143 for (Index j=0; j<cols; ++j)
00144 {
00145
00146
00147
00148 float ratioColRes = ratioRes;
00149 tempVector.init(ratioColRes);
00150 tempVector.setZero();
00151 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00152 {
00153
00154 tempVector.restart();
00155 Scalar x = rhsIt.value();
00156 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00157 {
00158 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00159 }
00160 }
00161 res.startVec(j);
00162 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector); it; ++it)
00163 res.insertBackByOuterInner(j,it.index()) = it.value();
00164 }
00165 res.finalize();
00166 }
00167
00168 template<typename Lhs, typename Rhs, typename ResultType,
00169 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00170 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00171 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00172 struct sparse_product_selector;
00173
00174 template<typename Lhs, typename Rhs, typename ResultType>
00175 struct sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00176 {
00177 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00178
00179 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00180 {
00181
00182 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00183 sparse_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res);
00184 res.swap(_res);
00185 }
00186 };
00187
00188 template<typename Lhs, typename Rhs, typename ResultType>
00189 struct sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00190 {
00191 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00192 {
00193
00194
00195 typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00196 SparseTemporaryType _res(res.rows(), res.cols());
00197 sparse_product_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res);
00198 res = _res;
00199 }
00200 };
00201
00202 template<typename Lhs, typename Rhs, typename ResultType>
00203 struct sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00204 {
00205 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00206 {
00207
00208
00209 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00210 sparse_product_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res);
00211 res.swap(_res);
00212 }
00213 };
00214
00215 template<typename Lhs, typename Rhs, typename ResultType>
00216 struct sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00217 {
00218 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00219 {
00220
00221 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00222 ColMajorMatrix colLhs(lhs);
00223 ColMajorMatrix colRhs(rhs);
00224
00225 sparse_product_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res);
00226
00227
00228
00229
00230
00231
00232
00233
00234 }
00235 };
00236
00237
00238
00239
00240 }
00241
00242
00243 template<typename Derived>
00244 template<typename Lhs, typename Rhs>
00245 inline Derived& SparseMatrixBase<Derived>::operator=(const SparseSparseProduct<Lhs,Rhs>& product)
00246 {
00247
00248 internal::sparse_product_selector<
00249 typename internal::remove_all<Lhs>::type,
00250 typename internal::remove_all<Rhs>::type,
00251 Derived>::run(product.lhs(),product.rhs(),derived());
00252 return derived();
00253 }
00254
00255 namespace internal {
00256
00257 template<typename Lhs, typename Rhs, typename ResultType,
00258 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00259 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00260 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00261 struct sparse_product_selector2;
00262
00263 template<typename Lhs, typename Rhs, typename ResultType>
00264 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00265 {
00266 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00267
00268 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00269 {
00270 sparse_product_impl2<Lhs,Rhs,ResultType>(lhs, rhs, res);
00271 }
00272 };
00273
00274 template<typename Lhs, typename Rhs, typename ResultType>
00275 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
00276 {
00277 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00278 {
00279
00280 EIGEN_UNUSED_VARIABLE(lhs);
00281 EIGEN_UNUSED_VARIABLE(rhs);
00282 EIGEN_UNUSED_VARIABLE(res);
00283
00284
00285
00286
00287
00288
00289 }
00290 };
00291
00292 template<typename Lhs, typename Rhs, typename ResultType>
00293 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
00294 {
00295 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00296 {
00297 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
00298 RowMajorMatrix lhsRow = lhs;
00299 RowMajorMatrix resRow(res.rows(), res.cols());
00300 sparse_product_impl2<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow);
00301 res = resRow;
00302 }
00303 };
00304
00305 template<typename Lhs, typename Rhs, typename ResultType>
00306 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00307 {
00308 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00309 {
00310 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
00311 RowMajorMatrix resRow(res.rows(), res.cols());
00312 sparse_product_impl2<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
00313 res = resRow;
00314 }
00315 };
00316
00317
00318 template<typename Lhs, typename Rhs, typename ResultType>
00319 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00320 {
00321 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00322
00323 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00324 {
00325 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00326 ColMajorMatrix resCol(res.rows(), res.cols());
00327 sparse_product_impl2<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
00328 res = resCol;
00329 }
00330 };
00331
00332 template<typename Lhs, typename Rhs, typename ResultType>
00333 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
00334 {
00335 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00336 {
00337 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00338 ColMajorMatrix lhsCol = lhs;
00339 ColMajorMatrix resCol(res.rows(), res.cols());
00340 sparse_product_impl2<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol);
00341 res = resCol;
00342 }
00343 };
00344
00345 template<typename Lhs, typename Rhs, typename ResultType>
00346 struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
00347 {
00348 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00349 {
00350 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00351 ColMajorMatrix rhsCol = rhs;
00352 ColMajorMatrix resCol(res.rows(), res.cols());
00353 sparse_product_impl2<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol);
00354 res = resCol;
00355 }
00356 };
00357
00358 template<typename Lhs, typename Rhs, typename ResultType>
00359 struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00360 {
00361 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00362 {
00363 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00364
00365
00366
00367
00368
00369
00370 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
00371 ColMajorMatrix lhsCol(lhs);
00372 ColMajorMatrix rhsCol(rhs);
00373 ColMajorMatrix resCol(res.rows(), res.cols());
00374 sparse_product_impl2<ColMajorMatrix,ColMajorMatrix,ColMajorMatrix>(lhsCol, rhsCol, resCol);
00375 res = resCol;
00376 }
00377 };
00378
00379 }
00380
00381 template<typename Derived>
00382 template<typename Lhs, typename Rhs>
00383 inline void SparseMatrixBase<Derived>::_experimentalNewProduct(const Lhs& lhs, const Rhs& rhs)
00384 {
00385
00386 internal::sparse_product_selector2<
00387 typename internal::remove_all<Lhs>::type,
00388 typename internal::remove_all<Rhs>::type,
00389 Derived>::run(lhs,rhs,derived());
00390 }
00391
00392
00393 template<typename Derived>
00394 template<typename OtherDerived>
00395 inline const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
00396 SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
00397 {
00398 return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00399 }
00400
00401 #endif // EIGEN_SPARSESPARSEPRODUCT_H