00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H
00011 #define EIGEN_SPARSE_CWISE_BINARY_OP_H
00012
00013 namespace Eigen {
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 namespace internal {
00033
00034 template<> struct promote_storage_type<Dense,Sparse>
00035 { typedef Sparse ret; };
00036
00037 template<> struct promote_storage_type<Sparse,Dense>
00038 { typedef Sparse ret; };
00039
00040 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
00041 typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
00042 typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
00043 class sparse_cwise_binary_op_inner_iterator_selector;
00044
00045 }
00046
00047 template<typename BinaryOp, typename Lhs, typename Rhs>
00048 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
00049 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
00050 {
00051 public:
00052 class InnerIterator;
00053 class ReverseInnerIterator;
00054 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived;
00055 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
00056 CwiseBinaryOpImpl()
00057 {
00058 typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind;
00059 typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind;
00060 EIGEN_STATIC_ASSERT((
00061 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value)
00062 || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))),
00063 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH);
00064 }
00065 };
00066
00067 template<typename BinaryOp, typename Lhs, typename Rhs>
00068 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator
00069 : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator>
00070 {
00071 public:
00072 typedef typename Lhs::Index Index;
00073 typedef internal::sparse_cwise_binary_op_inner_iterator_selector<
00074 BinaryOp,Lhs,Rhs, InnerIterator> Base;
00075
00076 EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename CwiseBinaryOpImpl::Index outer)
00077 : Base(binOp.derived(),outer)
00078 {}
00079 };
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090 namespace internal {
00091
00092
00093 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived>
00094 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse>
00095 {
00096 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr;
00097 typedef typename traits<CwiseBinaryXpr>::Scalar Scalar;
00098 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00099 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00100 typedef typename _LhsNested::InnerIterator LhsIterator;
00101 typedef typename _RhsNested::InnerIterator RhsIterator;
00102 typedef typename Lhs::Index Index;
00103
00104 public:
00105
00106 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00107 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
00108 {
00109 this->operator++();
00110 }
00111
00112 EIGEN_STRONG_INLINE Derived& operator++()
00113 {
00114 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
00115 {
00116 m_id = m_lhsIter.index();
00117 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
00118 ++m_lhsIter;
00119 ++m_rhsIter;
00120 }
00121 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
00122 {
00123 m_id = m_lhsIter.index();
00124 m_value = m_functor(m_lhsIter.value(), Scalar(0));
00125 ++m_lhsIter;
00126 }
00127 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
00128 {
00129 m_id = m_rhsIter.index();
00130 m_value = m_functor(Scalar(0), m_rhsIter.value());
00131 ++m_rhsIter;
00132 }
00133 else
00134 {
00135 m_value = 0;
00136 m_id = -1;
00137 }
00138 return *static_cast<Derived*>(this);
00139 }
00140
00141 EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
00142
00143 EIGEN_STRONG_INLINE Index index() const { return m_id; }
00144 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
00145 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
00146
00147 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
00148
00149 protected:
00150 LhsIterator m_lhsIter;
00151 RhsIterator m_rhsIter;
00152 const BinaryOp& m_functor;
00153 Scalar m_value;
00154 Index m_id;
00155 };
00156
00157
00158 template<typename T, typename Lhs, typename Rhs, typename Derived>
00159 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse>
00160 {
00161 typedef scalar_product_op<T> BinaryFunc;
00162 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00163 typedef typename CwiseBinaryXpr::Scalar Scalar;
00164 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00165 typedef typename _LhsNested::InnerIterator LhsIterator;
00166 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00167 typedef typename _RhsNested::InnerIterator RhsIterator;
00168 typedef typename Lhs::Index Index;
00169 public:
00170
00171 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00172 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
00173 {
00174 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
00175 {
00176 if (m_lhsIter.index() < m_rhsIter.index())
00177 ++m_lhsIter;
00178 else
00179 ++m_rhsIter;
00180 }
00181 }
00182
00183 EIGEN_STRONG_INLINE Derived& operator++()
00184 {
00185 ++m_lhsIter;
00186 ++m_rhsIter;
00187 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
00188 {
00189 if (m_lhsIter.index() < m_rhsIter.index())
00190 ++m_lhsIter;
00191 else
00192 ++m_rhsIter;
00193 }
00194 return *static_cast<Derived*>(this);
00195 }
00196
00197 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
00198
00199 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
00200 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
00201 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
00202
00203 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
00204
00205 protected:
00206 LhsIterator m_lhsIter;
00207 RhsIterator m_rhsIter;
00208 const BinaryFunc& m_functor;
00209 };
00210
00211
00212 template<typename T, typename Lhs, typename Rhs, typename Derived>
00213 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense>
00214 {
00215 typedef scalar_product_op<T> BinaryFunc;
00216 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00217 typedef typename CwiseBinaryXpr::Scalar Scalar;
00218 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00219 typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested;
00220 typedef typename _LhsNested::InnerIterator LhsIterator;
00221 typedef typename Lhs::Index Index;
00222 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
00223 public:
00224
00225 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00226 : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer)
00227 {}
00228
00229 EIGEN_STRONG_INLINE Derived& operator++()
00230 {
00231 ++m_lhsIter;
00232 return *static_cast<Derived*>(this);
00233 }
00234
00235 EIGEN_STRONG_INLINE Scalar value() const
00236 { return m_functor(m_lhsIter.value(),
00237 m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
00238
00239 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
00240 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
00241 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
00242
00243 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
00244
00245 protected:
00246 RhsNested m_rhs;
00247 LhsIterator m_lhsIter;
00248 const BinaryFunc m_functor;
00249 const Index m_outer;
00250 };
00251
00252
00253 template<typename T, typename Lhs, typename Rhs, typename Derived>
00254 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse>
00255 {
00256 typedef scalar_product_op<T> BinaryFunc;
00257 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00258 typedef typename CwiseBinaryXpr::Scalar Scalar;
00259 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00260 typedef typename _RhsNested::InnerIterator RhsIterator;
00261 typedef typename Lhs::Index Index;
00262
00263 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
00264 public:
00265
00266 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00267 : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer)
00268 {}
00269
00270 EIGEN_STRONG_INLINE Derived& operator++()
00271 {
00272 ++m_rhsIter;
00273 return *static_cast<Derived*>(this);
00274 }
00275
00276 EIGEN_STRONG_INLINE Scalar value() const
00277 { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
00278
00279 EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
00280 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
00281 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
00282
00283 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
00284
00285 protected:
00286 const CwiseBinaryXpr& m_xpr;
00287 RhsIterator m_rhsIter;
00288 const BinaryFunc& m_functor;
00289 const Index m_outer;
00290 };
00291
00292 }
00293
00294
00295
00296
00297
00298 template<typename Derived>
00299 template<typename OtherDerived>
00300 EIGEN_STRONG_INLINE Derived &
00301 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other)
00302 {
00303 return *this = derived() - other.derived();
00304 }
00305
00306 template<typename Derived>
00307 template<typename OtherDerived>
00308 EIGEN_STRONG_INLINE Derived &
00309 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other)
00310 {
00311 return *this = derived() + other.derived();
00312 }
00313
00314 template<typename Derived>
00315 template<typename OtherDerived>
00316 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
00317 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
00318 {
00319 return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
00320 }
00321
00322 }
00323
00324 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H