10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 24 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
53 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
59 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
65 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename OutputKernelType_,
typename Device_>
78 template <
typename LhsScalar,
typename RhsScalar>
82 template <
typename Device>
86 LhsScalar** lhs_block,
87 RhsScalar** rhs_block) {
90 BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
93 *lhs_block =
reinterpret_cast<LhsScalar*
>(block_mem);
94 *rhs_block =
reinterpret_cast<RhsScalar*
>(block_mem + sz.
lhs_size);
98 template <
typename Device>
102 std::vector<LhsScalar*>* lhs_blocks,
103 std::vector<RhsScalar*>* rhs_blocks) {
108 BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
109 void* block_mem = d.allocate(
112 char* mem =
static_cast<char*
>(block_mem);
114 for (
Index x = 0;
x < num_slices;
x++) {
115 if (num_lhs > 0) lhs_blocks[
x].resize(num_lhs);
116 for (
Index m = 0;
m < num_lhs;
m++) {
117 lhs_blocks[
x][
m] =
reinterpret_cast<LhsScalar*
>(mem);
120 if (num_rhs > 0) rhs_blocks[
x].resize(num_rhs);
121 for (
Index n = 0;
n < num_rhs;
n++) {
122 rhs_blocks[
x][
n] =
reinterpret_cast<RhsScalar*
>(mem);
130 template <
typename Device>
132 d.deallocate(handle);
145 sz.
lhs_size = divup<Index>(bm * bk *
sizeof(LhsScalar), align) *
align;
146 sz.
rhs_size = divup<Index>(bn * bk *
sizeof(RhsScalar), align) *
align;
179 template <
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
180 typename StorageIndex,
typename OutputMapper,
typename LhsMapper,
185 enum { HasBeta =
false };
189 StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
190 :
m(m_), k(k_),
n(n_), bm(bm_), bk(bk_), bn(bn_) {}
204 LhsScalar, StorageIndex,
typename LhsMapper::SubMapper, Traits::mr,
209 typename RhsMapper::SubMapper, Traits::nr,
214 OutputMapper, Traits::mr, Traits::nr,
218 template <
typename Device>
220 RhsBlock* rhs_block) {
221 return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block);
224 template <
typename Device>
226 Device&
d,
const StorageIndex num_lhs,
const StorageIndex num_rhs,
227 const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
228 std::vector<RhsBlock>* rhs_blocks) {
229 return BlockMemAllocator::allocateSlices(
230 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks);
233 template <
typename Device>
235 BlockMemAllocator::deallocate(d, handle);
239 LhsBlock* lhsBlock,
const typename LhsMapper::SubMapper& data_mapper,
240 const StorageIndex
depth,
const StorageIndex
rows) {
246 RhsBlock* rhsBlock,
const typename RhsMapper::SubMapper& data_mapper,
247 const StorageIndex
depth,
const StorageIndex
cols) {
252 const OutputMapper& output_mapper,
const LhsBlock& lhsBlock,
253 const RhsBlock& rhsBlock,
const StorageIndex
rows,
254 const StorageIndex
depth,
const StorageIndex
cols,
255 const ResScalar
alpha,
const ResScalar beta) {
258 static const int kComputeStrideFromBlockDimensions = -1;
260 kComputeStrideFromBlockDimensions,
261 kComputeStrideFromBlockDimensions,
269 const StorageIndex
m;
270 const StorageIndex
k;
271 const StorageIndex
n;
272 const StorageIndex
bm;
273 const StorageIndex
bk;
274 const StorageIndex
bn;
310 template <
typename Index,
typename Scalar>
324 template<
typename Indices,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType = const NoOpOutputKernel>
336 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims,
337 const OutputKernelType& output_kernel = OutputKernelType())
338 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims),
339 m_output_kernel(output_kernel) {}
354 const OutputKernelType&
outputKernel()
const {
return m_output_kernel; }
364 template<
typename Derived>
385 PreferBlockAccess =
false,
407 static const int LDims =
409 static const int RDims =
412 static const int NumDims = LDims + RDims - 2 * ContractDims;
423 op.lhsExpression(), op.rhsExpression()), device),
425 op.rhsExpression(), op.lhsExpression()), device),
427 m_output_kernel(op.outputKernel()),
431 YOU_MADE_A_PROGRAMMING_MISTAKE);
437 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
439 for (
int i = 0;
i < LDims;
i++) {
440 eval_left_dims[
i] = m_leftImpl.dimensions()[
i];
442 for (
int i = 0;
i < RDims;
i++) {
443 eval_right_dims[
i] = m_rightImpl.dimensions()[
i];
446 for (
int i = 0;
i < ContractDims;
i++) {
447 eval_op_indices[
i].first = op.
indices()[
i].first;
448 eval_op_indices[
i].second = op.
indices()[
i].second;
452 for (
int i = 0;
i < LDims;
i++) {
453 eval_left_dims[
i] = m_leftImpl.dimensions()[LDims -
i - 1];
455 for (
int i = 0;
i < RDims;
i++) {
456 eval_right_dims[
i] = m_rightImpl.dimensions()[RDims -
i - 1];
460 for (
int i = 0;
i < ContractDims;
i++) {
461 eval_op_indices[
i].first = LDims - 1 - op.
indices()[ContractDims - 1 -
i].second;
462 eval_op_indices[
i].second = RDims - 1 - op.
indices()[ContractDims - 1 -
i].first;
468 for (
int i = 0;
i < ContractDims;
i++) {
469 for (
int j =
i + 1;
j < ContractDims;
j++) {
471 eval_op_indices[
j].second != eval_op_indices[
i].second &&
472 "contraction axes should be unique");
473 if (eval_op_indices[
j].
first < eval_op_indices[
i].
first) {
481 for (
int i = 0;
i < LDims-1; ++
i) {
482 lhs_strides[
i+1] = lhs_strides[
i] * eval_left_dims[
i];
487 for (
int i = 0;
i < RDims-1; ++
i) {
488 rhs_strides[
i+1] = rhs_strides[
i] * eval_right_dims[
i];
491 if (m_i_strides.size() > 0) m_i_strides[0] = 1;
492 if (m_j_strides.size() > 0) m_j_strides[0] = 1;
493 if (m_k_strides.size() > 0) m_k_strides[0] = 1;
503 m_lhs_inner_dim_contiguous =
true;
505 Index nocontract_idx = 0;
507 for (
int i = 0;
i < LDims;
i++) {
509 bool contracting =
false;
510 for (
int j = 0;
j < ContractDims;
j++) {
511 if (eval_op_indices[
j].
first ==
i) {
518 m_dimensions[dim_idx] = eval_left_dims[
i];
519 m_left_nocontract_strides[nocontract_idx] = lhs_strides[
i];
521 m_lhs_inner_dim_contiguous =
false;
524 m_i_strides[nocontract_idx+1] =
525 m_i_strides[nocontract_idx] * eval_left_dims[
i];
527 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[
i];
535 for (
int i = 0;
i < RDims;
i++) {
536 bool contracting =
false;
538 for (
int j = 0;
j < ContractDims;
j++) {
539 if (eval_op_indices[
j].second ==
i) {
545 m_dimensions[dim_idx] = eval_right_dims[
i];
547 m_j_strides[nocontract_idx+1] =
548 m_j_strides[nocontract_idx] * eval_right_dims[
i];
550 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[
i];
552 m_right_nocontract_strides[nocontract_idx] = rhs_strides[
i];
563 m_rhs_inner_dim_contiguous =
true;
564 m_rhs_inner_dim_reordered =
false;
565 for (
int i = 0;
i < ContractDims;
i++) {
566 Index
left = eval_op_indices[
i].first;
567 Index
right = eval_op_indices[
i].second;
571 "Contraction axes must be same size");
574 m_k_strides[
i+1] = m_k_strides[
i] *
size;
576 m_k_size = m_k_strides[
i] *
size;
578 m_left_contracting_strides[
i] = lhs_strides[
left];
579 m_right_contracting_strides[
i] = rhs_strides[
right];
581 if (
i > 0 && right < eval_op_indices[
i-1].second) {
582 m_rhs_inner_dim_reordered =
true;
585 m_rhs_inner_dim_contiguous =
false;
590 if (static_cast<int>(Layout) == static_cast<int>(
RowMajor)) {
591 for (
int i = 0,
j = NumDims - 1;
i <
j;
i++, j--) {
600 m_tensor_contraction_params.swapped_arguments =
static_cast<int>(Layout) ==
RowMajor;
606 m_leftImpl.evalSubExprsIfNeeded(
NULL);
607 m_rightImpl.evalSubExprsIfNeeded(
NULL);
612 m_result =
static_cast<EvaluatorPointerType
>(m_device.allocate(
dimensions().TotalSize() *
sizeof(Scalar)));
618 #ifdef EIGEN_USE_THREADS 619 template <
typename EvalSubExprsCallback>
621 EvaluatorPointerType dest, EvalSubExprsCallback done) {
622 m_leftImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
623 m_rightImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
625 evalToAsync(dest, [done]() { done(
false); });
627 m_result =
static_cast<EvaluatorPointerType
>(
628 m_device.allocate(
dimensions().TotalSize() *
sizeof(Scalar)));
629 evalToAsync(m_result, [done]() { done(
true); });
634 #endif // EIGEN_USE_THREADS 636 #ifndef TENSOR_CONTRACTION_DISPATCH 637 #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ 638 if (this->m_lhs_inner_dim_contiguous) { \ 639 if (this->m_rhs_inner_dim_contiguous) { \ 640 if (this->m_rhs_inner_dim_reordered) { \ 641 METHOD<true, true, true, ALIGNMENT> ARGS; \ 643 METHOD<true, true, false, ALIGNMENT> ARGS; \ 646 if (this->m_rhs_inner_dim_reordered) { \ 647 METHOD<true, false, true, ALIGNMENT> ARGS; \ 649 METHOD<true, false, false, ALIGNMENT> ARGS; \ 653 if (this->m_rhs_inner_dim_contiguous) { \ 654 if (this->m_rhs_inner_dim_reordered) { \ 655 METHOD<false, true, true, ALIGNMENT> ARGS; \ 657 METHOD<false, true, false, ALIGNMENT> ARGS; \ 660 if (this->m_rhs_inner_dim_reordered) { \ 661 METHOD<false, false, true, ALIGNMENT> ARGS; \ 663 METHOD<false, false, false, ALIGNMENT> ARGS; \ 669 #ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH 670 #define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \ 671 if (this->m_lhs_inner_dim_contiguous) { \ 672 if (this->m_rhs_inner_dim_contiguous) { \ 673 if (this->m_rhs_inner_dim_reordered) { \ 674 (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \ 676 (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \ 679 if (this->m_rhs_inner_dim_reordered) { \ 680 (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \ 682 (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \ 686 if (this->m_rhs_inner_dim_contiguous) { \ 687 if (this->m_rhs_inner_dim_reordered) { \ 688 (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \ 690 (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \ 693 if (this->m_rhs_inner_dim_reordered) { \ 694 (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \ 696 (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \ 703 static_cast<const Derived*
>(
this)->
template evalProduct<Unaligned>(buffer);
706 #ifdef EIGEN_USE_THREADS 707 template <
typename EvalToCallback>
708 void evalToAsync(Scalar*
buffer, EvalToCallback done)
const {
709 static_cast<const Derived*
>(
this)
710 ->
template evalProductAsync<EvalToCallback, Unaligned>(buffer,
713 #endif // EIGEN_USE_THREADS 715 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
716 bool rhs_inner_dim_reordered,
int Alignment>
718 if (this->m_j_size == 1) {
719 this->
template evalGemv<lhs_inner_dim_contiguous,
720 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
723 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
724 rhs_inner_dim_reordered, Alignment>(buffer);
728 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
729 #if !defined(EIGEN_HIPCC) 733 const Index
rows = m_i_size;
734 const Index
cols = m_k_size;
745 LeftEvaluator, left_nocontract_t,
746 contract_t, lhs_packet_size,
747 lhs_inner_dim_contiguous,
748 false, lhs_alignment> LhsMapper;
751 RightEvaluator, right_nocontract_t,
752 contract_t, rhs_packet_size,
753 rhs_inner_dim_contiguous,
754 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
756 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
757 m_left_contracting_strides, m_k_strides);
758 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
759 m_right_contracting_strides, m_k_strides);
761 const Scalar
alpha(1);
762 const Index resIncr(1);
765 m_device.memset(buffer, 0, rows *
sizeof(Scalar));
768 rows, cols, lhs, rhs,
769 buffer, resIncr, alpha);
772 m_output_kernel(OutputMapper(buffer, rows), m_tensor_contraction_params,
773 static_cast<Index>(0), static_cast<Index>(0), rows,
774 static_cast<Index>(1));
777 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
778 #if !defined(EIGEN_HIPCC) 783 const Index k = this->m_k_size;
784 this->
template evalGemmPartial<lhs_inner_dim_contiguous,
785 rhs_inner_dim_contiguous,
786 rhs_inner_dim_reordered,
787 Alignment,
true>(buffer, 0, k, 1);
790 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
791 bool rhs_inner_dim_reordered,
int Alignment>
793 Scalar* buffer, Index k_start, Index k_end,
int num_threads)
const {
794 evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
795 rhs_inner_dim_reordered, Alignment,
796 false>(buffer, k_start, k_end,
800 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment,
bool use_output_kernel>
802 eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= this->m_k_size);
804 const Index k_slice = k_end - k_start;
807 const Index
m = this->m_i_size;
810 const Index
n = this->m_j_size;
823 LeftEvaluator, left_nocontract_t,
824 contract_t, lhs_packet_size,
825 lhs_inner_dim_contiguous,
829 RightEvaluator, right_nocontract_t,
830 contract_t, rhs_packet_size,
831 rhs_inner_dim_contiguous,
832 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
837 Scalar, LhsScalar, RhsScalar,
Index, OutputMapper, LhsMapper, RhsMapper>
838 TensorContractionKernel;
841 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
842 this->m_left_contracting_strides, this->m_k_strides);
844 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
845 this->m_right_contracting_strides, this->m_k_strides);
847 OutputMapper output(buffer, m);
852 blocking(k_slice, m, n, num_threads);
853 const Index kc = blocking.
kc();
863 TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc);
866 const BlockMemHandle packed_mem =
867 kernel.allocate(this->m_device, &blockA, &blockB);
872 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
875 for(Index i2=0; i2<
m; i2+=mc)
878 for (Index k2 = k_start; k2 < k_end; k2 += kc) {
880 const Index actual_kc =
numext::mini(k2 + kc, k_end) - k2;
881 kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
891 for (Index j2 = 0; j2 <
n; j2 +=
nc) {
894 kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
899 const OutputMapper output_mapper = output.getSubMapper(i2, j2);
900 kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
901 actual_nc, alpha, beta);
904 if (use_output_kernel && k2 + kc >= k_end) {
905 m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2,
906 actual_mc, actual_nc);
912 kernel.deallocate(this->m_device, packed_mem);
916 m_leftImpl.cleanup();
917 m_rightImpl.cleanup();
919 if (m_result !=
NULL) {
920 m_device.deallocate(m_result);
926 return m_result[index];
933 template<
int LoadMode>
935 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
971 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType,
typename Device>
974 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
997 static const int LDims =
999 static const int RDims =
1007 static const int NumDims = LDims + RDims - 2 * ContractDims;
1013 Base(op, device) { }
1015 template <
int Alignment>
1023 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const
internal::traits< Derived >::Indices Indices
OutputKernelType m_output_kernel
array< Index, ContractDims > contract_t
DSizes< Index, NumDims > Dimensions
XprType::CoeffReturnType CoeffReturnType
#define EIGEN_ALWAYS_INLINE
internal::gebp_traits< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
#define EIGEN_STRONG_INLINE
static EIGEN_DEVICE_FUNC BlockSizes ComputeLhsRhsBlockSizes(const Index bm, const Index bk, const Index bn)
EIGEN_DEVICE_FUNC const internal::remove_all< typename LhsXprType::Nested >::type & lhsExpression() const
XprType::CoeffReturnType CoeffReturnType
EIGEN_DEVICE_FUNC void evalGemv(Scalar *buffer) const
EIGEN_DEVICE_FUNC void evalTo(Scalar *buffer) const
LhsXprType::Nested LhsNested
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
EvaluatorPointerType m_result
gebp_traits< typename remove_const< typename LhsXprType::Scalar >::type, typename remove_const< typename RhsXprType::Scalar >::type >::ResScalar Scalar
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(LhsBlock *lhsBlock, const typename LhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex rows)
const Device EIGEN_DEVICE_REF m_device
RightArgType_ RightArgType
remove_reference< RhsNested >::type _RhsNested
void evalProductSequential(Scalar *buffer) const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const T1 & choose(Cond< true >, const T1 &first, const T2 &)
internal::traits< Derived >::RightArgType RightArgType
TensorContractionEvaluatorBase< Self > Base
EIGEN_DEVICE_FUNC const Indices & indices() const
array< Index, ContractDims > contract_t
internal::TensorBlockNotImplemented TensorBlock
Eigen::internal::nested< TensorContractionOp >::type Nested
bool m_rhs_inner_dim_contiguous
bool m_rhs_inner_dim_reordered
Namespace containing all symbols from the Eigen library.
A cost model used to limit the number of threads used for evaluating tensor expression.
contract_t m_left_contracting_strides
remove_reference< LhsNested >::type _LhsNested
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
TensorEvaluator(const XprType &op, const Device &device)
left_nocontract_t m_left_nocontract_strides
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, LhsBlock *lhs_block, RhsBlock *rhs_block)
promote_storage_type< typename traits< LhsXprType >::StorageKind, typename traits< RhsXprType >::StorageKind >::ret StorageKind
EIGEN_CONSTEXPR Index first(const T &x) EIGEN_NOEXCEPT
contract_t m_right_contracting_strides
void evalProduct(Scalar *buffer) const
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data)
internal::remove_const< typename XprType::Scalar >::type Scalar
static const SmartProjectionParams params
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper< Scalar, Index, ColMajor > &output_mapper, const TensorContractionParams ¶ms, Index i, Index j, Index num_rows, Index num_cols) const
#define EIGEN_DONT_INLINE
internal::traits< Derived >::LeftArgType LeftArgType
bool m_lhs_inner_dim_contiguous
const TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > & type
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const
static Similarity3 align(const Point3Pairs &d_abPointPairs, const Rot3 &aRb, const Point3Pair ¢roids)
This method estimates the similarity transform from differences point pairs,.
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
BlockMemAllocator::BlockMemHandle BlockMemHandle
DSizes< Index, NumDims > Dimensions
const OutputKernelType m_output_kernel
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
EIGEN_STRONG_INLINE void swap(T &a, T &b)
EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
RhsXprType::Nested RhsNested
array< Index, RDims - ContractDims > right_nocontract_t
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
RhsXprType::Nested m_rhs_xpr
Eigen::internal::traits< TensorContractionOp >::Index Index
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Device > Self
TensorEvaluator< EvalRightArgType, Device > RightEvaluatorType
array< Index, RDims - ContractDims > right_nocontract_t
TensorContractionBlockMemAllocator< LhsScalar, RhsScalar > BlockMemAllocator
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > type
left_nocontract_t m_i_strides
array< Index, LDims - ContractDims > left_nocontract_t
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(const OutputMapper &output_mapper, const LhsBlock &lhsBlock, const RhsBlock &rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, const ResScalar alpha, const ResScalar beta)
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
internal::gebp_kernel< LhsScalar, RhsScalar, StorageIndex, OutputMapper, Traits::mr, Traits::nr, false, false > GebpKernel
PacketType< CoeffReturnType, Device >::type PacketReturnType
EIGEN_STRONG_INLINE void cleanup()
EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType &op, const Device &device)
EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const StorageIndex num_lhs, const StorageIndex num_rhs, const StorageIndex num_slices, std::vector< LhsBlock > *lhs_blocks, std::vector< RhsBlock > *rhs_blocks)
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
internal::remove_const< typename XprType::Scalar >::type Scalar
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType &lhs, const RhsXprType &rhs, const Indices &dims, const OutputKernelType &output_kernel=OutputKernelType())
promote_index_type< typename traits< LhsXprType >::Index, typename traits< RhsXprType >::Index >::type Index
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const
EIGEN_CONSTEXPR Index size(const T &x)
#define EIGEN_DEVICE_FUNC
Eigen::internal::traits< TensorContractionOp >::StorageKind StorageKind
EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
TensorContractionParams m_tensor_contraction_params
Eigen::internal::traits< TensorContractionOp >::Scalar Scalar
right_nocontract_t m_right_nocontract_strides
Storage::Type EvaluatorPointerType
OutputKernelType_ OutputKernelType
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
EIGEN_DEVICE_FUNC void evalGemm(Scalar *buffer) const
internal::traits< Derived >::Device Device
std::vector< size_t > Indices
StorageMemory< Scalar, Device > Storage
internal::gebp_traits< LhsScalar, RhsScalar > Traits
TensorEvaluator< EvalLeftArgType, Device > LeftEvaluatorType
array< Index, LDims - ContractDims > left_nocontract_t
EIGEN_DEVICE_FUNC const OutputKernelType & outputKernel() const
PacketType< CoeffReturnType, Device >::type PacketReturnType
Generic expression where a coefficient-wise unary operator is applied to an expression.
TensorEvaluator< EvalRightArgType, Device > m_rightImpl
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
right_nocontract_t m_j_strides
static EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, const Index bm, const Index bk, const Index bn, LhsScalar **lhs_block, RhsScalar **rhs_block)
conditional< Pointer_type_promotion< typename LhsXprType::Scalar, Scalar >::val, typename traits< LhsXprType >::PointerType, typename traits< RhsXprType >::PointerType >::type PointerType
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
const std::vector< size_t > dimensions
internal::gemm_pack_rhs< RhsScalar, StorageIndex, typename RhsMapper::SubMapper, Traits::nr, ColMajor > RhsPacker
internal::traits< Derived >::OutputKernelType OutputKernelType
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(RhsBlock *rhsBlock, const typename RhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex cols)
EIGEN_DEVICE_FUNC const internal::remove_all< typename RhsXprType::Nested >::type & rhsExpression() const
#define EIGEN_UNUSED_VARIABLE(var)
TensorEvaluator< EvalLeftArgType, Device > m_leftImpl
internal::gemm_pack_lhs< LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor > LhsPacker
static EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const Index bm, const Index bk, const Index bn, const Index num_lhs, const Index num_rhs, const Index num_slices, std::vector< LhsScalar *> *lhs_blocks, std::vector< RhsScalar *> *rhs_blocks)
LhsXprType::Nested m_lhs_xpr