Go to the documentation of this file.
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
24 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
53 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
59 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
65 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename OutputKernelType_,
typename Device_>
78 template <
typename LhsScalar,
typename RhsScalar>
82 template <
typename Device>
86 LhsScalar** lhs_block,
87 RhsScalar** rhs_block) {
93 *lhs_block =
reinterpret_cast<LhsScalar*
>(block_mem);
94 *rhs_block =
reinterpret_cast<RhsScalar*
>(block_mem + sz.
lhs_size);
98 template <
typename Device>
102 std::vector<LhsScalar*>* lhs_blocks,
103 std::vector<RhsScalar*>* rhs_blocks) {
109 void* block_mem =
d.allocate(
112 char* mem =
static_cast<char*
>(block_mem);
114 for (
Index x = 0;
x < num_slices;
x++) {
115 if (num_lhs > 0) lhs_blocks[
x].resize(num_lhs);
116 for (
Index m = 0;
m < num_lhs;
m++) {
117 lhs_blocks[
x][
m] =
reinterpret_cast<LhsScalar*
>(mem);
120 if (num_rhs > 0) rhs_blocks[
x].resize(num_rhs);
121 for (
Index n = 0;
n < num_rhs;
n++) {
122 rhs_blocks[
x][
n] =
reinterpret_cast<RhsScalar*
>(mem);
130 template <
typename Device>
179 template <
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
180 typename StorageIndex,
typename OutputMapper,
typename LhsMapper,
189 StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
190 :
m(m_),
k(k_),
n(n_),
bm(bm_),
bk(bk_),
bn(bn_) {}
204 LhsScalar, StorageIndex,
typename LhsMapper::SubMapper,
Traits::mr,
218 template <
typename Device>
224 template <
typename Device>
226 Device&
d,
const StorageIndex num_lhs,
const StorageIndex num_rhs,
227 const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
228 std::vector<RhsBlock>* rhs_blocks) {
230 d,
bm,
bk,
bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks);
233 template <
typename Device>
239 LhsBlock* lhsBlock,
const typename LhsMapper::SubMapper& data_mapper,
240 const StorageIndex
depth,
const StorageIndex
rows) {
246 RhsBlock* rhsBlock,
const typename RhsMapper::SubMapper& data_mapper,
247 const StorageIndex
depth,
const StorageIndex
cols) {
252 const OutputMapper& output_mapper,
const LhsBlock& lhsBlock,
254 const StorageIndex
depth,
const StorageIndex
cols,
255 const ResScalar
alpha,
const ResScalar
beta) {
258 static const int kComputeStrideFromBlockDimensions = -1;
260 kComputeStrideFromBlockDimensions,
261 kComputeStrideFromBlockDimensions,
269 const StorageIndex
m;
270 const StorageIndex
k;
271 const StorageIndex
n;
272 const StorageIndex
bm;
273 const StorageIndex
bk;
274 const StorageIndex
bn;
310 template <
typename Index,
typename Scalar>
324 template<
typename Indices,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType = const NoOpOutputKernel>
336 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims,
337 const OutputKernelType& output_kernel = OutputKernelType())
364 template<
typename Derived>
423 op.lhsExpression(), op.rhsExpression()), device),
425 op.rhsExpression(), op.lhsExpression()), device),
431 YOU_MADE_A_PROGRAMMING_MISTAKE);
447 eval_op_indices[
i].first = op.
indices()[
i].first;
448 eval_op_indices[
i].second = op.
indices()[
i].second;
470 eigen_assert(eval_op_indices[
j].first != eval_op_indices[
i].first &&
471 eval_op_indices[
j].second != eval_op_indices[
i].second &&
472 "contraction axes should be unique");
473 if (eval_op_indices[
j].first < eval_op_indices[
i].first) {
482 lhs_strides[
i+1] = lhs_strides[
i] * eval_left_dims[
i];
488 rhs_strides[
i+1] = rhs_strides[
i] * eval_right_dims[
i];
505 Index nocontract_idx = 0;
509 bool contracting =
false;
511 if (eval_op_indices[
j].first ==
i) {
536 bool contracting =
false;
539 if (eval_op_indices[
j].second ==
i) {
571 "Contraction axes must be same size");
581 if (
i > 0 &&
right < eval_op_indices[
i-1].second) {
618 #ifdef EIGEN_USE_THREADS
619 template <
typename EvalSubExprsCallback>
622 m_leftImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
623 m_rightImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
625 evalToAsync(dest, [done]() { done(
false); });
629 evalToAsync(
m_result, [done]() { done(
true); });
634 #endif // EIGEN_USE_THREADS
636 #ifndef TENSOR_CONTRACTION_DISPATCH
637 #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
638 if (this->m_lhs_inner_dim_contiguous) { \
639 if (this->m_rhs_inner_dim_contiguous) { \
640 if (this->m_rhs_inner_dim_reordered) { \
641 METHOD<true, true, true, ALIGNMENT> ARGS; \
643 METHOD<true, true, false, ALIGNMENT> ARGS; \
646 if (this->m_rhs_inner_dim_reordered) { \
647 METHOD<true, false, true, ALIGNMENT> ARGS; \
649 METHOD<true, false, false, ALIGNMENT> ARGS; \
653 if (this->m_rhs_inner_dim_contiguous) { \
654 if (this->m_rhs_inner_dim_reordered) { \
655 METHOD<false, true, true, ALIGNMENT> ARGS; \
657 METHOD<false, true, false, ALIGNMENT> ARGS; \
660 if (this->m_rhs_inner_dim_reordered) { \
661 METHOD<false, false, true, ALIGNMENT> ARGS; \
663 METHOD<false, false, false, ALIGNMENT> ARGS; \
669 #ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
670 #define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
671 if (this->m_lhs_inner_dim_contiguous) { \
672 if (this->m_rhs_inner_dim_contiguous) { \
673 if (this->m_rhs_inner_dim_reordered) { \
674 (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
676 (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
679 if (this->m_rhs_inner_dim_reordered) { \
680 (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
682 (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
686 if (this->m_rhs_inner_dim_contiguous) { \
687 if (this->m_rhs_inner_dim_reordered) { \
688 (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
690 (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
693 if (this->m_rhs_inner_dim_reordered) { \
694 (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
696 (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
703 static_cast<const Derived*
>(
this)->
template evalProduct<Unaligned>(
buffer);
706 #ifdef EIGEN_USE_THREADS
707 template <
typename EvalToCallback>
708 void evalToAsync(
Scalar*
buffer, EvalToCallback done)
const {
709 static_cast<const Derived*
>(
this)
710 ->
template evalProductAsync<EvalToCallback, Unaligned>(
buffer,
713 #endif // EIGEN_USE_THREADS
715 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
716 bool rhs_inner_dim_reordered,
int Alignment>
719 this->
template evalGemv<lhs_inner_dim_contiguous,
720 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
723 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
724 rhs_inner_dim_reordered, Alignment>(
buffer);
728 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
729 #if !defined(EIGEN_HIPCC)
747 lhs_inner_dim_contiguous,
748 false, lhs_alignment> LhsMapper;
753 rhs_inner_dim_contiguous,
754 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
762 const Index resIncr(1);
774 static_cast<Index>(1));
777 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
778 #if !defined(EIGEN_HIPCC)
785 rhs_inner_dim_contiguous,
786 rhs_inner_dim_reordered,
787 Alignment,
true>(
buffer, 0, k, 1);
790 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
791 bool rhs_inner_dim_reordered,
int Alignment>
795 rhs_inner_dim_reordered, Alignment,
796 false>(
buffer, k_start, k_end,
800 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment,
bool use_output_kernel>
804 const Index k_slice = k_end - k_start;
825 lhs_inner_dim_contiguous,
831 rhs_inner_dim_contiguous,
832 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
837 Scalar, LhsScalar, RhsScalar,
Index, OutputMapper, LhsMapper, RhsMapper>
838 TensorContractionKernel;
847 OutputMapper output(
buffer,
m);
852 blocking(k_slice,
m,
n, num_threads);
853 const Index kc = blocking.kc();
857 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
858 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
863 TensorContractionKernel kernel(
m, k_slice,
n, mc, kc,
nc);
865 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
866 const BlockMemHandle packed_mem =
867 kernel.allocate(this->
m_device, &blockA, &blockB);
871 if (!TensorContractionKernel::HasBeta) {
875 for(
Index i2=0; i2<
m; i2+=mc)
878 for (
Index k2 = k_start; k2 < k_end; k2 += kc) {
881 kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
886 const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start)
891 for (
Index j2 = 0; j2 <
n; j2 +=
nc) {
894 kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
899 const OutputMapper output_mapper = output.getSubMapper(i2, j2);
900 kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
904 if (use_output_kernel && k2 + kc >= k_end) {
906 actual_mc, actual_nc);
912 kernel.deallocate(this->
m_device, packed_mem);
933 template<
int LoadMode>
935 return internal::ploadt<PacketReturnType, LoadMode>(
m_result + index);
971 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType,
typename Device>
974 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
997 static const int LDims =
999 static const int RDims =
1007 static const int NumDims = LDims + RDims - 2 * ContractDims;
1013 Base(op, device) { }
1015 template <
int Alignment>
1023 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
RhsXprType::Nested m_rhs_xpr
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType &lhs, const RhsXprType &rhs, const Indices &dims, const OutputKernelType &output_kernel=OutputKernelType())
array< Index, RDims - ContractDims > right_nocontract_t
EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType &op, const Device &device)
Eigen::internal::traits< TensorContractionOp >::Index Index
#define EIGEN_DEVICE_FUNC
Namespace containing all symbols from the Eigen library.
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Device > Self
TensorContractionEvaluatorBase< Self > Base
static const double d[K][N]
array< Index, LDims - ContractDims > left_nocontract_t
Eigen::internal::traits< TensorContractionOp >::StorageKind StorageKind
PacketType< CoeffReturnType, Device >::type PacketReturnType
static EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, const Index bm, const Index bk, const Index bn, LhsScalar **lhs_block, RhsScalar **rhs_block)
Storage::Type EvaluatorPointerType
internal::gebp_traits< LhsScalar, RhsScalar > Traits
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
DSizes< Index, NumDims > Dimensions
Eigen::internal::traits< TensorContractionOp >::Scalar Scalar
EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, LhsBlock *lhs_block, RhsBlock *rhs_block)
TensorContractionParams m_tensor_contraction_params
EIGEN_DEVICE_FUNC void evalGemm(Scalar *buffer) const
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
right_nocontract_t m_j_strides
internal::traits< Derived >::Device Device
static Similarity3 align(const Point3Pairs &d_abPointPairs, const Rot3 &aRb, const Point3Pair ¢roids)
This method estimates the similarity transform from differences point pairs,.
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper< Scalar, Index, ColMajor > &output_mapper, const TensorContractionParams ¶ms, Index i, Index j, Index num_rows, Index num_cols) const
PacketType< CoeffReturnType, Device >::type PacketReturnType
static const SmartProjectionParams params
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(RhsBlock *rhsBlock, const typename RhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex cols)
TensorEvaluator< EvalRightArgType, Device > m_rightImpl
double beta(double a, double b)
LhsXprType::Nested LhsNested
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
gebp_traits< typename remove_const< typename LhsXprType::Scalar >::type, typename remove_const< typename RhsXprType::Scalar >::type >::ResScalar Scalar
TensorEvaluator(const XprType &op, const Device &device)
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
internal::remove_const< typename XprType::Scalar >::type Scalar
conditional< Pointer_type_promotion< typename LhsXprType::Scalar, Scalar >::val, typename traits< LhsXprType >::PointerType, typename traits< RhsXprType >::PointerType >::type PointerType
LhsXprType::Nested m_lhs_xpr
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
OutputKernelType m_output_kernel
internal::gemm_pack_rhs< RhsScalar, StorageIndex, typename RhsMapper::SubMapper, Traits::nr, ColMajor > RhsPacker
EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
remove_reference< LhsNested >::type _LhsNested
right_nocontract_t m_right_nocontract_strides
TensorEvaluator< EvalLeftArgType, Device > m_leftImpl
#define EIGEN_UNUSED_VARIABLE(var)
promote_storage_type< typename traits< LhsXprType >::StorageKind, typename traits< RhsXprType >::StorageKind >::ret StorageKind
void evalProduct(Scalar *buffer) const
TensorEvaluator< EvalLeftArgType, Device > LeftEvaluatorType
const EIGEN_DEVICE_FUNC internal::remove_all< typename RhsXprType::Nested >::type & rhsExpression() const
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(const OutputMapper &output_mapper, const LhsBlock &lhsBlock, const RhsBlock &rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, const ResScalar alpha, const ResScalar beta)
internal::gebp_traits< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
internal::traits< Derived >::OutputKernelType OutputKernelType
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
array< Index, RDims - ContractDims > right_nocontract_t
#define EIGEN_STRONG_INLINE
StorageMemory< Scalar, Device > Storage
RhsXprType::Nested RhsNested
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const
internal::traits< Derived >::Indices Indices
EIGEN_DEVICE_FUNC void evalGemv(Scalar *buffer) const
const EIGEN_DEVICE_FUNC internal::remove_all< typename LhsXprType::Nested >::type & lhsExpression() const
static const int ContractDims
BlockMemAllocator::BlockMemHandle BlockMemHandle
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
XprType::CoeffReturnType CoeffReturnType
#define EIGEN_ALWAYS_INLINE
XprType::CoeffReturnType CoeffReturnType
EIGEN_DEVICE_FUNC const EIGEN_STRONG_INLINE Dimensions & dimensions() const
internal::TensorBlockNotImplemented TensorBlock
void evalProductSequential(Scalar *buffer) const
EIGEN_DEVICE_FUNC void evalTo(Scalar *buffer) const
EvaluatorPointerType m_result
EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const StorageIndex num_lhs, const StorageIndex num_rhs, const StorageIndex num_slices, std::vector< LhsBlock > *lhs_blocks, std::vector< RhsBlock > *rhs_blocks)
promote_index_type< typename traits< LhsXprType >::Index, typename traits< RhsXprType >::Index >::type Index
left_nocontract_t m_left_nocontract_strides
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
internal::gebp_kernel< LhsScalar, RhsScalar, StorageIndex, OutputMapper, Traits::mr, Traits::nr, false, false > GebpKernel
internal::remove_const< typename XprType::Scalar >::type Scalar
internal::conditional< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
RightArgType_ RightArgType
array< Index, LDims - ContractDims > left_nocontract_t
const Device EIGEN_DEVICE_REF m_device
bool m_rhs_inner_dim_contiguous
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
const EIGEN_DEVICE_FUNC Indices & indices() const
static EIGEN_DEVICE_FUNC BlockSizes ComputeLhsRhsBlockSizes(const Index bm, const Index bk, const Index bn)
EIGEN_STRONG_INLINE void swap(T &a, T &b)
bool m_rhs_inner_dim_reordered
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data)
DSizes< Index, NumDims > Dimensions
Eigen::internal::nested< TensorContractionOp >::type Nested
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::size_t size()
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
array< Index, ContractDims > contract_t
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const
TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > type
contract_t m_left_contracting_strides
const EIGEN_DEVICE_FUNC OutputKernelType & outputKernel() const
A cost model used to limit the number of threads used for evaluating tensor expression.
left_nocontract_t m_i_strides
internal::traits< Derived >::RightArgType RightArgType
internal::traits< Derived >::LeftArgType LeftArgType
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
const OutputKernelType m_output_kernel
bool m_lhs_inner_dim_contiguous
array< Index, ContractDims > contract_t
EIGEN_DEVICE_FUNC const EIGEN_ALWAYS_INLINE T1 & choose(Cond< true >, const T1 &first, const T2 &)
remove_reference< RhsNested >::type _RhsNested
std::vector< size_t > Indices
TensorContractionBlockMemAllocator< LhsScalar, RhsScalar > BlockMemAllocator
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
TensorEvaluator< EvalRightArgType, Device > RightEvaluatorType
contract_t m_right_contracting_strides
EIGEN_STRONG_INLINE void cleanup()
LhsPacket LhsPacket4Packing
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(LhsBlock *lhsBlock, const typename LhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex rows)
const typedef TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > & type
static EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const Index bm, const Index bk, const Index bn, const Index num_lhs, const Index num_rhs, const Index num_slices, std::vector< LhsScalar * > *lhs_blocks, std::vector< RhsScalar * > *rhs_blocks)
OutputKernelType_ OutputKernelType
EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
#define EIGEN_DONT_INLINE
internal::gemm_pack_lhs< LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor > LhsPacker
gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:05:30