10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
14 #ifdef EIGEN_USE_THREADS
18 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
19 struct TensorEvaluator<const TensorContractionOp<
Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
20 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
22 typedef ThreadPoolDevice Device;
24 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
25 typedef TensorContractionEvaluatorBase<Self>
Base;
27 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
XprType;
41 typedef typename internal::conditional<
42 static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor), LeftArgType, RightArgType>
::type EvalLeftArgType;
43 typedef typename internal::conditional<
44 static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor), RightArgType, LeftArgType>
::type EvalRightArgType;
46 static const int LDims =
47 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>
::value;
48 static const int RDims =
49 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>
::value;
56 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
57 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
59 static const int NumDims = LDims + RDims - 2 * ContractDims;
66 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
68 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
69 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
74 template <
int Alignment>
76 evalProductImpl<NoCallback, Alignment>(
buffer, NoCallback());
79 template <
typename EvalToCallback,
int Alignment>
80 void evalProductAsync(
Scalar*
buffer, EvalToCallback done)
const {
81 evalProductImpl<EvalToCallback, Alignment>(
buffer, std::move(done));
84 template <
typename DoneCallback,
int Alignment>
85 void evalProductImpl(
Scalar*
buffer, DoneCallback done)
const {
101 static const bool IsEvalInSyncMode =
104 const Index m = this->m_i_size;
105 const Index n = this->m_j_size;
106 const Index k = this->m_k_size;
107 if (
m == 0 ||
n == 0 || k == 0)
return;
132 bool shard_by_col = shardByCol(
m,
n, 2);
138 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
140 blocking(k,
m,
n, 2);
145 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
147 blocking(k,
m,
n, 2);
157 const TensorOpCost cost =
158 contractionCost(
m,
n, bm, bn, bk, shard_by_col,
false);
160 static_cast<double>(
n) *
m, cost, this->
m_device.numThreads());
161 int num_threads_by_k = numThreadsInnerDim(
m,
n, k);
162 if (shardByInnerDim(
m,
n, k, num_threads, num_threads_by_k)) {
165 if (IsEvalInSyncMode) {
166 EvalShardedByInnerDimContext<DoneCallback> ctx(
167 this, num_threads_by_k,
buffer,
m,
n, k, std::move(done));
168 ctx.template run<Alignment>();
170 auto* ctx =
new EvalShardedByInnerDimContext<DoneCallback>(
171 this, num_threads_by_k,
buffer,
m,
n, k, std::move(done));
172 ctx->template runAsync<Alignment>();
180 if (
n == 1) num_threads = 1;
182 if (num_threads == 1) {
185 if (!IsEvalInSyncMode) done();
190 shard_by_col = shardByCol(
m,
n, num_threads);
192 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
194 blocking(k,
m,
n, num_threads);
199 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
201 blocking(k,
m,
n, num_threads);
221 gm = coarsenM(
m,
n, bm, bn, bk,
gn, num_threads, shard_by_col);
222 gn = coarsenN(
m,
n, bm, bn, bk, gm, num_threads, shard_by_col);
224 gn = coarsenN(
m,
n, bm, bn, bk, gm, num_threads, shard_by_col);
225 gm = coarsenM(
m,
n, bm, bn, bk,
gn, num_threads, shard_by_col);
235 const Index sharding_dim_tasks = shard_by_col ?
nn : nm;
236 const int num_worker_threads = this->
m_device.numThreadsInPool();
241 const float oversharding_factor =
242 num_worker_threads <= 4 ? 8.0 :
243 num_worker_threads <= 8 ? 4.0 :
244 num_worker_threads <= 16 ? 2.0 :
245 num_worker_threads <= 32 ? 1.0 :
246 num_worker_threads <= 64 ? 0.8 : 0.6;
248 const bool parallelize_by_sharding_dim_only =
249 sharding_dim_tasks >= oversharding_factor * num_worker_threads;
258 bool parallel_pack = num_threads >= nm *
nn;
260 if (
m * bk *
Index(
sizeof(LhsScalar)) +
n * bk *
Index(
sizeof(RhsScalar)) <=
262 parallel_pack =
true;
265 if ((shard_by_col ? nm :
nn) == 1) parallel_pack =
false;
268 if (parallelize_by_sharding_dim_only) parallel_pack =
false;
271 if (IsEvalInSyncMode) {
272 #define CONTEXT_ARGS \
273 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
274 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
282 #define CONTEXT_ARGS \
283 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
284 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \
287 Alignment, CONTEXT_ARGS,
run());
298 eigen_assert(
false &&
"NoCallback should never be called");
304 template <
typename DoneCallback,
typename Context>
305 class EvalParallelNotification;
308 template <
typename Context>
309 class EvalParallelNotification<NoCallback, Context> {
311 EvalParallelNotification(Context*, NoCallback) {}
312 void Notify() { done_.Notify(); }
313 void Wait() { done_.Wait(); }
319 template <
typename DoneCallback,
typename Context>
320 class EvalParallelNotification {
322 EvalParallelNotification(Context* ctx, DoneCallback done)
323 : ctx_(ctx), done_(
std::
move(done)) {}
329 DoneCallback done_copy = std::move(done_);
349 template <
typename DoneCallback,
bool lhs_inner_dim_contiguous,
350 bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
352 class EvalParallelContext {
354 typedef internal::TensorContractionInputMapper<
357 lhs_inner_dim_contiguous,
false,
Unaligned>
359 typedef internal::TensorContractionInputMapper<
362 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
365 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
367 typedef internal::TensorContractionKernel<
368 Scalar, LhsScalar, RhsScalar,
Index, OutputMapper, LhsMapper, RhsMapper>
369 TensorContractionKernel;
371 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
372 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
373 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
375 EvalParallelContext(
const Self*
self,
int num_threads,
Scalar*
buffer,
380 bool parallelize_by_sharding_dim_only,
382 : created_by_thread_id_(
std::this_thread::get_id()),
385 lhs_(
self->m_leftImpl,
self->m_left_nocontract_strides,
386 self->m_i_strides,
self->m_left_contracting_strides,
388 rhs_(
self->m_rightImpl,
self->m_right_nocontract_strides,
389 self->m_j_strides,
self->m_right_contracting_strides,
393 output_kernel_(
self->m_output_kernel),
394 tensor_contraction_params_(
self->m_tensor_contraction_params),
395 num_threads_(num_threads),
396 shard_by_col_(shard_by_col),
397 parallel_pack_(parallel_pack),
398 parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
412 kernel_(m_, k_, n_, bm_, bk_, bn_),
413 num_thread_local_allocations_(0),
417 thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
418 ? device_.numThreadsInPool()
422 lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
424 rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
427 eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
437 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
438 (
x ==
P - 1 ? nm_ * nn_ : 0);
439 state_packing_ready_[
x] =
440 parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
441 state_kernel_[
x] =
new std::atomic<uint8_t>*[nm_];
443 state_kernel_[
x][
m] =
new std::atomic<uint8_t>[nn_];
448 state_kernel_[
x][
m][
n].store(
449 (
x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
450 std::memory_order_relaxed);
455 packed_mem_ = kernel_.allocateSlices(
459 std::min<Index>(nk_,
P - 1),
460 packed_lhs_, packed_rhs_);
462 if (parallelize_by_sharding_dim_only_) {
463 const int num_worker_threads = device_.numThreadsInPool();
466 can_use_thread_local_packed_ =
new std::atomic<bool>[nn_];
467 for (
int i = 0;
i < nn_; ++
i)
468 can_use_thread_local_packed_[
i].store(
true,
469 std::memory_order_relaxed);
471 Index num_blocks = num_worker_threads * gn_;
472 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
477 nullptr, &rhs_thread_local_pre_allocated_);
480 can_use_thread_local_packed_ =
new std::atomic<bool>[nm_];
481 for (
int i = 0;
i < nm_; ++
i)
482 can_use_thread_local_packed_[
i].store(
true,
483 std::memory_order_relaxed);
485 Index num_blocks = num_worker_threads * gm_;
486 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
490 1, &lhs_thread_local_pre_allocated_,
496 ~EvalParallelContext() {
498 for (
Index m = 0;
m < nm_;
m++)
delete[] state_kernel_[
x][
m];
499 delete[] state_kernel_[
x];
501 kernel_.deallocate(device_, packed_mem_);
502 if (parallelize_by_sharding_dim_only_) {
503 kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
504 delete[] can_use_thread_local_packed_;
533 EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
535 const Device& device_;
539 OutputMapper output_;
540 OutputKernelType output_kernel_;
541 TensorContractionParams tensor_contraction_params_;
542 const int num_threads_;
543 const bool shard_by_col_;
544 const bool parallel_pack_;
545 const bool parallelize_by_sharding_dim_only_;
566 TensorContractionKernel kernel_;
605 BlockMemHandle packed_mem_;
606 std::vector<LhsBlock> packed_lhs_[
P - 1];
607 std::vector<RhsBlock> packed_rhs_[
P - 1];
627 BlockMemHandle thread_local_pre_alocated_mem_;
631 std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
632 std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
635 std::atomic<int> num_thread_local_allocations_;
636 const int thread_local_capacity;
644 template <
typename BlockType>
645 class ThreadLocalBlocks {
647 ThreadLocalBlocks() =
default;
649 ThreadLocalBlocks(BlockType*
base,
size_t grain_size)
650 : is_pre_allocated_(true),
651 thread_local_pre_allocated_base_(
base),
652 grain_size_(grain_size) {}
654 ThreadLocalBlocks(BlockMemHandle mem_handle,
655 std::vector<BlockType> blocks)
656 : is_pre_allocated_(false),
657 mem_handle_(
std::
move(mem_handle)),
660 BlockType&
block(
int grain_index) {
663 return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
664 : blocks_[grain_index];
667 void Release(EvalParallelContext& ctx)
const {
668 if (!is_pre_allocated_) {
669 ctx.kernel_.deallocate(ctx.device_, mem_handle_);
673 size_t size()
const {
674 return is_pre_allocated_ ? grain_size_ : blocks_.size();
678 bool is_pre_allocated_;
681 BlockType* thread_local_pre_allocated_base_ =
nullptr;
682 size_t grain_size_ = 0;
685 BlockMemHandle mem_handle_{};
686 std::vector<BlockType> blocks_;
695 template <
typename BlockType,
bool is_rhs>
696 class ThreadLocalBlocksInitialize {
697 static constexpr
bool kIsLhs =
699 static const bool kIsRhs =
701 static_assert(kIsLhs || kIsRhs,
"Unkown block type");
703 using Blocks = ThreadLocalBlocks<BlockType>;
706 ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
708 num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
711 const int n = ctx_.num_thread_local_allocations_.fetch_add(
712 1, std::memory_order_relaxed);
714 if (
n >= num_worker_threads_) {
715 ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
717 ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_,
n, blocks);
726 template <
bool pack_rhs,
typename EvalCtx = EvalParallelContext>
727 struct ThreadLocalBlocksAllocator;
729 template <
typename EvalCtx>
730 struct ThreadLocalBlocksAllocator<true, EvalCtx> {
731 static void allocate(EvalCtx& ctx, Blocks& blocks) {
732 std::vector<RhsBlock> rhs_blocks;
733 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
738 nullptr, &rhs_blocks);
740 blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
741 std::move(rhs_blocks));
744 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
745 RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
746 blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
750 template <
typename EvalCtx>
751 struct ThreadLocalBlocksAllocator<false, EvalCtx> {
752 static void allocate(EvalCtx& ctx, Blocks& blocks) {
753 std::vector<LhsBlock> lhs_blocks;
754 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
759 &lhs_blocks,
nullptr);
761 blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
762 std::move(lhs_blocks));
765 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
766 LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
767 blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
771 EvalParallelContext& ctx_;
772 const int num_worker_threads_;
775 template <
typename BlockType>
776 class ThreadLocalBlocksRelease {
778 using Blocks = ThreadLocalBlocks<BlockType>;
779 ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
780 void operator()(Blocks& blocks) { blocks.Release(ctx_); }
783 EvalParallelContext& ctx_;
787 using ThreadLocalLhsInit =
788 ThreadLocalBlocksInitialize<LhsBlock,
false>;
789 using ThreadLocalRhsInit =
790 ThreadLocalBlocksInitialize<RhsBlock,
true>;
793 using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
794 using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
799 ThreadLocalLhsRelease>
800 lhs_thread_local_blocks_;
802 ThreadLocalRhsRelease>
803 rhs_thread_local_blocks_;
810 std::atomic<bool>* can_use_thread_local_packed_;
812 std::atomic<uint8_t>** state_kernel_[
P];
817 std::atomic<Index> state_packing_ready_[
P];
818 std::atomic<Index> state_switch_[
P];
821 if (use_thread_local) {
823 ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.
local();
826 return blocks.block(internal::convert_index<int>(grain_index));
828 return packed_lhs_[k % (
P - 1)][
m1];
833 if (use_thread_local) {
835 ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.
local();
838 return blocks.block(internal::convert_index<int>(grain_index));
840 return packed_rhs_[k % (
P - 1)][
n1];
855 bool use_thread_local =
false;
857 if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
858 can_use_thread_local_packed_[
m].load(std::memory_order_relaxed)) {
859 if (state_kernel_[k %
P][
m][0].load(std::memory_order_relaxed) == 1) {
860 use_thread_local =
true;
866 can_use_thread_local_packed_[
m].store(
false,
867 std::memory_order_relaxed);
871 const Index mend =
m * gm_ + gm(
m);
873 kernel_.packLhs(&packed_lhs(
m, k,
m1, use_thread_local),
874 lhs_.getSubMapper(
m1 * bm_, k * bk_), bk(k), bm(
m1));
876 if (!parallel_pack_ && shard_by_col_) {
877 assert(!use_thread_local);
880 signal_switch(k + 1);
881 for (
Index n = nn_ - 1;
n >= 0;
n--) {
882 bool sync = parallelize_by_sharding_dim_only_ ||
n == 0;
883 signal_kernel(
m,
n, k, sync, use_thread_local);
889 bool use_thread_local =
false;
891 if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
892 can_use_thread_local_packed_[
n].load(std::memory_order_relaxed)) {
893 if (state_kernel_[k %
P][0][
n].load(std::memory_order_relaxed) == 1) {
894 use_thread_local =
true;
900 can_use_thread_local_packed_[
n].store(
false,
901 std::memory_order_relaxed);
907 if (!TensorContractionKernel::HasBeta && k == 0) {
917 memset(buffer_ +
n1 * bn_ * m_, 0, bn(
n1) * m_ *
sizeof(
Scalar));
919 kernel_.packRhs(&packed_rhs(
n, k,
n1, use_thread_local),
920 rhs_.getSubMapper(k * bk_,
n1 * bn_), bk(k), bn(
n1));
923 if (parallel_pack_ || shard_by_col_) {
924 signal_switch(k + 1);
925 for (
Index m = nm_ - 1;
m >= 0;
m--) {
926 bool sync = parallelize_by_sharding_dim_only_ ||
m == 0;
927 signal_kernel(
m,
n, k, sync, use_thread_local);
930 assert(!use_thread_local);
940 const Index mend =
m * gm_ + gm(
m);
945 (TensorContractionKernel::HasBeta && k == 0) ?
Scalar(0) :
Scalar(1);
950 const auto output_mapper = output_.getSubMapper(
m1 * bm_,
n1 * bn_);
953 packed_lhs(
m, k,
m1, !shard_by_col_ && use_thread_local),
954 packed_rhs(
n, k,
n1, shard_by_col_ && use_thread_local), bm(
m1),
959 output_kernel_(output_mapper, tensor_contraction_params_,
960 m1 * bm_,
n1 * bn_, bm(
m1), bn(
n1));
967 const auto output_mapper = output_.getSubMapper(
m1 * bm_,
n1 * bn_);
970 packed_lhs(
m, k,
m1, !shard_by_col_ && use_thread_local),
971 packed_rhs(
n, k,
n1, shard_by_col_ && use_thread_local), bm(
m1),
976 output_kernel_(output_mapper, tensor_contraction_params_,
977 m1 * bm_,
n1 * bn_, bm(
m1), bn(
n1));
981 signal_kernel(
m,
n, k + 1,
false,
false);
982 signal_switch(k + 2);
985 void signal_packing(
Index k) {
987 Index s = state_packing_ready_[k %
P].fetch_sub(1);
990 state_packing_ready_[k %
P] = shard_by_col_ ? nm_ : nn_;
991 enqueue_packing(k, shard_by_col_);
995 bool use_thread_local) {
996 std::atomic<uint8_t>*
state = &state_kernel_[k %
P][
m][
n];
999 if (
s != 1 &&
state->fetch_sub(1) != 1) {
1003 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
1005 kernel(
m,
n, k, use_thread_local);
1008 device_.enqueueNoNotification(
1009 [=]() { kernel(
m,
n, k, use_thread_local); });
1014 Index s = state_switch_[k %
P].fetch_sub(
v);
1020 state_switch_[k %
P] =
1021 (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
1026 if (parallel_pack_) {
1027 enqueue_packing(k, !shard_by_col_);
1028 enqueue_packing(k, shard_by_col_);
1029 }
else if (shard_by_col_) {
1030 enqueue_packing(k,
false);
1032 enqueue_packing(k,
true);
1040 }
else if (k == nk_) {
1041 signal_switch(k + 1,
1042 parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
1049 void enqueue_packing(
Index k,
bool rhs) {
1050 enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
1054 if (
end - start == 1) {
1060 while (
end - start > 1) {
1062 device_.enqueueNoNotification(
1063 [=]() { enqueue_packing_helper(mid,
end, k, rhs); });
1077 (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
1078 (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1081 device_.enqueueNoNotification(
1082 [=]() { enqueue_packing_helper(start,
end, k, rhs); });
1084 enqueue_packing_helper(start,
end, k, rhs);
1090 Index bm(
Index m)
const {
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1091 Index bn(
Index n)
const {
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1092 Index bk(
Index k)
const {
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1094 Index gm(
Index m)
const {
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1095 Index gn(
Index n)
const {
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1097 EvalParallelContext(
const EvalParallelContext&) =
delete;
1098 void operator=(
const EvalParallelContext&) =
delete;
1101 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
1102 bool rhs_inner_dim_reordered,
int Alignment>
1103 using SyncEvalParallelContext =
1104 EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
1105 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1115 template <
typename DoneCallback>
1116 struct EvalShardedByInnerDimContext {
1117 EvalShardedByInnerDimContext(
const Self*
self,
int num_threads,
1120 DoneCallback done_callback)
1122 m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1123 m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1124 m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1129 done(
std::
move(done_callback)),
1130 buffer_size_bytes(
m *
n * sizeof(
Scalar)),
1131 block_size(blockSize(k, num_threads)),
1134 l0_ranges(
divup<
Index>(num_blocks, l0_size)),
1135 l0_state(l0_ranges),
1136 block_buffers(num_blocks) {
1138 for (
int i = 0;
i < l0_ranges; ++
i) {
1139 const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size,
i);
1140 l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1144 for (
Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1145 Scalar* buf = block_idx == 0
1147 :
static_cast<Scalar*
>(evaluator->m_device.allocate(
1148 buffer_size_bytes));
1149 block_buffers.emplace_back(buf);
1153 ~EvalShardedByInnerDimContext() {
1154 for (
Index i = 1;
i < num_blocks; ++
i) {
1155 evaluator->m_device.deallocate(block_buffers[
i]);
1159 template <
int Alignment>
1161 Barrier barrier(internal::convert_index<int>(num_blocks));
1162 eval<Alignment>(barrier, 0, num_blocks);
1166 aggregateL0Blocks<Alignment>();
1169 applyOutputKernel();
1172 template <
int Alignment>
1174 evalAsync<Alignment>(0, num_blocks);
1182 const Self* evaluator;
1185 bool m_lhs_inner_dim_contiguous;
1186 bool m_rhs_inner_dim_contiguous;
1187 bool m_rhs_inner_dim_reordered;
1201 Index buffer_size_bytes;
1207 std::atomic<int> num_pending_blocks;
1225 static const Index l0_size = 4;
1229 MaxSizeVector<std::atomic<int>> l0_state;
1232 MaxSizeVector<Scalar*> block_buffers;
1234 template <
int Alignment>
1236 Scalar* buf = block_buffers[block_idx];
1239 evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1241 internal::convert_index<int>(num_blocks)));
1244 const Index l0_index = block_idx / l0_size;
1245 const int v = l0_state[l0_index].fetch_sub(1);
1251 const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1252 const Index dst_block_idx = l0_index * l0_size;
1254 if (rng_size == l0_size) {
1255 addAllToBuffer<Alignment>(
1257 block_buffers[dst_block_idx + 1],
1258 block_buffers[dst_block_idx + 2],
1259 block_buffers[dst_block_idx + 3],
1260 block_buffers[dst_block_idx]);
1263 for (
int i = 1;
i < rng_size; ++
i) {
1264 addToBuffer<Alignment>(
m *
n,
1265 block_buffers[dst_block_idx +
i],
1266 block_buffers[dst_block_idx]);
1273 template <
int Alignment>
1274 void aggregateL0Blocks()
const {
1277 for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1278 addAllToBuffer<Alignment>(
1280 block_buffers[(l0_index + 0) * l0_size],
1281 block_buffers[(l0_index + 1) * l0_size],
1282 block_buffers[(l0_index + 2) * l0_size],
1286 for (; l0_index < l0_ranges; ++l0_index) {
1287 addToBuffer<Alignment>(
m *
n, block_buffers[l0_index * l0_size],
1292 void applyOutputKernel()
const {
1293 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1294 evaluator->m_output_kernel(
1295 OutputMapper(
result,
m), evaluator->m_tensor_contraction_params,
1300 Index actualBlockSize(
Index block_idx)
const {
1301 return block_idx + 1 < num_blocks
1303 : k + block_size - block_size * num_blocks;
1308 Index range_idx)
const {
1310 return range_idx + 1 < num_ranges
1312 : num_blocks + range_size - range_size * num_ranges;
1315 template <
int Alignment>
1318 const int output_packet_size =
1321 const size_t num_packets =
n / output_packet_size;
1322 for (;
i < output_packet_size * num_packets;
i += output_packet_size) {
1324 internal::pload<PacketReturnType>(src_buf +
i);
1326 internal::ploadt<PacketReturnType, Alignment>(tgt_buf +
i);
1328 internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf +
i,
1331 for (;
i <
n; ++
i) {
1332 tgt_buf[
i] += src_buf[
i];
1336 template <
int Alignment>
1347 const int output_packet_size =
1351 const size_t num_packets =
n / output_packet_size;
1352 for (;
i < output_packet_size * num_packets;
i += output_packet_size) {
1353 const auto src_val0 = pload<PacketReturnType>(src_buf0 +
i);
1354 const auto src_val1 = pload<PacketReturnType>(src_buf1 +
i);
1355 const auto src_val2 = pload<PacketReturnType>(src_buf2 +
i);
1357 const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf +
i);
1359 padd(
padd(dst_val, src_val0),
padd(src_val1, src_val2));
1361 pstoret<Scalar, PacketReturnType, Alignment>(dst_buf +
i, sum);
1363 for (;
i <
n; ++
i) {
1364 dst_buf[
i] += src_buf0[
i] + src_buf1[
i] + src_buf2[
i];
1368 template <
int Alignment>
1369 void eval(Barrier& barrier,
Index start_block_idx,
Index end_block_idx) {
1370 while (end_block_idx - start_block_idx > 1) {
1371 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1372 evaluator->m_device.enqueueNoNotification(
1373 [
this, &barrier, mid_block_idx, end_block_idx]() {
1374 eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1376 end_block_idx = mid_block_idx;
1379 Index block_idx = start_block_idx;
1380 Index block_start = block_idx * block_size;
1381 Index block_end = block_start + actualBlockSize(block_idx);
1383 processBlock<Alignment>(block_idx, block_start, block_end);
1387 template <
int Alignment>
1388 void evalAsync(
Index start_block_idx,
Index end_block_idx) {
1389 while (end_block_idx - start_block_idx > 1) {
1390 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1391 evaluator->m_device.enqueueNoNotification(
1392 [
this, mid_block_idx, end_block_idx]() {
1393 evalAsync<Alignment>(mid_block_idx, end_block_idx);
1395 end_block_idx = mid_block_idx;
1398 Index block_idx = start_block_idx;
1400 Index block_start = block_idx * block_size;
1401 Index block_end = block_start + actualBlockSize(block_idx);
1403 processBlock<Alignment>(block_idx, block_start, block_end);
1405 int v = num_pending_blocks.fetch_sub(1);
1410 aggregateL0Blocks<Alignment>();
1413 applyOutputKernel();
1420 DoneCallback done_copy = std::move(done);
1433 static Index blockSize(
Index k,
int num_threads) {
1434 const auto round_up = [=](
Index index) ->
Index {
1435 const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1436 return divup<Index>(index, kmultiple) * kmultiple;
1439 const Index target_block_size = round_up(divup<Index>(k, num_threads));
1440 const Index desired_min_block_size = 12 * packet_size;
1442 return numext::mini<Index>(
1443 k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1446 EvalShardedByInnerDimContext(
const EvalShardedByInnerDimContext&) =
delete;
1447 void operator=(
const EvalShardedByInnerDimContext&) =
delete;
1463 if (
m / num_threads >= Traits::nr &&
1465 (
n / num_threads < Traits::nr ||
1468 (
n / num_threads < 4 * Traits::nr &&
1469 (
n % (num_threads * Traits::nr)) != 0 &&
1471 ((
m % (num_threads * Traits::nr)) == 0 ||
1479 if (
n / num_threads < 16 * Traits::nr && m >
n * 32)
return false;
1484 int num_threads,
bool shard_by_col)
const {
1493 while (gm1 <= nm0 && nm1 ==
divup(nm0, gm1)) gm1++;
1494 if (gm1 > nm0)
break;
1496 int res = checkGrain(
m,
n, bm, bn, bk, gm1,
gn, gm,
gn, num_threads,
1499 nm1 =
divup(nm0, gm1);
1500 if (
res == 0)
continue;
1508 int num_threads,
bool shard_by_col)
const {
1514 while (gn1 <= nn0 && nn1 ==
divup(nn0, gn1)) gn1++;
1515 if (gn1 > nn0)
break;
1516 int res = checkGrain(
m,
n, bm, bn, bk, gm, gn1, gm,
gn, num_threads,
1519 nn1 =
divup(nn0, gn1);
1520 if (
res == 0)
continue;
1530 bool shard_by_col)
const {
1531 const TensorOpCost cost =
1532 contractionCost(bm * gm, bn *
gn, bm, bn, bk, shard_by_col,
true);
1534 static_cast<double>(bm) * gm * bn *
gn, cost);
1537 if (taskSize < 1)
return 1;
1539 if (taskSize > 2)
return -1;
1549 double new_parallelism =
static_cast<double>(new_tasks) /
1550 (divup<int>(new_tasks, num_threads) * num_threads);
1552 double old_parallelism =
static_cast<double>(old_tasks) /
1553 (divup<int>(old_tasks, num_threads) * num_threads);
1554 if (new_parallelism > old_parallelism || new_parallelism == 1)
return 1;
1559 bool shard_by_col,
bool prepacked)
const {
1563 const double kd =
static_cast<double>(bk);
1564 double compute_bandwidth = computeBandwidth(
false, bm, bn, bk);
1566 TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth,
true, packed_size);
1568 cost += TensorOpCost(0,
sizeof(
CoeffReturnType), 0,
true, output_packet_size);
1576 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd /
n);
1577 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd /
m);
1581 lhsCost.dropMemoryCost();
1583 rhsCost.dropMemoryCost();
1584 return cost + lhsCost + rhsCost;
1590 int num_threads_by_k) {
1591 std::ptrdiff_t bufsize =
m *
n *
sizeof(
Scalar);
1592 bool shard_by_k =
false;
1594 num_threads_by_k < 2 ||
1599 k / num_threads_by_k < 2 * Traits::nr) {
1604 (k / num_threads_by_k > 8 * Traits::nr &&
1608 num_threads_by_k > num_threads))) {
1617 TensorOpCost cost(0, 0, (computeBandwidth(
true,
m,
n, k) *
m) *
n,
true, output_packet_size);
1619 cost += TensorOpCost(0,
sizeof(
CoeffReturnType), 0,
true, output_packet_size);
1620 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) *
m;
1621 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) *
n;
1624 lhsCost.dropMemoryCost();
1625 return cost + lhsCost + rhsCost;
1630 TensorOpCost cost = contractionCostPerInnerDim(
m,
n, k);
1631 double total_parallel_cost =
1636 m *
n, TensorOpCost(2, 1, 1,
true, output_packet_size));
1637 int num_threads = 1;
1638 double min_cost = total_parallel_cost;
1639 double kPerThreadOverHead = 3000;
1640 double kFixedOverHead = 100000;
1641 for (
int nt = 2; nt <= this->
m_device.numThreads(); nt += 2) {
1642 double sequential_cost =
1643 kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1644 double parallel_cost = total_parallel_cost / nt + sequential_cost;
1645 if (parallel_cost < min_cost) {
1647 min_cost = parallel_cost;
1653 double computeBandwidth(
bool shard_by_col,
Index bm,
Index bn,
1658 double computeBandwidth =
1660 : (shard_by_col ? bn : bm) < Traits::nr ||
1661 (shard_by_col ? bm : bn) < Traits::mr
1664 #ifndef EIGEN_VECTORIZE_FMA
1669 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1671 return computeBandwidth;
1678 #endif // EIGEN_USE_THREADS
1679 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H