00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
00011 #define EIGEN_SPARSEDENSEPRODUCT_H
00012
00013 namespace Eigen {
00014
00015 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
00016 {
00017 typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
00018 };
00019
00020 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
00021 {
00022 typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type;
00023 };
00024
00025 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
00026 {
00027 typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
00028 };
00029
00030 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
00031 {
00032 typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type;
00033 };
00034
00035 namespace internal {
00036
00037 template<typename Lhs, typename Rhs, bool Tr>
00038 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00039 {
00040 typedef Sparse StorageKind;
00041 typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
00042 typename traits<Rhs>::Scalar>::ReturnType Scalar;
00043 typedef typename Lhs::Index Index;
00044 typedef typename Lhs::Nested LhsNested;
00045 typedef typename Rhs::Nested RhsNested;
00046 typedef typename remove_all<LhsNested>::type _LhsNested;
00047 typedef typename remove_all<RhsNested>::type _RhsNested;
00048
00049 enum {
00050 LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
00051 RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
00052
00053 RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime),
00054 ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime),
00055 MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime),
00056 MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime),
00057
00058 Flags = Tr ? RowMajorBit : 0,
00059
00060 CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
00061 };
00062 };
00063
00064 }
00065
00066 template<typename Lhs, typename Rhs, bool Tr>
00067 class SparseDenseOuterProduct
00068 : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00069 {
00070 public:
00071
00072 typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
00073 EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
00074 typedef internal::traits<SparseDenseOuterProduct> Traits;
00075
00076 private:
00077
00078 typedef typename Traits::LhsNested LhsNested;
00079 typedef typename Traits::RhsNested RhsNested;
00080 typedef typename Traits::_LhsNested _LhsNested;
00081 typedef typename Traits::_RhsNested _RhsNested;
00082
00083 public:
00084
00085 class InnerIterator;
00086
00087 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
00088 : m_lhs(lhs), m_rhs(rhs)
00089 {
00090 EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00091 }
00092
00093 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
00094 : m_lhs(lhs), m_rhs(rhs)
00095 {
00096 EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00097 }
00098
00099 EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
00100 EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
00101
00102 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
00103 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
00104
00105 protected:
00106 LhsNested m_lhs;
00107 RhsNested m_rhs;
00108 };
00109
00110 template<typename Lhs, typename Rhs, bool Transpose>
00111 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
00112 {
00113 typedef typename _LhsNested::InnerIterator Base;
00114 typedef typename SparseDenseOuterProduct::Index Index;
00115 public:
00116 EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
00117 : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer))
00118 {
00119 }
00120
00121 inline Index outer() const { return m_outer; }
00122 inline Index row() const { return Transpose ? Base::row() : m_outer; }
00123 inline Index col() const { return Transpose ? m_outer : Base::row(); }
00124
00125 inline Scalar value() const { return Base::value() * m_factor; }
00126
00127 protected:
00128 int m_outer;
00129 Scalar m_factor;
00130 };
00131
00132 namespace internal {
00133 template<typename Lhs, typename Rhs>
00134 struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
00135 : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
00136 {
00137 typedef Dense StorageKind;
00138 typedef MatrixXpr XprKind;
00139 };
00140
00141 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
00142 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
00143 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
00144 struct sparse_time_dense_product_impl;
00145
00146 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00147 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
00148 {
00149 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00150 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00151 typedef typename internal::remove_all<DenseResType>::type Res;
00152 typedef typename Lhs::Index Index;
00153 typedef typename Lhs::InnerIterator LhsInnerIterator;
00154 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00155 {
00156 for(Index c=0; c<rhs.cols(); ++c)
00157 {
00158 int n = lhs.outerSize();
00159 for(Index j=0; j<n; ++j)
00160 {
00161 typename Res::Scalar tmp(0);
00162 for(LhsInnerIterator it(lhs,j); it ;++it)
00163 tmp += it.value() * rhs.coeff(it.index(),c);
00164 res.coeffRef(j,c) = alpha * tmp;
00165 }
00166 }
00167 }
00168 };
00169
00170 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00171 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
00172 {
00173 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00174 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00175 typedef typename internal::remove_all<DenseResType>::type Res;
00176 typedef typename Lhs::InnerIterator LhsInnerIterator;
00177 typedef typename Lhs::Index Index;
00178 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00179 {
00180 for(Index c=0; c<rhs.cols(); ++c)
00181 {
00182 for(Index j=0; j<lhs.outerSize(); ++j)
00183 {
00184 typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
00185 for(LhsInnerIterator it(lhs,j); it ;++it)
00186 res.coeffRef(it.index(),c) += it.value() * rhs_j;
00187 }
00188 }
00189 }
00190 };
00191
00192 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00193 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
00194 {
00195 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00196 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00197 typedef typename internal::remove_all<DenseResType>::type Res;
00198 typedef typename Lhs::InnerIterator LhsInnerIterator;
00199 typedef typename Lhs::Index Index;
00200 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00201 {
00202 for(Index j=0; j<lhs.outerSize(); ++j)
00203 {
00204 typename Res::RowXpr res_j(res.row(j));
00205 for(LhsInnerIterator it(lhs,j); it ;++it)
00206 res_j += (alpha*it.value()) * rhs.row(it.index());
00207 }
00208 }
00209 };
00210
00211 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00212 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
00213 {
00214 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00215 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00216 typedef typename internal::remove_all<DenseResType>::type Res;
00217 typedef typename Lhs::InnerIterator LhsInnerIterator;
00218 typedef typename Lhs::Index Index;
00219 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00220 {
00221 for(Index j=0; j<lhs.outerSize(); ++j)
00222 {
00223 typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
00224 for(LhsInnerIterator it(lhs,j); it ;++it)
00225 res.row(it.index()) += (alpha*it.value()) * rhs_j;
00226 }
00227 }
00228 };
00229
00230 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
00231 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
00232 {
00233 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
00234 }
00235
00236 }
00237
00238 template<typename Lhs, typename Rhs>
00239 class SparseTimeDenseProduct
00240 : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
00241 {
00242 public:
00243 EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
00244
00245 SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00246 {}
00247
00248 template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
00249 {
00250 internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
00251 }
00252
00253 private:
00254 SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
00255 };
00256
00257
00258
00259 namespace internal {
00260 template<typename Lhs, typename Rhs>
00261 struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
00262 : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
00263 {
00264 typedef Dense StorageKind;
00265 };
00266 }
00267
00268 template<typename Lhs, typename Rhs>
00269 class DenseTimeSparseProduct
00270 : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
00271 {
00272 public:
00273 EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
00274
00275 DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00276 {}
00277
00278 template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
00279 {
00280 Transpose<const _LhsNested> lhs_t(m_lhs);
00281 Transpose<const _RhsNested> rhs_t(m_rhs);
00282 Transpose<Dest> dest_t(dest);
00283 internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
00284 }
00285
00286 private:
00287 DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
00288 };
00289
00290
00291 template<typename Derived>
00292 template<typename OtherDerived>
00293 inline const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
00294 SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
00295 {
00296 return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00297 }
00298
00299 }
00300
00301 #endif // EIGEN_SPARSEDENSEPRODUCT_H