00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef KRONECKER_TENSOR_PRODUCT_H
00013 #define KRONECKER_TENSOR_PRODUCT_H
00014
00015 namespace Eigen {
00016
00017 template<typename Scalar, int Options, typename Index> class SparseMatrix;
00018
00029 template<typename Lhs, typename Rhs>
00030 class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
00031 {
00032 private:
00033 typedef ReturnByValue<KroneckerProduct> Base;
00034 typedef typename Base::Scalar Scalar;
00035 typedef typename Base::Index Index;
00036
00037 public:
00039 KroneckerProduct(const Lhs& A, const Rhs& B)
00040 : m_A(A), m_B(B)
00041 {}
00042
00044 template<typename Dest> void evalTo(Dest& dst) const;
00045
00046 inline Index rows() const { return m_A.rows() * m_B.rows(); }
00047 inline Index cols() const { return m_A.cols() * m_B.cols(); }
00048
00049 Scalar coeff(Index row, Index col) const
00050 {
00051 return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
00052 m_B.coeff(row % m_B.rows(), col % m_B.cols());
00053 }
00054
00055 Scalar coeff(Index i) const
00056 {
00057 EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
00058 return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
00059 }
00060
00061 private:
00062 typename Lhs::Nested m_A;
00063 typename Rhs::Nested m_B;
00064 };
00065
00079 template<typename Lhs, typename Rhs>
00080 class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
00081 {
00082 private:
00083 typedef typename internal::traits<KroneckerProductSparse>::Index Index;
00084
00085 public:
00087 KroneckerProductSparse(const Lhs& A, const Rhs& B)
00088 : m_A(A), m_B(B)
00089 {}
00090
00092 template<typename Dest> void evalTo(Dest& dst) const;
00093
00094 inline Index rows() const { return m_A.rows() * m_B.rows(); }
00095 inline Index cols() const { return m_A.cols() * m_B.cols(); }
00096
00097 template<typename Scalar, int Options, typename Index>
00098 operator SparseMatrix<Scalar, Options, Index>()
00099 {
00100 SparseMatrix<Scalar, Options, Index> result;
00101 evalTo(result.derived());
00102 return result;
00103 }
00104
00105 private:
00106 typename Lhs::Nested m_A;
00107 typename Rhs::Nested m_B;
00108 };
00109
00110 template<typename Lhs, typename Rhs>
00111 template<typename Dest>
00112 void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
00113 {
00114 const int BlockRows = Rhs::RowsAtCompileTime,
00115 BlockCols = Rhs::ColsAtCompileTime;
00116 const Index Br = m_B.rows(),
00117 Bc = m_B.cols();
00118 for (Index i=0; i < m_A.rows(); ++i)
00119 for (Index j=0; j < m_A.cols(); ++j)
00120 Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B;
00121 }
00122
00123 template<typename Lhs, typename Rhs>
00124 template<typename Dest>
00125 void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
00126 {
00127 const Index Br = m_B.rows(),
00128 Bc = m_B.cols();
00129 dst.resize(rows(),cols());
00130 dst.resizeNonZeros(0);
00131 dst.reserve(m_A.nonZeros() * m_B.nonZeros());
00132
00133 for (Index kA=0; kA < m_A.outerSize(); ++kA)
00134 {
00135 for (Index kB=0; kB < m_B.outerSize(); ++kB)
00136 {
00137 for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
00138 {
00139 for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
00140 {
00141 const Index i = itA.row() * Br + itB.row(),
00142 j = itA.col() * Bc + itB.col();
00143 dst.insert(i,j) = itA.value() * itB.value();
00144 }
00145 }
00146 }
00147 }
00148 }
00149
00150 namespace internal {
00151
00152 template<typename _Lhs, typename _Rhs>
00153 struct traits<KroneckerProduct<_Lhs,_Rhs> >
00154 {
00155 typedef typename remove_all<_Lhs>::type Lhs;
00156 typedef typename remove_all<_Rhs>::type Rhs;
00157 typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
00158
00159 enum {
00160 Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
00161 Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
00162 MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
00163 MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
00164 CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
00165 };
00166
00167 typedef Matrix<Scalar,Rows,Cols> ReturnType;
00168 };
00169
00170 template<typename _Lhs, typename _Rhs>
00171 struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
00172 {
00173 typedef MatrixXpr XprKind;
00174 typedef typename remove_all<_Lhs>::type Lhs;
00175 typedef typename remove_all<_Rhs>::type Rhs;
00176 typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
00177 typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind;
00178 typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
00179
00180 enum {
00181 LhsFlags = Lhs::Flags,
00182 RhsFlags = Rhs::Flags,
00183
00184 RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
00185 ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
00186 MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
00187 MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
00188
00189 EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
00190 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
00191
00192 Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
00193 | EvalBeforeNestingBit | EvalBeforeAssigningBit,
00194 CoeffReadCost = Dynamic
00195 };
00196 };
00197
00198 }
00199
00219 template<typename A, typename B>
00220 KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b)
00221 {
00222 return KroneckerProduct<A, B>(a.derived(), b.derived());
00223 }
00224
00236 template<typename A, typename B>
00237 KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b)
00238 {
00239 return KroneckerProductSparse<A,B>(a.derived(), b.derived());
00240 }
00241
00242 }
00243
00244 #endif // KRONECKER_TENSOR_PRODUCT_H