00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H
00011 #define EIGEN_SPARSE_CWISE_BINARY_OP_H
00012
00013 namespace Eigen {
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 namespace internal {
00033
00034 template<> struct promote_storage_type<Dense,Sparse>
00035 { typedef Sparse ret; };
00036
00037 template<> struct promote_storage_type<Sparse,Dense>
00038 { typedef Sparse ret; };
00039
00040 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
00041 typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
00042 typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
00043 class sparse_cwise_binary_op_inner_iterator_selector;
00044
00045 }
00046
00047 template<typename BinaryOp, typename Lhs, typename Rhs>
00048 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
00049 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
00050 {
00051 public:
00052 class InnerIterator;
00053 class ReverseInnerIterator;
00054 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived;
00055 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
00056 CwiseBinaryOpImpl()
00057 {
00058 typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind;
00059 typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind;
00060 EIGEN_STATIC_ASSERT((
00061 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value)
00062 || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))),
00063 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH);
00064 }
00065 };
00066
00067 template<typename BinaryOp, typename Lhs, typename Rhs>
00068 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator
00069 : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator>
00070 {
00071 public:
00072 typedef typename Lhs::Index Index;
00073 typedef internal::sparse_cwise_binary_op_inner_iterator_selector<
00074 BinaryOp,Lhs,Rhs, InnerIterator> Base;
00075
00076
00077 EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename Lhs::Index outer)
00078 : Base(binOp.derived(),outer)
00079 {}
00080 };
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091 namespace internal {
00092
00093
00094 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived>
00095 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse>
00096 {
00097 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr;
00098 typedef typename traits<CwiseBinaryXpr>::Scalar Scalar;
00099 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00100 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00101 typedef typename _LhsNested::InnerIterator LhsIterator;
00102 typedef typename _RhsNested::InnerIterator RhsIterator;
00103 typedef typename Lhs::Index Index;
00104
00105 public:
00106
00107 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00108 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
00109 {
00110 this->operator++();
00111 }
00112
00113 EIGEN_STRONG_INLINE Derived& operator++()
00114 {
00115 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
00116 {
00117 m_id = m_lhsIter.index();
00118 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
00119 ++m_lhsIter;
00120 ++m_rhsIter;
00121 }
00122 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
00123 {
00124 m_id = m_lhsIter.index();
00125 m_value = m_functor(m_lhsIter.value(), Scalar(0));
00126 ++m_lhsIter;
00127 }
00128 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
00129 {
00130 m_id = m_rhsIter.index();
00131 m_value = m_functor(Scalar(0), m_rhsIter.value());
00132 ++m_rhsIter;
00133 }
00134 else
00135 {
00136 m_value = 0;
00137 m_id = -1;
00138 }
00139 return *static_cast<Derived*>(this);
00140 }
00141
00142 EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
00143
00144 EIGEN_STRONG_INLINE Index index() const { return m_id; }
00145 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
00146 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
00147
00148 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
00149
00150 protected:
00151 LhsIterator m_lhsIter;
00152 RhsIterator m_rhsIter;
00153 const BinaryOp& m_functor;
00154 Scalar m_value;
00155 Index m_id;
00156 };
00157
00158
00159 template<typename T, typename Lhs, typename Rhs, typename Derived>
00160 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse>
00161 {
00162 typedef scalar_product_op<T> BinaryFunc;
00163 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00164 typedef typename CwiseBinaryXpr::Scalar Scalar;
00165 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00166 typedef typename _LhsNested::InnerIterator LhsIterator;
00167 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00168 typedef typename _RhsNested::InnerIterator RhsIterator;
00169 typedef typename Lhs::Index Index;
00170 public:
00171
00172 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00173 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
00174 {
00175 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
00176 {
00177 if (m_lhsIter.index() < m_rhsIter.index())
00178 ++m_lhsIter;
00179 else
00180 ++m_rhsIter;
00181 }
00182 }
00183
00184 EIGEN_STRONG_INLINE Derived& operator++()
00185 {
00186 ++m_lhsIter;
00187 ++m_rhsIter;
00188 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
00189 {
00190 if (m_lhsIter.index() < m_rhsIter.index())
00191 ++m_lhsIter;
00192 else
00193 ++m_rhsIter;
00194 }
00195 return *static_cast<Derived*>(this);
00196 }
00197
00198 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
00199
00200 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
00201 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
00202 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
00203
00204 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
00205
00206 protected:
00207 LhsIterator m_lhsIter;
00208 RhsIterator m_rhsIter;
00209 const BinaryFunc& m_functor;
00210 };
00211
00212
00213 template<typename T, typename Lhs, typename Rhs, typename Derived>
00214 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense>
00215 {
00216 typedef scalar_product_op<T> BinaryFunc;
00217 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00218 typedef typename CwiseBinaryXpr::Scalar Scalar;
00219 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
00220 typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested;
00221 typedef typename _LhsNested::InnerIterator LhsIterator;
00222 typedef typename Lhs::Index Index;
00223 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
00224 public:
00225
00226 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00227 : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer)
00228 {}
00229
00230 EIGEN_STRONG_INLINE Derived& operator++()
00231 {
00232 ++m_lhsIter;
00233 return *static_cast<Derived*>(this);
00234 }
00235
00236 EIGEN_STRONG_INLINE Scalar value() const
00237 { return m_functor(m_lhsIter.value(),
00238 m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
00239
00240 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
00241 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
00242 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
00243
00244 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
00245
00246 protected:
00247 RhsNested m_rhs;
00248 LhsIterator m_lhsIter;
00249 const BinaryFunc m_functor;
00250 const Index m_outer;
00251 };
00252
00253
00254 template<typename T, typename Lhs, typename Rhs, typename Derived>
00255 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse>
00256 {
00257 typedef scalar_product_op<T> BinaryFunc;
00258 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
00259 typedef typename CwiseBinaryXpr::Scalar Scalar;
00260 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
00261 typedef typename _RhsNested::InnerIterator RhsIterator;
00262 typedef typename Lhs::Index Index;
00263
00264 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
00265 public:
00266
00267 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
00268 : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer)
00269 {}
00270
00271 EIGEN_STRONG_INLINE Derived& operator++()
00272 {
00273 ++m_rhsIter;
00274 return *static_cast<Derived*>(this);
00275 }
00276
00277 EIGEN_STRONG_INLINE Scalar value() const
00278 { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
00279
00280 EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
00281 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
00282 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
00283
00284 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
00285
00286 protected:
00287 const CwiseBinaryXpr& m_xpr;
00288 RhsIterator m_rhsIter;
00289 const BinaryFunc& m_functor;
00290 const Index m_outer;
00291 };
00292
00293 }
00294
00295
00296
00297
00298
00299 template<typename Derived>
00300 template<typename OtherDerived>
00301 EIGEN_STRONG_INLINE Derived &
00302 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other)
00303 {
00304 return derived() = derived() - other.derived();
00305 }
00306
00307 template<typename Derived>
00308 template<typename OtherDerived>
00309 EIGEN_STRONG_INLINE Derived &
00310 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other)
00311 {
00312 return derived() = derived() + other.derived();
00313 }
00314
00315 template<typename Derived>
00316 template<typename OtherDerived>
00317 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
00318 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
00319 {
00320 return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
00321 }
00322
00323 }
00324
00325 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H