10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 24 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
49 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
55 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
61 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename Device_>
74 template<
typename Indices,
typename LhsXprType,
typename RhsXprType>
86 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims)
87 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
90 const Indices&
indices()
const {
return m_indices; }
108 template<
typename Derived>
139 static const int LDims =
141 static const int RDims =
144 static const int NumDims = LDims + RDims - 2 * ContractDims;
155 op.lhsExpression(), op.rhsExpression()), device),
157 op.rhsExpression(), op.lhsExpression()), device),
162 YOU_MADE_A_PROGRAMMING_MISTAKE);
168 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
170 for (
int i = 0; i < LDims; i++) {
171 eval_left_dims[i] = m_leftImpl.dimensions()[i];
173 for (
int i = 0; i < RDims; i++) {
174 eval_right_dims[i] = m_rightImpl.dimensions()[i];
177 for (
int i = 0; i < ContractDims; i++) {
178 eval_op_indices[i].first = op.
indices()[i].first;
179 eval_op_indices[i].second = op.
indices()[i].second;
183 for (
int i = 0; i < LDims; i++) {
184 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
186 for (
int i = 0; i < RDims; i++) {
187 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
191 for (
int i = 0; i < ContractDims; i++) {
192 eval_op_indices[i].first = LDims - 1 - op.
indices()[ContractDims - 1 - i].second;
193 eval_op_indices[i].second = RDims - 1 - op.
indices()[ContractDims - 1 - i].first;
199 for (
int i = 0; i < ContractDims; i++) {
200 for (
int j = i + 1; j < ContractDims; j++) {
201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
202 eval_op_indices[j].second != eval_op_indices[i].second &&
203 "contraction axes should be unique");
204 if (eval_op_indices[j].first < eval_op_indices[i].first) {
212 for (
int i = 0; i < LDims-1; ++i) {
213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
218 for (
int i = 0; i < RDims-1; ++i) {
219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
222 if (m_i_strides.size() > 0) m_i_strides[0] = 1;
223 if (m_j_strides.size() > 0) m_j_strides[0] = 1;
224 if (m_k_strides.size() > 0) m_k_strides[0] = 1;
234 m_lhs_inner_dim_contiguous =
true;
236 unsigned int nocontract_idx = 0;
238 for (
int i = 0; i < LDims; i++) {
240 bool contracting =
false;
241 for (
int j = 0; j < ContractDims; j++) {
242 if (eval_op_indices[j].first == i) {
249 m_dimensions[dim_idx] = eval_left_dims[i];
250 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
252 m_lhs_inner_dim_contiguous =
false;
255 m_i_strides[nocontract_idx+1] =
256 m_i_strides[nocontract_idx] * eval_left_dims[i];
258 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
266 for (
int i = 0; i < RDims; i++) {
267 bool contracting =
false;
269 for (
int j = 0; j < ContractDims; j++) {
270 if (eval_op_indices[j].second == i) {
276 m_dimensions[dim_idx] = eval_right_dims[i];
278 m_j_strides[nocontract_idx+1] =
279 m_j_strides[nocontract_idx] * eval_right_dims[i];
281 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
283 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
294 m_rhs_inner_dim_contiguous =
true;
295 m_rhs_inner_dim_reordered =
false;
296 for (
int i = 0; i < ContractDims; i++) {
297 Index left = eval_op_indices[i].first;
298 Index right = eval_op_indices[i].second;
300 Index
size = eval_left_dims[left];
302 "Contraction axes must be same size");
305 m_k_strides[i+1] = m_k_strides[i] *
size;
307 m_k_size = m_k_strides[i] *
size;
309 m_left_contracting_strides[i] = lhs_strides[left];
310 m_right_contracting_strides[i] = rhs_strides[right];
312 if (i > 0 && right < eval_op_indices[i-1].second) {
313 m_rhs_inner_dim_reordered =
true;
316 m_rhs_inner_dim_contiguous =
false;
321 if (static_cast<int>(Layout) == static_cast<int>(
RowMajor)) {
322 for (
int i = 0, j = NumDims - 1; i < j; i++, j--) {
331 m_leftImpl.evalSubExprsIfNeeded(NULL);
332 m_rightImpl.evalSubExprsIfNeeded(NULL);
337 m_result =
static_cast<Scalar *
>(m_device.allocate(dimensions().TotalSize() *
sizeof(Scalar)));
343 EIGEN_DEVICE_FUNC
void evalTo(Scalar* buffer)
const {
344 if (this->m_lhs_inner_dim_contiguous) {
345 if (this->m_rhs_inner_dim_contiguous) {
346 if (this->m_rhs_inner_dim_reordered) {
347 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, true, Unaligned>(buffer);
350 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, false, Unaligned>(buffer);
354 if (this->m_rhs_inner_dim_reordered) {
355 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, true, Unaligned>(buffer);
358 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, false, Unaligned>(buffer);
363 if (this->m_rhs_inner_dim_contiguous) {
364 if (this->m_rhs_inner_dim_reordered) {
365 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, true, Unaligned>(buffer);
368 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, false, Unaligned>(buffer);
372 if (this->m_rhs_inner_dim_reordered) {
373 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, true, Unaligned>(buffer);
376 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, false, Unaligned>(buffer);
382 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
383 EIGEN_DEVICE_FUNC
void evalGemv(Scalar* buffer)
const {
384 const Index rows = m_i_size;
385 const Index cols = m_k_size;
396 LeftEvaluator, left_nocontract_t,
397 contract_t, lhs_packet_size,
398 lhs_inner_dim_contiguous,
399 false, lhs_alignment> LhsMapper;
402 RightEvaluator, right_nocontract_t,
403 contract_t, rhs_packet_size,
404 rhs_inner_dim_contiguous,
405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
407 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
408 m_left_contracting_strides, m_k_strides);
409 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
410 m_right_contracting_strides, m_k_strides);
412 const Scalar alpha(1);
413 const Index resIncr(1);
416 m_device.memset(buffer, 0, rows *
sizeof(Scalar));
419 rows, cols, lhs, rhs,
420 buffer, resIncr, alpha);
423 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
424 EIGEN_DEVICE_FUNC
void evalGemm(Scalar* buffer)
const {
426 const Index k = this->m_k_size;
429 const Index m = this->m_i_size;
432 const Index n = this->m_j_size;
435 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
442 const Index nr = Traits::nr;
443 const Index mr = Traits::mr;
452 LeftEvaluator, left_nocontract_t,
453 contract_t, lhs_packet_size,
454 lhs_inner_dim_contiguous,
458 RightEvaluator, right_nocontract_t,
459 contract_t, rhs_packet_size,
460 rhs_inner_dim_contiguous,
461 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
472 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
473 this->m_left_contracting_strides, this->m_k_strides);
475 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
476 this->m_right_contracting_strides, this->m_k_strides);
478 OutputMapper output(buffer, m);
482 const Index kc = blocking.
kc();
483 const Index mc = numext::mini(m, blocking.
mc());
484 const Index nc = numext::mini(n, blocking.
nc());
485 const Index sizeA = mc * kc;
486 const Index sizeB = kc * nc;
488 LhsScalar* blockA =
static_cast<LhsScalar *
>(this->m_device.allocate(sizeA *
sizeof(LhsScalar)));
489 RhsScalar* blockB =
static_cast<RhsScalar *
>(this->m_device.allocate(sizeB *
sizeof(RhsScalar)));
491 for(Index i2=0; i2<m; i2+=mc)
493 const Index actual_mc = numext::mini(i2+mc,m)-i2;
494 for (Index k2 = 0; k2 < k; k2 += kc) {
496 const Index actual_kc = numext::mini(k2 + kc, k) - k2;
497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
500 for (Index j2 = 0; j2 < n; j2 += nc) {
502 const Index actual_nc = numext::mini(j2 + nc, n) - j2;
503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
512 this->m_device.deallocate(blockA);
513 this->m_device.deallocate(blockB);
517 m_leftImpl.cleanup();
518 m_rightImpl.cleanup();
520 if (m_result != NULL) {
521 m_device.deallocate(m_result);
527 return m_result[index];
534 template<
int LoadMode>
536 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
571 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename Device>
574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
597 static const int LDims =
599 static const int RDims =
607 static const int NumDims = LDims + RDims - 2 * ContractDims;
615 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
617 if (this->m_j_size == 1) {
618 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
622 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
628 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H EIGEN_DEVICE_FUNC TensorEvaluator(const XprType &op, const Device &device)
internal::traits< Derived >::Indices Indices
array< Index, ContractDims > contract_t
XprType::CoeffReturnType CoeffReturnType
#define EIGEN_STRONG_INLINE
EIGEN_DEVICE_FUNC void evalGemv(Scalar *buffer) const
RightArgType_ RightArgType
TensorContractionOp< Indices, LeftArgType, RightArgType > XprType
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType &lhs, const RhsXprType &rhs, const Indices &dims)
EIGEN_DEVICE_FUNC const internal::remove_all< typename RhsXprType::Nested >::type & rhsExpression() const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
promote_storage_type< typename traits< LhsXprType >::StorageKind, typename traits< RhsXprType >::StorageKind >::ret StorageKind
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const T1 & choose(Cond< true >, const T1 &first, const T2 &)
internal::traits< Derived >::RightArgType RightArgType
array< Index, ContractDims > contract_t
EIGEN_DEVICE_FUNC void evalProduct(Scalar *buffer) const
internal::gebp_traits< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
bool m_rhs_inner_dim_contiguous
bool m_rhs_inner_dim_reordered
EIGEN_DEVICE_FUNC void evalTo(Scalar *buffer) const
remove_reference< LhsNested >::type _LhsNested
static constexpr size_t size(Tuple< Args... > &)
Provides access to the number of elements in a tuple as a compile-time constant expression.
A cost model used to limit the number of threads used for evaluating tensor expression.
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const
contract_t m_left_contracting_strides
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
left_nocontract_t m_left_nocontract_strides
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const
contract_t m_right_contracting_strides
internal::remove_const< typename XprType::Scalar >::type Scalar
RhsXprType::Nested RhsNested
EIGEN_DEVICE_FUNC void evalGemm(Scalar *buffer) const
LhsXprType::Nested m_lhs_xpr
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType &op, const Device &device)
internal::traits< Derived >::LeftArgType LeftArgType
bool m_lhs_inner_dim_contiguous
Eigen::internal::traits< TensorContractionOp >::Index Index
DSizes< Index, NumDims > Dimensions
DSizes< Index, NumDims > Dimensions
EIGEN_STRONG_INLINE void swap(T &a, T &b)
TensorContractionOp< Indices, LeftArgType, RightArgType > XprType
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
promote_index_type< typename traits< LhsXprType >::Index, typename traits< RhsXprType >::Index >::type Index
Eigen::internal::traits< TensorContractionOp >::StorageKind StorageKind
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const
EIGEN_DEVICE_FUNC const Indices & indices() const
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
array< Index, RDims-ContractDims > right_nocontract_t
array< Index, RDims-ContractDims > right_nocontract_t
left_nocontract_t m_i_strides
array< Index, LDims-ContractDims > left_nocontract_t
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Eigen::internal::nested< TensorContractionOp >::type Nested
LhsXprType::Nested LhsNested
PacketType< CoeffReturnType, Device >::type PacketReturnType
TensorContractionEvaluatorBase< Self > Base
PacketType< CoeffReturnType, Device >::type PacketReturnType
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
gebp_traits< typename remove_const< typename LhsXprType::Scalar >::type, typename remove_const< typename RhsXprType::Scalar >::type >::ResScalar Scalar
remove_reference< RhsNested >::type _RhsNested
right_nocontract_t m_right_nocontract_strides
XprType::CoeffReturnType CoeffReturnType
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
internal::traits< Derived >::Device Device
TensorContractionOp< Dimensions, LhsXprType, RhsXprType > type
RhsXprType::Nested m_rhs_xpr
internal::remove_const< typename XprType::Scalar >::type Scalar
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType >, Device > Self
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data() const
Eigen::internal::traits< TensorContractionOp >::Scalar Scalar
const TensorContractionOp< Dimensions, LhsXprType, RhsXprType > & type
EIGEN_DEVICE_FUNC const internal::remove_all< typename LhsXprType::Nested >::type & lhsExpression() const
TensorEvaluator< EvalRightArgType, Device > m_rightImpl
right_nocontract_t m_j_strides
void run(Expr &expr, Dev &dev)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup()
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
array< Index, LDims-ContractDims > left_nocontract_t
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar *data)
internal::packet_traits< Scalar >::type type
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const
TensorEvaluator< EvalLeftArgType, Device > m_leftImpl