10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H 14 #ifdef EIGEN_USE_THREADS 18 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
19 struct TensorEvaluator<const TensorContractionOp<
Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
20 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
22 typedef ThreadPoolDevice Device;
24 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
25 typedef TensorContractionEvaluatorBase<Self>
Base;
27 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
XprType;
41 typedef typename internal::conditional<
42 static_cast<int>(
Layout) == static_cast<int>(
ColMajor), LeftArgType, RightArgType>
::type EvalLeftArgType;
43 typedef typename internal::conditional<
44 static_cast<int>(
Layout) == static_cast<int>(
ColMajor), RightArgType, LeftArgType>
::type EvalRightArgType;
46 static const int LDims =
47 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>
::value;
48 static const int RDims =
49 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>
::value;
56 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
57 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
59 static const int NumDims = LDims + RDims - 2 * ContractDims;
66 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
68 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
69 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
74 template <
int Alignment>
75 void evalProduct(Scalar*
buffer)
const {
76 evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
79 template <
typename EvalToCallback,
int Alignment>
80 void evalProductAsync(Scalar*
buffer, EvalToCallback done)
const {
81 evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
84 template <
typename DoneCallback,
int Alignment>
85 void evalProductImpl(Scalar*
buffer, DoneCallback done)
const {
101 static const bool IsEvalInSyncMode =
104 const Index
m = this->m_i_size;
105 const Index
n = this->m_j_size;
106 const Index k = this->m_k_size;
107 if (m == 0 || n == 0 || k == 0)
return;
132 bool shard_by_col = shardByCol(m, n, 2);
138 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
140 blocking(k, m, n, 2);
145 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
147 blocking(k, m, n, 2);
157 const TensorOpCost cost =
158 contractionCost(m, n, bm, bn, bk, shard_by_col,
false);
160 static_cast<double>(n) * m, cost, this->
m_device.numThreads());
161 int num_threads_by_k = numThreadsInnerDim(m, n, k);
162 if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
165 if (IsEvalInSyncMode) {
166 EvalShardedByInnerDimContext<DoneCallback> ctx(
167 this, num_threads_by_k, buffer, m, n, k, std::move(done));
168 ctx.template run<Alignment>();
170 auto* ctx =
new EvalShardedByInnerDimContext<DoneCallback>(
171 this, num_threads_by_k, buffer,
m,
n, k, std::move(done));
172 ctx->template runAsync<Alignment>();
180 if (n == 1) num_threads = 1;
182 if (num_threads == 1) {
185 if (!IsEvalInSyncMode) done();
190 shard_by_col = shardByCol(m, n, num_threads);
192 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
194 blocking(k, m, n, num_threads);
199 internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar,
Index,
201 blocking(k, m, n, num_threads);
208 Index nm0 =
divup(m, bm);
209 Index nn0 =
divup(n, bn);
210 Index nk =
divup(k, bk);
221 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
222 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
224 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
225 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
228 Index nm =
divup(nm0, gm);
235 const Index sharding_dim_tasks = shard_by_col ?
nn : nm;
236 const int num_worker_threads = this->
m_device.numThreadsInPool();
241 const float oversharding_factor =
242 num_worker_threads <= 4 ? 8.0 :
243 num_worker_threads <= 8 ? 4.0 :
244 num_worker_threads <= 16 ? 2.0 :
245 num_worker_threads <= 32 ? 1.0 :
246 num_worker_threads <= 64 ? 0.8 : 0.6;
248 const bool parallelize_by_sharding_dim_only =
249 sharding_dim_tasks >= oversharding_factor * num_worker_threads;
258 bool parallel_pack = num_threads >= nm *
nn;
260 if (m * bk *
Index(
sizeof(LhsScalar)) + n * bk *
Index(
sizeof(RhsScalar)) <=
262 parallel_pack =
true;
265 if ((shard_by_col ? nm : nn) == 1) parallel_pack =
false;
268 if (parallelize_by_sharding_dim_only) parallel_pack =
false;
271 if (IsEvalInSyncMode) {
272 #define CONTEXT_ARGS \ 273 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ 274 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \ 282 #define CONTEXT_ARGS \ 283 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ 284 nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only, \ 287 Alignment, CONTEXT_ARGS,
run());
298 eigen_assert(
false &&
"NoCallback should never be called");
304 template <
typename DoneCallback,
typename Context>
305 class EvalParallelNotification;
308 template <
typename Context>
309 class EvalParallelNotification<NoCallback, Context> {
311 EvalParallelNotification(Context*, NoCallback) {}
312 void Notify() { done_.Notify(); }
313 void Wait() { done_.Wait(); }
319 template <
typename DoneCallback,
typename Context>
320 class EvalParallelNotification {
322 EvalParallelNotification(Context* ctx, DoneCallback done)
323 : ctx_(ctx), done_(
std::
move(done)) {}
329 DoneCallback done_copy = std::move(done_);
349 template <
typename DoneCallback,
bool lhs_inner_dim_contiguous,
350 bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
352 class EvalParallelContext {
354 typedef internal::TensorContractionInputMapper<
357 lhs_inner_dim_contiguous,
false,
Unaligned>
359 typedef internal::TensorContractionInputMapper<
362 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
365 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
367 typedef internal::TensorContractionKernel<
368 Scalar, LhsScalar, RhsScalar,
Index, OutputMapper, LhsMapper, RhsMapper>
369 TensorContractionKernel;
371 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
372 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
373 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
375 EvalParallelContext(
const Self*
self,
int num_threads, Scalar* buffer,
376 Index tm, Index tn, Index tk, Index bm, Index bn,
377 Index bk, Index nm, Index nn, Index nk, Index gm,
378 Index gn, Index nm0, Index nn0,
bool shard_by_col,
380 bool parallelize_by_sharding_dim_only,
382 : created_by_thread_id_(
std::this_thread::get_id()),
385 lhs_(
self->m_leftImpl,
self->m_left_nocontract_strides,
386 self->m_i_strides,
self->m_left_contracting_strides,
388 rhs_(
self->m_rightImpl,
self->m_right_nocontract_strides,
389 self->m_j_strides,
self->m_right_contracting_strides,
393 output_kernel_(
self->m_output_kernel),
394 tensor_contraction_params_(
self->m_tensor_contraction_params),
395 num_threads_(num_threads),
396 shard_by_col_(shard_by_col),
397 parallel_pack_(parallel_pack),
398 parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
412 kernel_(m_, k_, n_, bm_, bk_, bn_),
413 num_thread_local_allocations_(0),
417 thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
418 ? device_.numThreadsInPool()
422 lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
424 rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
427 eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
429 for (Index
x = 0;
x <
P;
x++) {
437 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
438 (
x == P - 1 ? nm_ * nn_ : 0);
439 state_packing_ready_[
x] =
440 parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
441 state_kernel_[
x] =
new std::atomic<uint8_t>*[nm_];
442 for (Index m = 0; m < nm_; m++) {
443 state_kernel_[
x][
m] =
new std::atomic<uint8_t>[nn_];
447 for (Index n = 0; n < nn_; n++)
448 state_kernel_[
x][m][n].store(
449 (
x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
450 std::memory_order_relaxed);
455 packed_mem_ = kernel_.allocateSlices(
459 std::min<Index>(nk_, P - 1),
460 packed_lhs_, packed_rhs_);
462 if (parallelize_by_sharding_dim_only_) {
463 const int num_worker_threads = device_.numThreadsInPool();
466 can_use_thread_local_packed_ =
new std::atomic<bool>[nn_];
467 for (
int i = 0;
i < nn_; ++
i)
468 can_use_thread_local_packed_[
i].store(
true,
469 std::memory_order_relaxed);
471 Index num_blocks = num_worker_threads * gn_;
472 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
477 nullptr, &rhs_thread_local_pre_allocated_);
480 can_use_thread_local_packed_ =
new std::atomic<bool>[nm_];
481 for (
int i = 0;
i < nm_; ++
i)
482 can_use_thread_local_packed_[
i].store(
true,
483 std::memory_order_relaxed);
485 Index num_blocks = num_worker_threads * gm_;
486 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
490 1, &lhs_thread_local_pre_allocated_,
496 ~EvalParallelContext() {
497 for (Index
x = 0;
x <
P;
x++) {
498 for (Index m = 0; m < nm_; m++)
delete[] state_kernel_[
x][m];
499 delete[] state_kernel_[
x];
501 kernel_.deallocate(device_, packed_mem_);
502 if (parallelize_by_sharding_dim_only_) {
503 kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
504 delete[] can_use_thread_local_packed_;
533 EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
535 const Device& device_;
538 Scalar*
const buffer_;
539 OutputMapper output_;
540 OutputKernelType output_kernel_;
541 TensorContractionParams tensor_contraction_params_;
542 const int num_threads_;
543 const bool shard_by_col_;
544 const bool parallel_pack_;
545 const bool parallelize_by_sharding_dim_only_;
566 TensorContractionKernel kernel_;
602 static const Index P = 3;
605 BlockMemHandle packed_mem_;
606 std::vector<LhsBlock> packed_lhs_[P - 1];
607 std::vector<RhsBlock> packed_rhs_[P - 1];
627 BlockMemHandle thread_local_pre_alocated_mem_;
631 std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
632 std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
635 std::atomic<int> num_thread_local_allocations_;
636 const int thread_local_capacity;
644 template <
typename BlockType>
645 class ThreadLocalBlocks {
647 ThreadLocalBlocks() =
default;
649 ThreadLocalBlocks(BlockType*
base,
size_t grain_size)
650 : is_pre_allocated_(true),
651 thread_local_pre_allocated_base_(base),
652 grain_size_(grain_size) {}
654 ThreadLocalBlocks(BlockMemHandle mem_handle,
655 std::vector<BlockType> blocks)
656 : is_pre_allocated_(false),
657 mem_handle_(
std::
move(mem_handle)),
660 BlockType&
block(
int grain_index) {
663 return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
664 : blocks_[grain_index];
667 void Release(EvalParallelContext& ctx)
const {
668 if (!is_pre_allocated_) {
669 ctx.kernel_.deallocate(ctx.device_, mem_handle_);
673 size_t size()
const {
674 return is_pre_allocated_ ? grain_size_ : blocks_.size();
678 bool is_pre_allocated_;
681 BlockType* thread_local_pre_allocated_base_ =
nullptr;
682 size_t grain_size_ = 0;
685 BlockMemHandle mem_handle_{};
686 std::vector<BlockType> blocks_;
695 template <
typename BlockType,
bool is_rhs>
696 class ThreadLocalBlocksInitialize {
697 static constexpr
bool kIsLhs =
699 static const bool kIsRhs =
701 static_assert(kIsLhs || kIsRhs,
"Unkown block type");
703 using Blocks = ThreadLocalBlocks<BlockType>;
706 ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
708 num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
711 const int n = ctx_.num_thread_local_allocations_.fetch_add(
712 1, std::memory_order_relaxed);
714 if (n >= num_worker_threads_) {
715 ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
717 ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
726 template <
bool pack_rhs,
typename EvalCtx = EvalParallelContext>
727 struct ThreadLocalBlocksAllocator;
729 template <
typename EvalCtx>
730 struct ThreadLocalBlocksAllocator<true, EvalCtx> {
731 static void allocate(EvalCtx& ctx, Blocks& blocks) {
732 std::vector<RhsBlock> rhs_blocks;
733 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
738 nullptr, &rhs_blocks);
740 blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
741 std::move(rhs_blocks));
744 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
745 RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
746 blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
750 template <
typename EvalCtx>
751 struct ThreadLocalBlocksAllocator<false, EvalCtx> {
752 static void allocate(EvalCtx& ctx, Blocks& blocks) {
753 std::vector<LhsBlock> lhs_blocks;
754 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
759 &lhs_blocks,
nullptr);
761 blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
762 std::move(lhs_blocks));
765 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
766 LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
767 blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
771 EvalParallelContext& ctx_;
772 const int num_worker_threads_;
775 template <
typename BlockType>
776 class ThreadLocalBlocksRelease {
778 using Blocks = ThreadLocalBlocks<BlockType>;
779 ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
780 void operator()(Blocks& blocks) { blocks.Release(ctx_); }
783 EvalParallelContext& ctx_;
787 using ThreadLocalLhsInit =
788 ThreadLocalBlocksInitialize<LhsBlock,
false>;
789 using ThreadLocalRhsInit =
790 ThreadLocalBlocksInitialize<RhsBlock,
true>;
793 using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
794 using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
799 ThreadLocalLhsRelease>
800 lhs_thread_local_blocks_;
802 ThreadLocalRhsRelease>
803 rhs_thread_local_blocks_;
810 std::atomic<bool>* can_use_thread_local_packed_;
812 std::atomic<uint8_t>** state_kernel_[
P];
817 std::atomic<Index> state_packing_ready_[
P];
818 std::atomic<Index> state_switch_[
P];
820 LhsBlock& packed_lhs(Index m, Index k, Index
m1,
bool use_thread_local) {
821 if (use_thread_local) {
823 ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.
local();
825 Index grain_index = m1 - m * gm_;
826 return blocks.block(internal::convert_index<int>(grain_index));
828 return packed_lhs_[k % (P - 1)][m1];
832 RhsBlock& packed_rhs(Index n, Index k, Index
n1,
bool use_thread_local) {
833 if (use_thread_local) {
835 ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.
local();
837 Index grain_index = n1 - n * gn_;
838 return blocks.block(internal::convert_index<int>(grain_index));
840 return packed_rhs_[k % (P - 1)][n1];
854 void pack_lhs(Index m, Index k) {
855 bool use_thread_local =
false;
857 if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
858 can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
859 if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
860 use_thread_local =
true;
866 can_use_thread_local_packed_[
m].store(
false,
867 std::memory_order_relaxed);
871 const Index mend = m * gm_ + gm(m);
872 for (Index m1 = m * gm_; m1 < mend; m1++)
873 kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
874 lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
876 if (!parallel_pack_ && shard_by_col_) {
877 assert(!use_thread_local);
880 signal_switch(k + 1);
881 for (Index n = nn_ - 1; n >= 0; n--) {
882 bool sync = parallelize_by_sharding_dim_only_ || n == 0;
883 signal_kernel(m, n, k, sync, use_thread_local);
888 void pack_rhs(Index n, Index k) {
889 bool use_thread_local =
false;
891 if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
892 can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
893 if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
894 use_thread_local =
true;
900 can_use_thread_local_packed_[
n].store(
false,
901 std::memory_order_relaxed);
905 const Index nend = n * gn_ + gn(n);
906 for (Index n1 = n * gn_; n1 < nend; n1++) {
907 if (!TensorContractionKernel::HasBeta && k == 0) {
917 memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ *
sizeof(Scalar));
919 kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
920 rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
923 if (parallel_pack_ || shard_by_col_) {
924 signal_switch(k + 1);
925 for (Index m = nm_ - 1; m >= 0; m--) {
926 bool sync = parallelize_by_sharding_dim_only_ || m == 0;
927 signal_kernel(m, n, k, sync, use_thread_local);
930 assert(!use_thread_local);
935 void kernel(Index m, Index n, Index k,
bool use_thread_local) {
939 const Index nend = n * gn_ + gn(n);
940 const Index mend = m * gm_ + gm(m);
945 (TensorContractionKernel::HasBeta && k == 0) ?
Scalar(0) :
Scalar(1);
948 for (Index n1 = n * gn_; n1 < nend; n1++) {
949 for (Index m1 = m * gm_; m1 < mend; m1++) {
950 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
953 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
954 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
955 bk(k), bn(n1), alpha, beta);
959 output_kernel_(output_mapper, tensor_contraction_params_,
960 m1 * bm_, n1 * bn_, bm(m1), bn(n1));
965 for (Index m1 = m * gm_; m1 < mend; m1++)
966 for (Index n1 = n * gn_; n1 < nend; n1++) {
967 const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
970 packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
971 packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
972 bk(k), bn(n1), alpha, beta);
976 output_kernel_(output_mapper, tensor_contraction_params_,
977 m1 * bm_, n1 * bn_, bm(m1), bn(n1));
981 signal_kernel(m, n, k + 1,
false,
false);
982 signal_switch(k + 2);
985 void signal_packing(Index k) {
987 Index
s = state_packing_ready_[k %
P].fetch_sub(1);
990 state_packing_ready_[k %
P] = shard_by_col_ ? nm_ : nn_;
991 enqueue_packing(k, shard_by_col_);
994 void signal_kernel(Index m, Index n, Index k,
bool sync,
995 bool use_thread_local) {
996 std::atomic<uint8_t>*
state = &state_kernel_[k %
P][
m][
n];
997 Index s = state->load();
999 if (s != 1 && state->fetch_sub(1) != 1) {
1003 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
1005 kernel(m, n, k, use_thread_local);
1008 device_.enqueueNoNotification(
1009 [=]() { kernel(m, n, k, use_thread_local); });
1013 void signal_switch(Index k, Index
v = 1) {
1014 Index s = state_switch_[k %
P].fetch_sub(
v);
1020 state_switch_[k %
P] =
1021 (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
1026 if (parallel_pack_) {
1027 enqueue_packing(k, !shard_by_col_);
1028 enqueue_packing(k, shard_by_col_);
1029 }
else if (shard_by_col_) {
1030 enqueue_packing(k,
false);
1032 enqueue_packing(k,
true);
1040 }
else if (k == nk_) {
1041 signal_switch(k + 1,
1042 parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
1049 void enqueue_packing(Index k,
bool rhs) {
1050 enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
1053 void enqueue_packing_helper(Index start, Index
end, Index k,
bool rhs) {
1054 if (end - start == 1) {
1060 while (end - start > 1) {
1061 Index mid = (start +
end) / 2;
1062 device_.enqueueNoNotification(
1063 [=]() { enqueue_packing_helper(mid, end, k, rhs); });
1077 (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
1078 (k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1081 device_.enqueueNoNotification(
1082 [=]() { enqueue_packing_helper(start, end, k, rhs); });
1084 enqueue_packing_helper(start, end, k, rhs);
1090 Index bm(Index m)
const {
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1091 Index bn(Index n)
const {
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1092 Index bk(Index k)
const {
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1094 Index gm(Index m)
const {
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1095 Index gn(Index n)
const {
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1097 EvalParallelContext(
const EvalParallelContext&) =
delete;
1098 void operator=(
const EvalParallelContext&) =
delete;
1101 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
1102 bool rhs_inner_dim_reordered,
int Alignment>
1103 using SyncEvalParallelContext =
1104 EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
1105 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1115 template <
typename DoneCallback>
1116 struct EvalShardedByInnerDimContext {
1117 EvalShardedByInnerDimContext(
const Self*
self,
int num_threads,
1118 Scalar* result_buffer,
1119 Index m_size, Index n_size, Index k_size,
1120 DoneCallback done_callback)
1122 m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1123 m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1124 m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1129 done(
std::
move(done_callback)),
1130 buffer_size_bytes(m * n * sizeof(Scalar)),
1131 block_size(blockSize(k, num_threads)),
1132 num_blocks(
divup<Index>(k, block_size)),
1134 l0_ranges(
divup<Index>(num_blocks, l0_size)),
1135 l0_state(l0_ranges),
1136 block_buffers(num_blocks) {
1138 for (
int i = 0;
i < l0_ranges; ++
i) {
1139 const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size,
i);
1140 l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1144 for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1145 Scalar* buf = block_idx == 0
1147 :
static_cast<Scalar*
>(evaluator->m_device.allocate(
1148 buffer_size_bytes));
1149 block_buffers.emplace_back(buf);
1153 ~EvalShardedByInnerDimContext() {
1154 for (Index
i = 1;
i < num_blocks; ++
i) {
1155 evaluator->m_device.deallocate(block_buffers[
i]);
1159 template <
int Alignment>
1161 Barrier barrier(internal::convert_index<int>(num_blocks));
1162 eval<Alignment>(barrier, 0, num_blocks);
1166 aggregateL0Blocks<Alignment>();
1169 applyOutputKernel();
1172 template <
int Alignment>
1174 evalAsync<Alignment>(0, num_blocks);
1180 static const Index packet_size = internal::packet_traits<RhsScalar>::size;
1182 const Self* evaluator;
1185 bool m_lhs_inner_dim_contiguous;
1186 bool m_rhs_inner_dim_contiguous;
1187 bool m_rhs_inner_dim_reordered;
1201 Index buffer_size_bytes;
1207 std::atomic<int> num_pending_blocks;
1225 static const Index l0_size = 4;
1229 MaxSizeVector<std::atomic<int>> l0_state;
1232 MaxSizeVector<Scalar*> block_buffers;
1234 template <
int Alignment>
1235 void processBlock(Index block_idx, Index begin, Index end) {
1236 Scalar* buf = block_buffers[block_idx];
1239 evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
1241 internal::convert_index<int>(num_blocks)));
1244 const Index l0_index = block_idx / l0_size;
1245 const int v = l0_state[l0_index].fetch_sub(1);
1251 const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1252 const Index dst_block_idx = l0_index * l0_size;
1254 if (rng_size == l0_size) {
1255 addAllToBuffer<Alignment>(
1257 block_buffers[dst_block_idx + 1],
1258 block_buffers[dst_block_idx + 2],
1259 block_buffers[dst_block_idx + 3],
1260 block_buffers[dst_block_idx]);
1263 for (
int i = 1;
i < rng_size; ++
i) {
1264 addToBuffer<Alignment>(m *
n,
1265 block_buffers[dst_block_idx +
i],
1266 block_buffers[dst_block_idx]);
1273 template <
int Alignment>
1274 void aggregateL0Blocks()
const {
1277 for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1278 addAllToBuffer<Alignment>(
1280 block_buffers[(l0_index + 0) * l0_size],
1281 block_buffers[(l0_index + 1) * l0_size],
1282 block_buffers[(l0_index + 2) * l0_size],
1286 for (; l0_index < l0_ranges; ++l0_index) {
1287 addToBuffer<Alignment>(m *
n, block_buffers[l0_index * l0_size],
1292 void applyOutputKernel()
const {
1293 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1294 evaluator->m_output_kernel(
1295 OutputMapper(result, m), evaluator->m_tensor_contraction_params,
1296 static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
1300 Index actualBlockSize(Index block_idx)
const {
1301 return block_idx + 1 < num_blocks
1303 : k + block_size - block_size * num_blocks;
1307 Index actualRangeSize(Index num_ranges, Index range_size,
1308 Index range_idx)
const {
1310 return range_idx + 1 < num_ranges
1312 : num_blocks + range_size - range_size * num_ranges;
1315 template <
int Alignment>
1318 const int output_packet_size =
1321 const size_t num_packets = n / output_packet_size;
1322 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1323 const PacketReturnType src_val =
1324 internal::pload<PacketReturnType>(src_buf +
i);
1325 const PacketReturnType tgt_val =
1326 internal::ploadt<PacketReturnType, Alignment>(tgt_buf +
i);
1328 internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf +
i,
1331 for (; i <
n; ++
i) {
1332 tgt_buf[
i] += src_buf[
i];
1336 template <
int Alignment>
1338 const Scalar* src_buf0,
1339 const Scalar* src_buf1,
1340 const Scalar* src_buf2,
1347 const int output_packet_size =
1351 const size_t num_packets = n / output_packet_size;
1352 for (; i < output_packet_size * num_packets; i += output_packet_size) {
1353 const auto src_val0 = pload<PacketReturnType>(src_buf0 +
i);
1354 const auto src_val1 = pload<PacketReturnType>(src_buf1 +
i);
1355 const auto src_val2 = pload<PacketReturnType>(src_buf2 +
i);
1357 const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf +
i);
1359 padd(
padd(dst_val, src_val0),
padd(src_val1, src_val2));
1361 pstoret<Scalar, PacketReturnType, Alignment>(dst_buf +
i, sum);
1363 for (; i <
n; ++
i) {
1364 dst_buf[
i] += src_buf0[
i] + src_buf1[
i] + src_buf2[
i];
1368 template <
int Alignment>
1369 void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
1370 while (end_block_idx - start_block_idx > 1) {
1371 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1372 evaluator->m_device.enqueueNoNotification(
1373 [
this, &barrier, mid_block_idx, end_block_idx]() {
1374 eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1376 end_block_idx = mid_block_idx;
1379 Index block_idx = start_block_idx;
1380 Index block_start = block_idx * block_size;
1381 Index block_end = block_start + actualBlockSize(block_idx);
1383 processBlock<Alignment>(block_idx, block_start, block_end);
1387 template <
int Alignment>
1388 void evalAsync(Index start_block_idx, Index end_block_idx) {
1389 while (end_block_idx - start_block_idx > 1) {
1390 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1391 evaluator->m_device.enqueueNoNotification(
1392 [
this, mid_block_idx, end_block_idx]() {
1393 evalAsync<Alignment>(mid_block_idx, end_block_idx);
1395 end_block_idx = mid_block_idx;
1398 Index block_idx = start_block_idx;
1400 Index block_start = block_idx * block_size;
1401 Index block_end = block_start + actualBlockSize(block_idx);
1403 processBlock<Alignment>(block_idx, block_start, block_end);
1405 int v = num_pending_blocks.fetch_sub(1);
1410 aggregateL0Blocks<Alignment>();
1413 applyOutputKernel();
1420 DoneCallback done_copy = std::move(done);
1433 static Index blockSize(Index k,
int num_threads) {
1434 const auto round_up = [=](Index index) -> Index {
1435 const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1436 return divup<Index>(index, kmultiple) * kmultiple;
1439 const Index target_block_size = round_up(divup<Index>(k, num_threads));
1440 const Index desired_min_block_size = 12 * packet_size;
1442 return numext::mini<Index>(
1443 k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1446 EvalShardedByInnerDimContext(
const EvalShardedByInnerDimContext&) =
delete;
1447 void operator=(
const EvalShardedByInnerDimContext&) =
delete;
1456 static bool shardByCol(Index m, Index n, Index num_threads) {
1463 if (m / num_threads >= Traits::nr &&
1465 (n / num_threads < Traits::nr ||
1468 (n / num_threads < 4 * Traits::nr &&
1469 (n % (num_threads * Traits::nr)) != 0 &&
1471 ((m % (num_threads * Traits::nr)) == 0 ||
1479 if (n / num_threads < 16 * Traits::nr && m > n * 32)
return false;
1483 Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
1484 int num_threads,
bool shard_by_col)
const {
1487 Index nm0 =
divup(m, bm);
1493 while (gm1 <= nm0 && nm1 ==
divup(nm0, gm1)) gm1++;
1494 if (gm1 > nm0)
break;
1496 int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
1499 nm1 =
divup(nm0, gm1);
1500 if (res == 0)
continue;
1507 Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1508 int num_threads,
bool shard_by_col)
const {
1511 Index nn0 =
divup(n, bn);
1514 while (gn1 <= nn0 && nn1 ==
divup(nn0, gn1)) gn1++;
1515 if (gn1 > nn0)
break;
1516 int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
1519 nn1 =
divup(nn0, gn1);
1520 if (res == 0)
continue;
1528 int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
1529 Index gn, Index oldgm, Index oldgn,
int num_threads,
1530 bool shard_by_col)
const {
1531 const TensorOpCost cost =
1532 contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col,
true);
1534 static_cast<double>(bm) * gm * bn * gn, cost);
1537 if (taskSize < 1)
return 1;
1539 if (taskSize > 2)
return -1;
1546 Index nm0 =
divup(m, bm);
1547 Index nn0 =
divup(n, bn);
1548 Index new_tasks =
divup(nm0, gm) *
divup(nn0, gn);
1549 double new_parallelism =
static_cast<double>(new_tasks) /
1550 (divup<int>(new_tasks, num_threads) * num_threads);
1551 Index old_tasks =
divup(nm0, oldgm) *
divup(nn0, oldgn);
1552 double old_parallelism =
static_cast<double>(old_tasks) /
1553 (divup<int>(old_tasks, num_threads) * num_threads);
1554 if (new_parallelism > old_parallelism || new_parallelism == 1)
return 1;
1558 TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1559 bool shard_by_col,
bool prepacked)
const {
1563 const double kd =
static_cast<double>(bk);
1564 double compute_bandwidth = computeBandwidth(
false, bm, bn, bk);
1566 TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth,
true, packed_size);
1568 cost += TensorOpCost(0,
sizeof(CoeffReturnType), 0,
true, output_packet_size);
1576 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd /
n);
1577 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd /
m);
1581 lhsCost.dropMemoryCost();
1583 rhsCost.dropMemoryCost();
1584 return cost + lhsCost + rhsCost;
1589 static bool shardByInnerDim(Index m, Index n, Index k,
int num_threads,
1590 int num_threads_by_k) {
1591 std::ptrdiff_t bufsize = m * n *
sizeof(
Scalar);
1592 bool shard_by_k =
false;
1594 num_threads_by_k < 2 ||
1599 k / num_threads_by_k < 2 * Traits::nr) {
1604 (k / num_threads_by_k > 8 * Traits::nr &&
1608 num_threads_by_k > num_threads))) {
1614 TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k)
const {
1617 TensorOpCost cost(0, 0, (computeBandwidth(
true, m, n, k) * m) * n,
true, output_packet_size);
1619 cost += TensorOpCost(0,
sizeof(CoeffReturnType), 0,
true, output_packet_size);
1620 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) *
m;
1621 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) *
n;
1624 lhsCost.dropMemoryCost();
1625 return cost + lhsCost + rhsCost;
1628 int numThreadsInnerDim(Index m, Index n, Index k)
const {
1630 TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
1631 double total_parallel_cost =
1636 m * n, TensorOpCost(2, 1, 1,
true, output_packet_size));
1637 int num_threads = 1;
1638 double min_cost = total_parallel_cost;
1639 double kPerThreadOverHead = 3000;
1640 double kFixedOverHead = 100000;
1641 for (
int nt = 2; nt <= this->
m_device.numThreads(); nt += 2) {
1642 double sequential_cost =
1643 kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1644 double parallel_cost = total_parallel_cost / nt + sequential_cost;
1645 if (parallel_cost < min_cost) {
1647 min_cost = parallel_cost;
1653 double computeBandwidth(
bool shard_by_col, Index bm, Index bn,
1658 double computeBandwidth =
1660 : (shard_by_col ? bn : bm) < Traits::nr ||
1661 (shard_by_col ? bm : bn) < Traits::mr
1664 #ifndef EIGEN_VECTORIZE_FMA 1669 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1671 return computeBandwidth;
1678 #endif // EIGEN_USE_THREADS 1679 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int numThreads(double output_size, const TensorOpCost &cost_per_coeff, int max_threads)
typename XprType::Scalar type
#define EIGEN_STRONG_INLINE
#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN)
Derived::Scalar CoeffReturnType
Namespace containing all symbols from the Eigen library.
EIGEN_DEVICE_FUNC IndexDest convert_index(const IndexSrc &idx)
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
static const Similarity3 id
std::ptrdiff_t l3CacheSize()
EIGEN_DEVICE_FUNC Packet padd(const Packet &a, const Packet &b)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret(Eigen::TensorSycl::internal::RangeAccess< cl::sycl::access::mode::read_write, typename unpacket_traits< packet_type >::type > to, const packet_type &from)
detail::enable_if_t<!detail::move_never< T >::value, T > move(object &&obj)
std::ptrdiff_t l2CacheSize()
const Device EIGEN_DEVICE_REF m_device
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Array< int, Dynamic, 1 > v
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc &desc, TensorBlockScratch &scratch, bool=false) const
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
PacketType< CoeffReturnType, Device >::type PacketReturnType
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Derived::Dimensions Dimensions
static EIGEN_DEPRECATED const end_t end
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt(const Eigen::TensorSycl::internal::RangeAccess< cl::sycl::access::mode::read_write, typename unpacket_traits< packet_type >::type > &from)
std::vector< size_t > Indices
internal::nested_eval< T, 1 >::type eval(const T &xpr)
Annotation indicating that a class derives from another given type.
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double totalCost(double output_size, const TensorOpCost &cost_per_coeff)
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double taskSize(double output_size, const TensorOpCost &cost_per_coeff)
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T divup(const X x, const Y y)
EIGEN_STRONG_INLINE TensorEvaluator(const Derived &m, const Device &device)
internal::packet_traits< Scalar >::type type
EIGEN_DEVICE_FUNC Packet pload(const typename unpacket_traits< Packet >::type *from)