10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
14 #ifdef EIGEN_USE_THREADS
18 #ifdef EIGEN_USE_SIMPLE_THREAD_POOL
21 template<
typename LhsScalar,
typename LhsMapper,
typename Index>
31 template<
typename LhsScalar,
typename RhsScalar,
typename RhsMapper,
typename OutputMapper,
typename Index>
32 struct packRhsAndKernelArg {
33 const MaxSizeVector<LhsScalar*>* blockAs;
43 const Index num_threads;
44 const Index num_blockAs;
46 const Index k_block_idx;
47 const Index m_block_idx;
48 const Index n_block_idx;
51 MaxSizeVector<Notification*>* kernel_notifications;
52 const MaxSizeVector<Notification*>* lhs_notifications;
53 const bool need_to_pack;
57 #endif // EIGEN_USE_SIMPLE_THREAD_POOL
59 template<
typename Indices,
typename LeftArgType,
typename RightArgType>
60 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
61 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
63 typedef ThreadPoolDevice Device;
65 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
66 typedef TensorContractionEvaluatorBase<Self> Base;
68 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
82 typedef typename internal::conditional<
83 static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
84 typedef typename internal::conditional<
85 static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
87 static const int LDims =
88 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
89 static const int RDims =
90 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
91 static const int ContractDims = internal::array_size<Indices>::value;
93 typedef array<Index, LDims> left_dim_mapper_t;
94 typedef array<Index, RDims> right_dim_mapper_t;
96 typedef array<Index, ContractDims> contract_t;
97 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
98 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
100 static const int NumDims = LDims + RDims - 2 * ContractDims;
107 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
109 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
110 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
115 #ifndef EIGEN_USE_SIMPLE_THREAD_POOL
116 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
117 bool rhs_inner_dim_reordered,
int Alignment>
118 void evalProduct(
Scalar* buffer)
const {
119 typedef internal::TensorContractionInputMapper<
122 lhs_inner_dim_contiguous,
false,
Unaligned>
124 typedef internal::TensorContractionInputMapper<
127 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
129 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
130 typedef internal::gemm_pack_lhs<LhsScalar,
Index,
131 typename LhsMapper::SubMapper, Traits::mr,
134 typedef internal::gemm_pack_rhs<
135 RhsScalar,
Index,
typename RhsMapper::SubMapper, Traits::nr,
ColMajor>
137 typedef internal::gebp_kernel<LhsScalar, RhsScalar,
Index, OutputMapper,
138 Traits::mr, Traits::nr,
false,
false>
141 const Index m = this->m_i_size;
142 const Index n = this->m_j_size;
143 const Index k = this->m_k_size;
144 if (m == 0 ||
n == 0 || k == 0)
return;
169 bool shard_by_col = shardByCol(m,
n, 2);
175 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
177 blocking(k, m,
n, 2);
182 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
184 blocking(k, m,
n, 2);
194 const TensorOpCost cost =
195 contractionCost(m,
n, bm, bn, bk, shard_by_col,
false);
197 static_cast<double>(
n) * m, cost, this->
m_device.numThreads());
201 if (n == 1) num_threads = 1;
203 if (num_threads == 1) {
206 this->
template evalGemv<lhs_inner_dim_contiguous,
207 rhs_inner_dim_contiguous,
208 rhs_inner_dim_reordered, Alignment>(buffer);
210 this->
template evalGemm<lhs_inner_dim_contiguous,
211 rhs_inner_dim_contiguous,
212 rhs_inner_dim_reordered, Alignment>(buffer);
217 shard_by_col = shardByCol(m,
n, num_threads);
219 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
221 blocking(k, m,
n, num_threads);
226 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
228 blocking(k, m,
n, num_threads);
248 gm = coarsenM(m,
n, bm, bn, bk, gn, num_threads, shard_by_col);
249 gn = coarsenN(m,
n, bm, bn, bk, gm, num_threads, shard_by_col);
251 gn = coarsenN(m,
n, bm, bn, bk, gm, num_threads, shard_by_col);
252 gm = coarsenM(m,
n, bm, bn, bk, gn, num_threads, shard_by_col);
265 bool parallel_pack = num_threads >= nm * nn;
267 if (m * bk *
Index(
sizeof(LhsScalar)) +
n * bk *
Index(
sizeof(RhsScalar)) <=
269 parallel_pack =
true;
272 if ((shard_by_col ? nm : nn) == 1) parallel_pack =
false;
274 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
275 this->m_i_strides, this->m_left_contracting_strides,
278 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
279 this->m_j_strides, this->m_right_contracting_strides,
282 Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
283 OutputMapper>(this->
m_device, num_threads, lhs, rhs, buffer, m,
n,
284 k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
285 shard_by_col, parallel_pack)
290 template <
typename LhsPacker,
typename RhsPacker,
typename GebpKernel,
291 typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
294 Context(
const Device&
device,
int num_threads, LhsMapper& lhs,
304 num_threads_(num_threads),
305 shard_by_col_(shard_by_col),
306 parallel_pack_(parallel_pack),
329 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
330 (
x == P - 1 ? nm_ * nn_ : 0);
331 state_packing_ready_[
x] =
332 parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
333 state_kernel_[
x] =
new std::atomic<uint8_t>*[nm_];
334 for (
Index m = 0; m < nm_; m++) {
335 state_kernel_[
x][m] =
new std::atomic<uint8_t>[nn_];
340 state_kernel_[
x][m][
n].store(
341 (
x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
342 std::memory_order_relaxed);
349 divup<size_t>(bm_ * bk_ *
sizeof(LhsScalar), align) * align;
351 divup<size_t>(bn_ * bk_ *
sizeof(RhsScalar), align) * align;
353 (nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
354 char* mem =
static_cast<char*
>(packed_mem_);
355 for (
Index x = 0; x < numext::mini<Index>(nk_, P - 1);
x++) {
356 packed_lhs_[
x].resize(nm0_);
357 for (
Index m = 0; m < nm0_; m++) {
358 packed_lhs_[
x][m] =
reinterpret_cast<LhsScalar*
>(mem);
361 packed_rhs_[
x].resize(nn0_);
363 packed_rhs_[
x][
n] =
reinterpret_cast<RhsScalar*
>(mem);
371 for (
Index m = 0; m < nm_; m++)
delete[] state_kernel_[
x][m];
372 delete[] state_kernel_[
x];
390 const Device& device_;
394 OutputMapper output_;
395 const int num_threads_;
396 const bool shard_by_col_;
397 const bool parallel_pack_;
452 static const Index P = 3;
454 std::vector<LhsScalar*> packed_lhs_[P - 1];
455 std::vector<RhsScalar*> packed_rhs_[P - 1];
456 std::atomic<uint8_t>** state_kernel_[P];
461 std::atomic<Index> state_packing_ready_[P];
462 std::atomic<Index> state_switch_[P];
465 const Index mend = m * gm_ + gm(m);
466 for (
Index m1 = m * gm_; m1 < mend; m1++)
467 LhsPacker()(packed_lhs_[k % (P - 1)][m1],
468 lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
470 if (!parallel_pack_ && shard_by_col_) {
473 signal_switch(k + 1);
474 for (
Index n = nn_ - 1;
n >= 0;
n--) signal_kernel(m,
n, k,
n == 0);
479 const Index nend =
n * gn_ + gn(
n);
480 for (
Index n1 =
n * gn_; n1 < nend; n1++) {
489 memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ *
sizeof(
Scalar));
491 RhsPacker()(packed_rhs_[k % (P - 1)][n1],
492 rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
495 if (parallel_pack_ || shard_by_col_) {
496 signal_switch(k + 1);
497 for (
Index m = nm_ - 1; m >= 0; m--) signal_kernel(m,
n, k, m == 0);
507 const Index nend =
n * gn_ + gn(
n);
508 const Index mend = m * gm_ + gm(m);
510 for (
Index n1 =
n * gn_; n1 < nend; n1++) {
511 for (
Index m1 = m * gm_; m1 < mend; m1++)
512 GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
513 packed_lhs_[k % (P - 1)][m1],
514 packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
518 for (
Index m1 = m * gm_; m1 < mend; m1++)
519 for (
Index n1 =
n * gn_; n1 < nend; n1++) {
520 GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
521 packed_lhs_[k % (P - 1)][m1],
522 packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
526 signal_kernel(m,
n, k + 1,
false);
527 signal_switch(k + 2);
530 void signal_packing(
Index k) {
532 Index s = state_packing_ready_[k % P].fetch_sub(1);
535 state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
536 enqueue_packing(k, shard_by_col_);
540 std::atomic<uint8_t>* state = &state_kernel_[k % P][m][
n];
543 if (
s != 1 && state->fetch_sub(1) != 1)
return;
544 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
548 device_.enqueueNoNotification([=]() { kernel(m,
n, k); });
552 Index s = state_switch_[k % P].fetch_sub(v);
558 state_switch_[k % P] =
559 (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
564 if (parallel_pack_) {
565 enqueue_packing(k, !shard_by_col_);
566 enqueue_packing(k, shard_by_col_);
567 }
else if (shard_by_col_) {
568 enqueue_packing(k,
false);
570 enqueue_packing(k,
true);
578 }
else if (k == nk_) {
580 parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
587 void enqueue_packing(
Index k,
bool rhs) {
588 enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
591 void enqueue_packing_helper(
Index start,
Index end,
Index k,
bool rhs) {
592 if (end - start == 1) {
598 Index mid = (start + end) / 2;
599 device_.enqueueNoNotification(
600 [=]() { enqueue_packing_helper(mid, end, k, rhs); });
601 device_.enqueueNoNotification(
602 [=]() { enqueue_packing_helper(start, mid, k, rhs); });
607 Index bm(
Index m)
const {
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
608 Index bn(
Index n)
const {
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
609 Index bk(
Index k)
const {
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
611 Index gm(
Index m)
const {
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
612 Index gn(
Index n)
const {
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
614 Context(
const Context&) =
delete;
615 void operator=(
const Context&) =
delete;
626 if (m / num_threads >= Traits::nr &&
628 (
n / num_threads < Traits::nr ||
631 (
n / num_threads < 4 * Traits::nr &&
632 (
n % (num_threads * Traits::nr)) != 0 &&
634 ((m % (num_threads * Traits::nr)) == 0 ||
642 if (
n / num_threads < 16 * Traits::nr && m >
n * 32)
return false;
647 int num_threads,
bool shard_by_col)
const {
656 while (gm1 <= nm0 && nm1 ==
divup(nm0, gm1)) gm1++;
657 if (gm1 > nm0)
break;
659 int res = checkGrain(m,
n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
662 nm1 =
divup(nm0, gm1);
663 if (res == 0)
continue;
671 int num_threads,
bool shard_by_col)
const {
677 while (gn1 <= nn0 && nn1 ==
divup(nn0, gn1)) gn1++;
678 if (gn1 > nn0)
break;
679 int res = checkGrain(m,
n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
682 nn1 =
divup(nn0, gn1);
683 if (res == 0)
continue;
693 bool shard_by_col)
const {
694 const TensorOpCost cost =
695 contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col,
true);
697 static_cast<double>(bm) * gm * bn * gn, cost);
700 if (taskSize < 1)
return 1;
702 if (taskSize > 2)
return -1;
712 double new_parallelism =
static_cast<double>(new_tasks) /
713 (divup<int>(new_tasks, num_threads) * num_threads);
715 double old_parallelism =
static_cast<double>(old_tasks) /
716 (divup<int>(old_tasks, num_threads) * num_threads);
717 if (new_parallelism > old_parallelism || new_parallelism == 1)
return 1;
721 #else // EIGEN_USE_SIMPLE_THREAD_POOL
723 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
724 void evalProduct(
Scalar* buffer)
const {
725 if (this->m_j_size == 1) {
726 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
730 evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
733 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
734 void evalGemm(
Scalar* buffer)
const {
736 const Index k = this->m_k_size;
739 const Index m = this->m_i_size;
742 const Index n = this->m_j_size;
752 LeftEvaluator, left_nocontract_t,
753 contract_t, lhs_packet_size,
754 lhs_inner_dim_contiguous,
758 RightEvaluator, right_nocontract_t,
759 contract_t, rhs_packet_size,
760 rhs_inner_dim_contiguous,
761 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
763 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
766 typedef internal::gemm_pack_lhs<LhsScalar,
Index,
typename LhsMapper::SubMapper, Traits::mr,
767 Traits::LhsProgress,
ColMajor> LhsPacker;
768 typedef internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> RhsPacker;
771 typedef internal::gebp_kernel<LhsScalar, RhsScalar,
Index, OutputMapper,
772 Traits::mr, Traits::nr,
false,
false> GebpKernel;
774 typedef internal::packLhsArg<LhsScalar, LhsMapper, Index> packLArg;
775 typedef internal::packRhsAndKernelArg<LhsScalar, RhsScalar, RhsMapper, OutputMapper, Index> packRKArg;
778 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
779 this->m_left_contracting_strides, this->m_k_strides);
781 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
782 this->m_right_contracting_strides, this->m_k_strides);
784 OutputMapper output(buffer, m);
788 internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m,
n, num_threads);
789 Index mc = blocking.mc();
790 Index nc = blocking.nc();
791 Index kc = blocking.kc();
796 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
797 const Index k_blocks = CEIL_DIV(k, kc);
798 const Index n_blocks = CEIL_DIV(
n, nc);
799 const Index m_blocks = CEIL_DIV(m, mc);
800 const Index sizeA = mc * kc;
801 const Index sizeB = kc * nc;
815 MaxSizeVector<LhsScalar *> blockAs(num_threads);
816 for (
int i = 0; i < num_threads; i++) {
817 blockAs.push_back(
static_cast<LhsScalar *
>(this->
m_device.allocate(sizeA *
sizeof(LhsScalar))));
824 MaxSizeVector<RhsScalar *> blockBs(n_blocks);
825 for (
int i = 0; i < n_blocks; i++) {
826 blockBs.push_back(
static_cast<RhsScalar *
>(this->
m_device.allocate(sizeB *
sizeof(RhsScalar))));
830 MaxSizeVector<Notification*> lhs_notifications(num_threads,
nullptr);
833 const Index num_kernel_notifications = num_threads * n_blocks;
834 MaxSizeVector<Notification*> kernel_notifications(num_kernel_notifications,
837 for (
Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) {
838 const Index k_start = k_block_idx * kc;
842 for (
Index m_block_idx = 0; m_block_idx < m_blocks; m_block_idx += numBlockAs) {
845 for (
Index mt_block_idx = m_block_idx; mt_block_idx < m_block_idx+num_blocks; mt_block_idx++) {
846 const Index m_start = mt_block_idx * mc;
850 Index blockAId = (k_block_idx * m_blocks + mt_block_idx) % num_threads;
852 for (
int i = 0; i < n_blocks; ++i) {
853 Index notification_id = (blockAId * n_blocks + i);
856 if (kernel_notifications[notification_id]) {
857 wait_until_ready(kernel_notifications[notification_id]);
858 delete kernel_notifications[notification_id];
860 kernel_notifications[notification_id] =
new Notification();
862 const packLArg
arg = {
874 delete lhs_notifications[blockAId];
875 lhs_notifications[blockAId] =
876 this->
m_device.enqueue(&Self::packLhs<packLArg, LhsPacker>,
arg);
880 const Index m_base_start = m_block_idx * mc;
881 const bool need_to_pack = m_block_idx == 0;
883 for (
Index n_block_idx = 0; n_block_idx < n_blocks; n_block_idx++) {
884 const Index n_start = n_block_idx * nc;
890 for (
Index i = num_blocks; i < num_threads; ++i) {
891 Index blockAId = (k_block_idx * m_blocks + i + m_block_idx) % num_threads;
892 Index future_id = (blockAId * n_blocks + n_block_idx);
893 wait_until_ready(kernel_notifications[future_id]);
899 blockBs[n_block_idx],
916 &kernel_notifications,
924 this->
m_device.enqueueNoNotification(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>,
arg);
930 for (
size_t i = 0; i < kernel_notifications.size(); ++i) {
931 wait_until_ready(kernel_notifications[i]);
932 delete kernel_notifications[i];
937 for (
size_t i = 0; i < lhs_notifications.size(); ++i) {
938 delete lhs_notifications[i];
942 for (
size_t i = 0; i < blockAs.size(); i++) {
943 this->
m_device.deallocate(blockAs[i]);
945 for (
size_t i = 0; i < blockBs.size(); i++) {
946 this->
m_device.deallocate(blockBs[i]);
957 template <
typename packLArg,
typename LhsPacker>
958 static void packLhs(
const packLArg
arg) {
973 template <
typename packRKArg,
typename RhsPacker,
typename GebpKernel>
974 static void packRhsAndKernel(packRKArg
arg) {
975 if (
arg.need_to_pack) {
981 for (
Index mt_block_idx = 0; mt_block_idx <
arg.num_blockAs; mt_block_idx++) {
982 const Index m_base_start =
arg.m +
arg.mc*mt_block_idx;
983 if (m_base_start <
arg.max_m) {
984 Index blockAId = (
arg.k_block_idx *
arg.m_blocks + mt_block_idx +
arg.m_block_idx) %
arg.num_threads;
985 wait_until_ready((*
arg.lhs_notifications)[blockAId]);
987 gebp(
arg.output.getSubMapper(m_base_start,
arg.n),
988 (*
arg.blockAs)[blockAId],
arg.blockB,
992 const Index set_idx = blockAId *
arg.n_blocks +
arg.n_block_idx;
993 (*
arg.kernel_notifications)[set_idx]->Notify();
997 #endif // EIGEN_USE_SIMPLE_THREAD_POOL
1000 bool shard_by_col,
bool prepacked)
const {
1004 const double kd =
static_cast<double>(bk);
1008 double computeBandwidth = bk == 1 ? 4.0 :
1009 (shard_by_col ? bn : bm) < Traits::nr ||
1010 (shard_by_col ? bm : bn) < Traits::mr ? 2.0 : 0.5;
1011 #ifndef EIGEN_VECTORIZE_FMA
1015 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1018 TensorOpCost cost = TensorOpCost(0, 0, kd * computeBandwidth,
true, packed_size);
1020 cost += TensorOpCost(0,
sizeof(
CoeffReturnType), 0,
true, output_packet_size);
1028 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd /
n);
1029 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd / m);
1033 lhsCost.dropMemoryCost();
1035 rhsCost.dropMemoryCost();
1036 return cost + lhsCost + rhsCost;
1042 #endif // EIGEN_USE_THREADS
1043 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H