SparseDenseProduct.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
00011 #define EIGEN_SPARSEDENSEPRODUCT_H
00012 
00013 namespace Eigen { 
00014 
00015 template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
00016 {
00017   typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
00018 };
00019 
00020 template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
00021 {
00022   typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type;
00023 };
00024 
00025 template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
00026 {
00027   typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
00028 };
00029 
00030 template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
00031 {
00032   typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type;
00033 };
00034 
00035 namespace internal {
00036 
00037 template<typename Lhs, typename Rhs, bool Tr>
00038 struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00039 {
00040   typedef Sparse StorageKind;
00041   typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
00042                                             typename traits<Rhs>::Scalar>::ReturnType Scalar;
00043   typedef typename Lhs::Index Index;
00044   typedef typename Lhs::Nested LhsNested;
00045   typedef typename Rhs::Nested RhsNested;
00046   typedef typename remove_all<LhsNested>::type _LhsNested;
00047   typedef typename remove_all<RhsNested>::type _RhsNested;
00048 
00049   enum {
00050     LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
00051     RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
00052 
00053     RowsAtCompileTime    = Tr ? int(traits<Rhs>::RowsAtCompileTime)     : int(traits<Lhs>::RowsAtCompileTime),
00054     ColsAtCompileTime    = Tr ? int(traits<Lhs>::ColsAtCompileTime)     : int(traits<Rhs>::ColsAtCompileTime),
00055     MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime)  : int(traits<Lhs>::MaxRowsAtCompileTime),
00056     MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime)  : int(traits<Rhs>::MaxColsAtCompileTime),
00057 
00058     Flags = Tr ? RowMajorBit : 0,
00059 
00060     CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
00061   };
00062 };
00063 
00064 } // end namespace internal
00065 
00066 template<typename Lhs, typename Rhs, bool Tr>
00067 class SparseDenseOuterProduct
00068  : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
00069 {
00070   public:
00071 
00072     typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
00073     EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
00074     typedef internal::traits<SparseDenseOuterProduct> Traits;
00075 
00076   private:
00077 
00078     typedef typename Traits::LhsNested LhsNested;
00079     typedef typename Traits::RhsNested RhsNested;
00080     typedef typename Traits::_LhsNested _LhsNested;
00081     typedef typename Traits::_RhsNested _RhsNested;
00082 
00083   public:
00084 
00085     class InnerIterator;
00086 
00087     EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
00088       : m_lhs(lhs), m_rhs(rhs)
00089     {
00090       EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00091     }
00092 
00093     EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
00094       : m_lhs(lhs), m_rhs(rhs)
00095     {
00096       EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
00097     }
00098 
00099     EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
00100     EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
00101 
00102     EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
00103     EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
00104 
00105   protected:
00106     LhsNested m_lhs;
00107     RhsNested m_rhs;
00108 };
00109 
00110 template<typename Lhs, typename Rhs, bool Transpose>
00111 class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
00112 {
00113     typedef typename _LhsNested::InnerIterator Base;
00114   public:
00115     EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
00116       : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer))
00117     {
00118     }
00119 
00120     inline Index outer() const { return m_outer; }
00121     inline Index row() const { return Transpose ? Base::row() : m_outer; }
00122     inline Index col() const { return Transpose ? m_outer : Base::row(); }
00123 
00124     inline Scalar value() const { return Base::value() * m_factor; }
00125 
00126   protected:
00127     int m_outer;
00128     Scalar m_factor;
00129 };
00130 
00131 namespace internal {
00132 template<typename Lhs, typename Rhs>
00133 struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
00134  : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
00135 {
00136   typedef Dense StorageKind;
00137   typedef MatrixXpr XprKind;
00138 };
00139 
00140 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
00141          int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
00142          bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
00143 struct sparse_time_dense_product_impl;
00144 
00145 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00146 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
00147 {
00148   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00149   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00150   typedef typename internal::remove_all<DenseResType>::type Res;
00151   typedef typename Lhs::Index Index;
00152   typedef typename Lhs::InnerIterator LhsInnerIterator;
00153   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00154   {
00155     for(Index c=0; c<rhs.cols(); ++c)
00156     {
00157       int n = lhs.outerSize();
00158       for(Index j=0; j<n; ++j)
00159       {
00160         typename Res::Scalar tmp(0);
00161         for(LhsInnerIterator it(lhs,j); it ;++it)
00162           tmp += it.value() * rhs.coeff(it.index(),c);
00163         res.coeffRef(j,c) = alpha * tmp;
00164       }
00165     }
00166   }
00167 };
00168 
00169 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00170 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
00171 {
00172   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00173   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00174   typedef typename internal::remove_all<DenseResType>::type Res;
00175   typedef typename Lhs::InnerIterator LhsInnerIterator;
00176   typedef typename Lhs::Index Index;
00177   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00178   {
00179     for(Index c=0; c<rhs.cols(); ++c)
00180     {
00181       for(Index j=0; j<lhs.outerSize(); ++j)
00182       {
00183         typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
00184         for(LhsInnerIterator it(lhs,j); it ;++it)
00185           res.coeffRef(it.index(),c) += it.value() * rhs_j;
00186       }
00187     }
00188   }
00189 };
00190 
00191 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00192 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
00193 {
00194   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00195   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00196   typedef typename internal::remove_all<DenseResType>::type Res;
00197   typedef typename Lhs::InnerIterator LhsInnerIterator;
00198   typedef typename Lhs::Index Index;
00199   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00200   {
00201     for(Index j=0; j<lhs.outerSize(); ++j)
00202     {
00203       typename Res::RowXpr res_j(res.row(j));
00204       for(LhsInnerIterator it(lhs,j); it ;++it)
00205         res_j += (alpha*it.value()) * rhs.row(it.index());
00206     }
00207   }
00208 };
00209 
00210 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
00211 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
00212 {
00213   typedef typename internal::remove_all<SparseLhsType>::type Lhs;
00214   typedef typename internal::remove_all<DenseRhsType>::type Rhs;
00215   typedef typename internal::remove_all<DenseResType>::type Res;
00216   typedef typename Lhs::InnerIterator LhsInnerIterator;
00217   typedef typename Lhs::Index Index;
00218   static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
00219   {
00220     for(Index j=0; j<lhs.outerSize(); ++j)
00221     {
00222       typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
00223       for(LhsInnerIterator it(lhs,j); it ;++it)
00224         res.row(it.index()) += (alpha*it.value()) * rhs_j;
00225     }
00226   }
00227 };
00228 
00229 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
00230 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
00231 {
00232   sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
00233 }
00234 
00235 } // end namespace internal
00236 
00237 template<typename Lhs, typename Rhs>
00238 class SparseTimeDenseProduct
00239   : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
00240 {
00241   public:
00242     EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
00243 
00244     SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00245     {}
00246 
00247     template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
00248     {
00249       internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
00250     }
00251 
00252   private:
00253     SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
00254 };
00255 
00256 
00257 // dense = dense * sparse
00258 namespace internal {
00259 template<typename Lhs, typename Rhs>
00260 struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
00261  : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
00262 {
00263   typedef Dense StorageKind;
00264 };
00265 } // end namespace internal
00266 
00267 template<typename Lhs, typename Rhs>
00268 class DenseTimeSparseProduct
00269   : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
00270 {
00271   public:
00272     EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
00273 
00274     DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00275     {}
00276 
00277     template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
00278     {
00279       Transpose<const _LhsNested> lhs_t(m_lhs);
00280       Transpose<const _RhsNested> rhs_t(m_rhs);
00281       Transpose<Dest> dest_t(dest);
00282       internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
00283     }
00284 
00285   private:
00286     DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
00287 };
00288 
00289 // sparse * dense
00290 template<typename Derived>
00291 template<typename OtherDerived>
00292 inline const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
00293 SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
00294 {
00295   return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
00296 }
00297 
00298 } // end namespace Eigen
00299 
00300 #endif // EIGEN_SPARSEDENSEPRODUCT_H


win_eigen
Author(s): Daniel Stonier
autogenerated on Wed Sep 16 2015 07:12:04