00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef EIGEN_SPARSEPRODUCT_H
00026 #define EIGEN_SPARSEPRODUCT_H
00027
00028 template<typename Lhs, typename Rhs> struct ei_sparse_product_mode
00029 {
00030 enum {
00031
00032 value = ((Lhs::Flags&Diagonal)==Diagonal || (Rhs::Flags&Diagonal)==Diagonal)
00033 ? DiagonalProduct
00034 : (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit
00035 ? SparseTimeSparseProduct
00036 : (Lhs::Flags&SparseBit)==SparseBit
00037 ? SparseTimeDenseProduct
00038 : DenseTimeSparseProduct };
00039 };
00040
00041 template<typename Lhs, typename Rhs, int ProductMode>
00042 struct SparseProductReturnType
00043 {
00044 typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested;
00045 typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
00046
00047 typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type;
00048 };
00049
00050 template<typename Lhs, typename Rhs>
00051 struct SparseProductReturnType<Lhs,Rhs,DiagonalProduct>
00052 {
00053 typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested;
00054 typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
00055
00056 typedef SparseDiagonalProduct<LhsNested, RhsNested> Type;
00057 };
00058
00059
00060 template<typename Lhs, typename Rhs>
00061 struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct>
00062 {
00063 typedef typename ei_traits<Lhs>::Scalar Scalar;
00064 enum {
00065 LhsRowMajor = ei_traits<Lhs>::Flags & RowMajorBit,
00066 RhsRowMajor = ei_traits<Rhs>::Flags & RowMajorBit,
00067 TransposeRhs = (!LhsRowMajor) && RhsRowMajor,
00068 TransposeLhs = LhsRowMajor && (!RhsRowMajor)
00069 };
00070
00071
00072
00073 typedef typename ei_meta_if<TransposeLhs,
00074 SparseMatrix<Scalar,0>,
00075 const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type>::ret LhsNested;
00076
00077 typedef typename ei_meta_if<TransposeRhs,
00078 SparseMatrix<Scalar,0>,
00079 const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested;
00080
00081 typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type;
00082 };
00083
00084 template<typename LhsNested, typename RhsNested, int ProductMode>
00085 struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >
00086 {
00087
00088 typedef typename ei_cleantype<LhsNested>::type _LhsNested;
00089 typedef typename ei_cleantype<RhsNested>::type _RhsNested;
00090 typedef typename _LhsNested::Scalar Scalar;
00091
00092 enum {
00093 LhsCoeffReadCost = _LhsNested::CoeffReadCost,
00094 RhsCoeffReadCost = _RhsNested::CoeffReadCost,
00095 LhsFlags = _LhsNested::Flags,
00096 RhsFlags = _RhsNested::Flags,
00097
00098 RowsAtCompileTime = _LhsNested::RowsAtCompileTime,
00099 ColsAtCompileTime = _RhsNested::ColsAtCompileTime,
00100 InnerSize = EIGEN_SIZE_MIN(_LhsNested::ColsAtCompileTime, _RhsNested::RowsAtCompileTime),
00101
00102 MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
00103 MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
00104
00105
00106
00107
00108 EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
00109 ResultIsSparse = ProductMode==SparseTimeSparseProduct || ProductMode==DiagonalProduct,
00110
00111 RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ),
00112
00113 Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
00114 | EvalBeforeAssigningBit
00115 | EvalBeforeNestingBit,
00116
00117 CoeffReadCost = Dynamic
00118 };
00119
00120 typedef typename ei_meta_if<ResultIsSparse,
00121 SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >,
00122 MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base;
00123 };
00124
00125 template<typename LhsNested, typename RhsNested, int ProductMode>
00126 class SparseProduct : ei_no_assignment_operator,
00127 public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base
00128 {
00129 public:
00130
00131 EIGEN_GENERIC_PUBLIC_INTERFACE(SparseProduct)
00132
00133 private:
00134
00135 typedef typename ei_traits<SparseProduct>::_LhsNested _LhsNested;
00136 typedef typename ei_traits<SparseProduct>::_RhsNested _RhsNested;
00137
00138 public:
00139
00140 template<typename Lhs, typename Rhs>
00141 EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs)
00142 : m_lhs(lhs), m_rhs(rhs)
00143 {
00144 ei_assert(lhs.cols() == rhs.rows());
00145
00146 enum {
00147 ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic
00148 || _RhsNested::RowsAtCompileTime==Dynamic
00149 || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime),
00150 AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
00151 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested)
00152 };
00153
00154
00155
00156 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
00157 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
00158 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
00159 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
00160 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
00161 }
00162
00163 EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); }
00164 EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); }
00165
00166 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
00167 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
00168
00169 protected:
00170 LhsNested m_lhs;
00171 RhsNested m_rhs;
00172 };
00173
00174
00175 template<typename Lhs, typename Rhs, typename ResultType>
00176 static void ei_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00177 {
00178 typedef typename ei_traits<typename ei_cleantype<Lhs>::type>::Scalar Scalar;
00179
00180
00181 int rows = lhs.innerSize();
00182 int cols = rhs.outerSize();
00183
00184 ei_assert(lhs.outerSize() == rhs.innerSize());
00185
00186
00187 AmbiVector<Scalar> tempVector(rows);
00188
00189
00190 float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
00191 float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
00192 float ratioRes = std::min(ratioLhs * avgNnzPerRhsColumn, 1.f);
00193
00194 res.resize(rows, cols);
00195 res.startFill(int(ratioRes*rows*cols));
00196 for (int j=0; j<cols; ++j)
00197 {
00198
00199
00200
00201 float ratioColRes = ratioRes;
00202 tempVector.init(ratioColRes);
00203 tempVector.setZero();
00204 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
00205 {
00206
00207 tempVector.restart();
00208 Scalar x = rhsIt.value();
00209 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
00210 {
00211 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00212 }
00213 }
00214 for (typename AmbiVector<Scalar>::Iterator it(tempVector); it; ++it)
00215 if (ResultType::Flags&RowMajorBit)
00216 res.fill(j,it.index()) = it.value();
00217 else
00218 res.fill(it.index(), j) = it.value();
00219 }
00220 res.endFill();
00221 }
00222
00223 template<typename Lhs, typename Rhs, typename ResultType,
00224 int LhsStorageOrder = ei_traits<Lhs>::Flags&RowMajorBit,
00225 int RhsStorageOrder = ei_traits<Rhs>::Flags&RowMajorBit,
00226 int ResStorageOrder = ei_traits<ResultType>::Flags&RowMajorBit>
00227 struct ei_sparse_product_selector;
00228
00229 template<typename Lhs, typename Rhs, typename ResultType>
00230 struct ei_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00231 {
00232 typedef typename ei_traits<typename ei_cleantype<Lhs>::type>::Scalar Scalar;
00233
00234 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00235 {
00236 typename ei_cleantype<ResultType>::type _res(res.rows(), res.cols());
00237 ei_sparse_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res);
00238 res.swap(_res);
00239 }
00240 };
00241
00242 template<typename Lhs, typename Rhs, typename ResultType>
00243 struct ei_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00244 {
00245 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00246 {
00247
00248 typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00249 SparseTemporaryType _res(res.rows(), res.cols());
00250 ei_sparse_product_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res);
00251 res = _res;
00252 }
00253 };
00254
00255 template<typename Lhs, typename Rhs, typename ResultType>
00256 struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00257 {
00258 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00259 {
00260
00261 typename ei_cleantype<ResultType>::type _res(res.rows(), res.cols());
00262 ei_sparse_product_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res);
00263 res.swap(_res);
00264 }
00265 };
00266
00267 template<typename Lhs, typename Rhs, typename ResultType>
00268 struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00269 {
00270 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
00271 {
00272
00273 typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00274 SparseTemporaryType _res(res.cols(), res.rows());
00275 ei_sparse_product_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
00276 res = _res.transpose();
00277 }
00278 };
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309 template<typename Derived>
00310 template<typename Lhs, typename Rhs>
00311 inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product)
00312 {
00313 ei_sparse_product_selector<
00314 typename ei_cleantype<Lhs>::type,
00315 typename ei_cleantype<Rhs>::type,
00316 Derived>::run(product.lhs(),product.rhs(),derived());
00317 return derived();
00318 }
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335 template<typename Derived>
00336 template<typename Lhs, typename Rhs>
00337 Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product)
00338 {
00339 typedef typename ei_cleantype<Lhs>::type _Lhs;
00340 typedef typename ei_cleantype<Rhs>::type _Rhs;
00341 typedef typename _Lhs::InnerIterator LhsInnerIterator;
00342 enum {
00343 LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit,
00344 LhsIsSelfAdjoint = (_Lhs::Flags&SelfAdjointBit)==SelfAdjointBit,
00345 ProcessFirstHalf = LhsIsSelfAdjoint
00346 && ( ((_Lhs::Flags&(UpperTriangularBit|LowerTriangularBit))==0)
00347 || ( (_Lhs::Flags&UpperTriangularBit) && !LhsIsRowMajor)
00348 || ( (_Lhs::Flags&LowerTriangularBit) && LhsIsRowMajor) ),
00349 ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
00350 };
00351 derived().setZero();
00352 for (int j=0; j<product.lhs().outerSize(); ++j)
00353 {
00354 LhsInnerIterator i(product.lhs(),j);
00355 if (ProcessSecondHalf && i && (i.index()==j))
00356 {
00357 derived().row(j) += i.value() * product.rhs().row(j);
00358 ++i;
00359 }
00360 Block<Derived,1,Derived::ColsAtCompileTime> res(derived().row(LhsIsRowMajor ? j : 0));
00361 for (; (ProcessFirstHalf ? i && i.index() < j : i) ; ++i)
00362 {
00363 if (LhsIsSelfAdjoint)
00364 {
00365 int a = LhsIsRowMajor ? j : i.index();
00366 int b = LhsIsRowMajor ? i.index() : j;
00367 Scalar v = i.value();
00368 derived().row(a) += (v) * product.rhs().row(b);
00369 derived().row(b) += ei_conj(v) * product.rhs().row(a);
00370 }
00371 else if (LhsIsRowMajor)
00372 res += i.value() * product.rhs().row(i.index());
00373 else
00374 derived().row(i.index()) += i.value() * product.rhs().row(j);
00375 }
00376 if (ProcessFirstHalf && i && (i.index()==j))
00377 derived().row(j) += i.value() * product.rhs().row(j);
00378 }
00379 return derived();
00380 }
00381
00382
00383 template<typename Derived>
00384 template<typename Lhs, typename Rhs>
00385 Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product)
00386 {
00387 typedef typename ei_cleantype<Rhs>::type _Rhs;
00388 typedef typename _Rhs::InnerIterator RhsInnerIterator;
00389 enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit };
00390 derived().setZero();
00391 for (int j=0; j<product.rhs().outerSize(); ++j)
00392 for (RhsInnerIterator i(product.rhs(),j); i; ++i)
00393 derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index());
00394 return derived();
00395 }
00396
00397
00398 template<typename Derived>
00399 template<typename OtherDerived>
00400 EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
00401 SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
00402 {
00403 return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00404 }
00405
00406
00407 template<typename Derived>
00408 template<typename OtherDerived>
00409 EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
00410 SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
00411 {
00412 return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00413 }
00414
00415 #endif // EIGEN_SPARSEPRODUCT_H