KroneckerTensorProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
5 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
6 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef KRONECKER_TENSOR_PRODUCT_H
13 #define KRONECKER_TENSOR_PRODUCT_H
14 
15 namespace Eigen {
16 
24 template<typename Derived>
25 class KroneckerProductBase : public ReturnByValue<Derived>
26 {
27  private:
29  typedef typename Traits::Scalar Scalar;
30 
31  protected:
32  typedef typename Traits::Lhs Lhs;
33  typedef typename Traits::Rhs Rhs;
34 
35  public:
37  KroneckerProductBase(const Lhs& A, const Rhs& B)
38  : m_A(A), m_B(B)
39  {}
40 
41  inline Index rows() const { return m_A.rows() * m_B.rows(); }
42  inline Index cols() const { return m_A.cols() * m_B.cols(); }
43 
48  Scalar coeff(Index row, Index col) const
49  {
50  return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
51  m_B.coeff(row % m_B.rows(), col % m_B.cols());
52  }
53 
58  Scalar coeff(Index i) const
59  {
61  return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
62  }
63 
64  protected:
65  typename Lhs::Nested m_A;
66  typename Rhs::Nested m_B;
67 };
68 
81 template<typename Lhs, typename Rhs>
82 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
83 {
84  private:
86  using Base::m_A;
87  using Base::m_B;
88 
89  public:
91  KroneckerProduct(const Lhs& A, const Rhs& B)
92  : Base(A, B)
93  {}
94 
96  template<typename Dest> void evalTo(Dest& dst) const;
97 };
98 
114 template<typename Lhs, typename Rhs>
115 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
116 {
117  private:
119  using Base::m_A;
120  using Base::m_B;
121 
122  public:
124  KroneckerProductSparse(const Lhs& A, const Rhs& B)
125  : Base(A, B)
126  {}
127 
129  template<typename Dest> void evalTo(Dest& dst) const;
130 };
131 
132 template<typename Lhs, typename Rhs>
133 template<typename Dest>
135 {
136  const int BlockRows = Rhs::RowsAtCompileTime,
137  BlockCols = Rhs::ColsAtCompileTime;
138  const Index Br = m_B.rows(),
139  Bc = m_B.cols();
140  for (Index i=0; i < m_A.rows(); ++i)
141  for (Index j=0; j < m_A.cols(); ++j)
142  Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B;
143 }
144 
145 template<typename Lhs, typename Rhs>
146 template<typename Dest>
148 {
149  Index Br = m_B.rows(), Bc = m_B.cols();
150  dst.resize(this->rows(), this->cols());
151  dst.resizeNonZeros(0);
152 
153  // 1 - evaluate the operands if needed:
154  typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
155  typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
156  const Lhs1 lhs1(m_A);
157  typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
158  typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
159  const Rhs1 rhs1(m_B);
160 
161  // 2 - construct respective iterators
162  typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
163  typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
164 
165  // compute number of non-zeros per innervectors of dst
166  {
167  // TODO VectorXi is not necessarily big enough!
168  VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
169  for (Index kA=0; kA < m_A.outerSize(); ++kA)
170  for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
171  nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
172 
173  VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
174  for (Index kB=0; kB < m_B.outerSize(); ++kB)
175  for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
176  nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
177 
178  Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
179  dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
180  }
181 
182  for (Index kA=0; kA < m_A.outerSize(); ++kA)
183  {
184  for (Index kB=0; kB < m_B.outerSize(); ++kB)
185  {
186  for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
187  {
188  for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
189  {
190  Index i = itA.row() * Br + itB.row(),
191  j = itA.col() * Bc + itB.col();
192  dst.insert(i,j) = itA.value() * itB.value();
193  }
194  }
195  }
196  }
197 }
198 
199 namespace internal {
200 
201 template<typename _Lhs, typename _Rhs>
202 struct traits<KroneckerProduct<_Lhs,_Rhs> >
203 {
204  typedef typename remove_all<_Lhs>::type Lhs;
205  typedef typename remove_all<_Rhs>::type Rhs;
208 
209  enum {
214  };
215 
217 };
218 
219 template<typename _Lhs, typename _Rhs>
220 struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
221 {
223  typedef typename remove_all<_Lhs>::type Lhs;
224  typedef typename remove_all<_Rhs>::type Rhs;
228 
229  enum {
230  LhsFlags = Lhs::Flags,
231  RhsFlags = Rhs::Flags,
232 
237 
238  EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
239  RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
240 
241  Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
243  CoeffReadCost = HugeCost
244  };
245 
247 };
248 
249 } // end namespace internal
250 
270 template<typename A, typename B>
272 {
273  return KroneckerProduct<A, B>(a.derived(), b.derived());
274 }
275 
297 template<typename A, typename B>
299 {
301 }
302 
303 } // end namespace Eigen
304 
305 #endif // KRONECKER_TENSOR_PRODUCT_H
promote_index_type< typename Lhs::StorageIndex, typename Rhs::StorageIndex >::type StorageIndex
ScalarBinaryOpTraits< typename Lhs::Scalar, typename Rhs::Scalar >::ReturnType Scalar
const int HugeCost
Definition: Constants.h:39
KroneckerProductBase< KroneckerProductSparse > Base
A versatible sparse matrix representation.
Definition: SparseMatrix.h:96
internal::traits< Derived > Traits
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar * data() const
Definition: LDLT.h:16
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
EIGEN_DEVICE_FUNC void evalTo(Dest &dst) const
Definition: ReturnByValue.h:61
const unsigned int RowMajorBit
Definition: Constants.h:61
KroneckerProductBase< KroneckerProduct > Base
ScalarBinaryOpTraits< typename Lhs::Scalar, typename Rhs::Scalar >::ReturnType Scalar
const unsigned int HereditaryBits
Definition: Constants.h:190
EIGEN_DEVICE_FUNC ColXpr col(Index i)
This is the const version of col().
Definition: BlockMethods.h:838
KroneckerProductSparse(const Lhs &A, const Rhs &B)
Constructor.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33
EIGEN_DEVICE_FUNC RowXpr row(Index i)
This is the const version of row(). */.
Definition: BlockMethods.h:859
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
promote_index_type< typename Lhs::StorageIndex, typename Rhs::StorageIndex >::type StorageIndex
KroneckerProductBase(const Lhs &A, const Rhs &B)
Constructor.
cwise_promote_storage_type< typename traits< Lhs >::StorageKind, typename traits< Rhs >::StorageKind, scalar_product_op< typename Lhs::Scalar, typename Rhs::Scalar > >::ret StorageKind
Expression of a fixed-size or dynamic-size block.
Definition: Block.h:103
Scalar coeff(Index row, Index col) const
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:766
KroneckerProduct(const Lhs &A, const Rhs &B)
Constructor.
KroneckerProduct< A, B > kroneckerProduct(const MatrixBase< A > &a, const MatrixBase< B > &b)
const unsigned int EvalBeforeNestingBit
Definition: Constants.h:65
The matrix class, also used for vectors and row-vectors.
Definition: Matrix.h:178
#define EIGEN_STATIC_ASSERT_VECTOR_ONLY(TYPE)
Definition: StaticAssert.h:137
EIGEN_DEVICE_FUNC const Scalar & b
EIGEN_DEVICE_FUNC Derived & derived()
Definition: EigenBase.h:45
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48
Kronecker tensor product helper class for dense matrices.
The base class of dense and sparse Kronecker product.
Kronecker tensor product helper class for sparse matrices.
An InnerIterator allows to loop over the element of any matrix expression.
Definition: CoreIterators.h:33


hebiros
Author(s): Xavier Artache , Matthew Tesch
autogenerated on Thu Sep 3 2020 04:08:19