00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
00011 #define EIGEN_SPARSEDENSEPRODUCT_H
00012
00013 namespace Eigen {
00014
00015 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
00016 {
00017 typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
00018 };
00019
00020 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
00021 {
00022 typedef typename internal::conditional<
00023 Lhs::IsRowMajor,
00024 SparseDenseOuterProduct<Rhs,Lhs,true>,
00025 SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
00026 };
00027
00028 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
00029 {
00030 typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
00031 };
00032
00033 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
00034 {
00035 typedef typename internal::conditional<
00036 Rhs::IsRowMajor,
00037 SparseDenseOuterProduct<Rhs,Lhs,true>,
00038 SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
00039 };
00040
00041 namespace internal {
00042
00043 template<typename Lhs, typename Rhs, bool Tr>
00044 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00045 {
00046 typedef Sparse StorageKind;
00047 typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
00048 typename traits<Rhs>::Scalar>::ReturnType Scalar;
00049 typedef typename Lhs::Index Index;
00050 typedef typename Lhs::Nested LhsNested;
00051 typedef typename Rhs::Nested RhsNested;
00052 typedef typename remove_all<LhsNested>::type _LhsNested;
00053 typedef typename remove_all<RhsNested>::type _RhsNested;
00054
00055 enum {
00056 LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
00057 RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
00058
00059 RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime),
00060 ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime),
00061 MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime),
00062 MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime),
00063
00064 Flags = Tr ? RowMajorBit : 0,
00065
00066 CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
00067 };
00068 };
00069
00070 }
00071
00072 template<typename Lhs, typename Rhs, bool Tr>
00073 class SparseDenseOuterProduct
00074 : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00075 {
00076 public:
00077
00078 typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
00079 EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
00080 typedef internal::traits<SparseDenseOuterProduct> Traits;
00081
00082 private:
00083
00084 typedef typename Traits::LhsNested LhsNested;
00085 typedef typename Traits::RhsNested RhsNested;
00086 typedef typename Traits::_LhsNested _LhsNested;
00087 typedef typename Traits::_RhsNested _RhsNested;
00088
00089 public:
00090
00091 class InnerIterator;
00092
00093 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
00094 : m_lhs(lhs), m_rhs(rhs)
00095 {
00096 EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00097 }
00098
00099 EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
00100 : m_lhs(lhs), m_rhs(rhs)
00101 {
00102 EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00103 }
00104
00105 EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
00106 EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
00107
00108 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
00109 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
00110
00111 protected:
00112 LhsNested m_lhs;
00113 RhsNested m_rhs;
00114 };
00115
00116 template<typename Lhs, typename Rhs, bool Transpose>
00117 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
00118 {
00119 typedef typename _LhsNested::InnerIterator Base;
00120 typedef typename SparseDenseOuterProduct::Index Index;
00121 public:
00122 EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
00123 : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
00124 { }
00125
00126 inline Index outer() const { return m_outer; }
00127 inline Index row() const { return Transpose ? m_outer : Base::index(); }
00128 inline Index col() const { return Transpose ? Base::index() : m_outer; }
00129
00130 inline Scalar value() const { return Base::value() * m_factor; }
00131
00132 protected:
00133 static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense())
00134 {
00135 return rhs.coeff(outer);
00136 }
00137
00138 static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
00139 {
00140 typename Traits::_RhsNested::InnerIterator it(rhs, outer);
00141 if (it && it.index()==0)
00142 return it.value();
00143
00144 return Scalar(0);
00145 }
00146
00147 Index m_outer;
00148 Scalar m_factor;
00149 };
00150
00151 namespace internal {
00152 template<typename Lhs, typename Rhs>
00153 struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
00154 : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
00155 {
00156 typedef Dense StorageKind;
00157 typedef MatrixXpr XprKind;
00158 };
00159
00160 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
00161 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
00162 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
00163 struct sparse_time_dense_product_impl;
00164
00165 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00166 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
00167 {
00168 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00169 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00170 typedef typename internal::remove_all<DenseResType>::type Res;
00171 typedef typename Lhs::Index Index;
00172 typedef typename Lhs::InnerIterator LhsInnerIterator;
00173 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00174 {
00175 for(Index c=0; c<rhs.cols(); ++c)
00176 {
00177 Index n = lhs.outerSize();
00178 for(Index j=0; j<n; ++j)
00179 {
00180 typename Res::Scalar tmp(0);
00181 for(LhsInnerIterator it(lhs,j); it ;++it)
00182 tmp += it.value() * rhs.coeff(it.index(),c);
00183 res.coeffRef(j,c) += alpha * tmp;
00184 }
00185 }
00186 }
00187 };
00188
00189 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00190 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
00191 {
00192 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00193 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00194 typedef typename internal::remove_all<DenseResType>::type Res;
00195 typedef typename Lhs::InnerIterator LhsInnerIterator;
00196 typedef typename Lhs::Index Index;
00197 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00198 {
00199 for(Index c=0; c<rhs.cols(); ++c)
00200 {
00201 for(Index j=0; j<lhs.outerSize(); ++j)
00202 {
00203 typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
00204 for(LhsInnerIterator it(lhs,j); it ;++it)
00205 res.coeffRef(it.index(),c) += it.value() * rhs_j;
00206 }
00207 }
00208 }
00209 };
00210
00211 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00212 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
00213 {
00214 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00215 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00216 typedef typename internal::remove_all<DenseResType>::type Res;
00217 typedef typename Lhs::InnerIterator LhsInnerIterator;
00218 typedef typename Lhs::Index Index;
00219 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00220 {
00221 for(Index j=0; j<lhs.outerSize(); ++j)
00222 {
00223 typename Res::RowXpr res_j(res.row(j));
00224 for(LhsInnerIterator it(lhs,j); it ;++it)
00225 res_j += (alpha*it.value()) * rhs.row(it.index());
00226 }
00227 }
00228 };
00229
00230 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00231 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
00232 {
00233 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00234 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00235 typedef typename internal::remove_all<DenseResType>::type Res;
00236 typedef typename Lhs::InnerIterator LhsInnerIterator;
00237 typedef typename Lhs::Index Index;
00238 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
00239 {
00240 for(Index j=0; j<lhs.outerSize(); ++j)
00241 {
00242 typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
00243 for(LhsInnerIterator it(lhs,j); it ;++it)
00244 res.row(it.index()) += (alpha*it.value()) * rhs_j;
00245 }
00246 }
00247 };
00248
00249 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
00250 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
00251 {
00252 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
00253 }
00254
00255 }
00256
00257 template<typename Lhs, typename Rhs>
00258 class SparseTimeDenseProduct
00259 : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
00260 {
00261 public:
00262 EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
00263
00264 SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00265 {}
00266
00267 template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
00268 {
00269 internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
00270 }
00271
00272 private:
00273 SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
00274 };
00275
00276
00277
00278 namespace internal {
00279 template<typename Lhs, typename Rhs>
00280 struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
00281 : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
00282 {
00283 typedef Dense StorageKind;
00284 };
00285 }
00286
00287 template<typename Lhs, typename Rhs>
00288 class DenseTimeSparseProduct
00289 : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
00290 {
00291 public:
00292 EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
00293
00294 DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00295 {}
00296
00297 template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
00298 {
00299 Transpose<const _LhsNested> lhs_t(m_lhs);
00300 Transpose<const _RhsNested> rhs_t(m_rhs);
00301 Transpose<Dest> dest_t(dest);
00302 internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
00303 }
00304
00305 private:
00306 DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
00307 };
00308
00309 }
00310
00311 #endif // EIGEN_SPARSEDENSEPRODUCT_H