10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H 14 #ifdef EIGEN_USE_THREADS 18 #ifdef EIGEN_USE_SIMPLE_THREAD_POOL 21 template<
typename LhsScalar,
typename LhsMapper,
typename Index>
31 template<
typename LhsScalar,
typename RhsScalar,
typename RhsMapper,
typename OutputMapper,
typename Index>
32 struct packRhsAndKernelArg {
33 const MaxSizeVector<LhsScalar*>* blockAs;
43 const Index num_threads;
44 const Index num_blockAs;
46 const Index k_block_idx;
47 const Index m_block_idx;
48 const Index n_block_idx;
51 MaxSizeVector<Notification*>* kernel_notifications;
52 const MaxSizeVector<Notification*>* lhs_notifications;
53 const bool need_to_pack;
57 #endif // EIGEN_USE_SIMPLE_THREAD_POOL 59 template<
typename Indices,
typename LeftArgType,
typename RightArgType>
60 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
61 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
63 typedef ThreadPoolDevice Device;
65 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
66 typedef TensorContractionEvaluatorBase<Self>
Base;
68 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
71 typedef typename XprType::CoeffReturnType CoeffReturnType;
82 typedef typename internal::conditional<
83 static_cast<int>(Layout) == static_cast<int>(
ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
84 typedef typename internal::conditional<
85 static_cast<int>(Layout) == static_cast<int>(
ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
87 static const int LDims =
88 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
89 static const int RDims =
90 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
91 static const int ContractDims = internal::array_size<Indices>::value;
93 typedef array<Index, LDims> left_dim_mapper_t;
94 typedef array<Index, RDims> right_dim_mapper_t;
96 typedef array<Index, ContractDims> contract_t;
97 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
98 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
100 static const int NumDims = LDims + RDims - 2 * ContractDims;
102 typedef DSizes<Index, NumDims> Dimensions;
107 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
109 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
110 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
112 TensorEvaluator(
const XprType& op,
const Device& device) :
115 #ifndef EIGEN_USE_SIMPLE_THREAD_POOL 116 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
117 bool rhs_inner_dim_reordered,
int Alignment>
118 void evalProduct(Scalar* buffer)
const {
125 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
126 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
127 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
128 typedef internal::TensorContractionInputMapper<
131 lhs_inner_dim_contiguous,
false,
Unaligned>
133 typedef internal::TensorContractionInputMapper<
136 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
138 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
139 typedef internal::gemm_pack_lhs<LhsScalar,
Index,
140 typename LhsMapper::SubMapper, Traits::mr,
143 typedef internal::gemm_pack_rhs<
144 RhsScalar,
Index,
typename RhsMapper::SubMapper, Traits::nr,
ColMajor>
146 typedef internal::gebp_kernel<LhsScalar, RhsScalar,
Index, OutputMapper,
147 Traits::mr, Traits::nr,
false,
false>
150 const Index m = this->m_i_size;
151 const Index n = this->m_j_size;
152 const Index k = this->m_k_size;
153 if (m == 0 || n == 0 || k == 0)
return;
178 bool shard_by_col = shardByCol(m, n, 2);
184 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
186 blocking(k, m, n, 2);
191 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
193 blocking(k, m, n, 2);
203 const TensorOpCost cost =
204 contractionCost(m, n, bm, bn, bk, shard_by_col,
false);
206 static_cast<double>(n) * m, cost, this->m_device.numThreads());
210 if (n == 1) num_threads = 1;
212 if (num_threads == 1) {
215 this->
template evalGemv<lhs_inner_dim_contiguous,
216 rhs_inner_dim_contiguous,
217 rhs_inner_dim_reordered, Alignment>(buffer);
219 this->
template evalGemm<lhs_inner_dim_contiguous,
220 rhs_inner_dim_contiguous,
221 rhs_inner_dim_reordered, Alignment>(buffer);
226 shard_by_col = shardByCol(m, n, num_threads);
228 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
230 blocking(k, m, n, num_threads);
235 internal::TensorContractionBlocking<LhsMapper, RhsMapper,
Index,
237 blocking(k, m, n, num_threads);
244 Index nm0 =
divup(m, bm);
245 Index nn0 =
divup(n, bn);
246 Index nk =
divup(k, bk);
257 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
258 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
260 gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
261 gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
264 Index nm =
divup(nm0, gm);
265 Index nn =
divup(nn0, gn);
274 bool parallel_pack = num_threads >= nm * nn;
276 if (m * bk *
Index(
sizeof(LhsScalar)) + n * bk *
Index(
sizeof(RhsScalar)) <=
278 parallel_pack =
true;
281 if ((shard_by_col ? nm : nn) == 1) parallel_pack =
false;
283 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
284 this->m_i_strides, this->m_left_contracting_strides,
287 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
288 this->m_j_strides, this->m_right_contracting_strides,
291 Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
292 OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n,
293 k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
294 shard_by_col, parallel_pack)
299 template <
typename LhsPacker,
typename RhsPacker,
typename GebpKernel,
300 typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
303 Context(
const Device& device,
int num_threads, LhsMapper& lhs,
304 RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
305 Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
306 Index gn, Index nm0, Index nn0,
bool shard_by_col,
313 num_threads_(num_threads),
314 shard_by_col_(shard_by_col),
315 parallel_pack_(parallel_pack),
330 for (Index x = 0;
x < P;
x++) {
338 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
339 (
x == P - 1 ? nm_ * nn_ : 0);
340 state_packing_ready_[
x] =
341 parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
342 state_kernel_[
x] =
new std::atomic<uint8_t>*[nm_];
343 for (Index m = 0; m < nm_; m++) {
344 state_kernel_[
x][m] =
new std::atomic<uint8_t>[nn_];
348 for (Index n = 0; n < nn_; n++)
349 state_kernel_[x][m][n].store(
350 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
351 std::memory_order_relaxed);
358 divup<size_t>(bm_ * bk_ *
sizeof(LhsScalar), align) * align;
360 divup<size_t>(bn_ * bk_ *
sizeof(RhsScalar), align) * align;
362 (nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
363 char* mem =
static_cast<char*
>(packed_mem_);
364 for (Index x = 0; x < numext::mini<Index>(nk_, P - 1);
x++) {
365 packed_lhs_[
x].resize(nm0_);
366 for (Index m = 0; m < nm0_; m++) {
367 packed_lhs_[
x][m] =
reinterpret_cast<LhsScalar*
>(mem);
370 packed_rhs_[
x].resize(nn0_);
371 for (Index n = 0; n < nn0_; n++) {
372 packed_rhs_[
x][n] =
reinterpret_cast<RhsScalar*
>(mem);
379 for (Index x = 0;
x < P;
x++) {
380 for (Index m = 0; m < nm_; m++)
delete[] state_kernel_[x][m];
381 delete[] state_kernel_[
x];
399 const Device& device_;
402 Scalar*
const buffer_;
403 OutputMapper output_;
404 const int num_threads_;
405 const bool shard_by_col_;
406 const bool parallel_pack_;
461 static const Index P = 3;
463 std::vector<LhsScalar*> packed_lhs_[P - 1];
464 std::vector<RhsScalar*> packed_rhs_[P - 1];
465 std::atomic<uint8_t>** state_kernel_[P];
470 std::atomic<Index> state_packing_ready_[P];
471 std::atomic<Index> state_switch_[P];
473 void pack_lhs(Index m, Index k) {
474 const Index mend = m * gm_ + gm(m);
475 for (Index m1 = m * gm_; m1 < mend; m1++)
476 LhsPacker()(packed_lhs_[k % (P - 1)][m1],
477 lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
479 if (!parallel_pack_ && shard_by_col_) {
482 signal_switch(k + 1);
483 for (Index n = nn_ - 1; n >= 0; n--) signal_kernel(m, n, k, n == 0);
487 void pack_rhs(Index n, Index k) {
488 const Index nend = n * gn_ + gn(n);
489 for (Index n1 = n * gn_; n1 < nend; n1++) {
498 memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ *
sizeof(Scalar));
500 RhsPacker()(packed_rhs_[k % (P - 1)][n1],
501 rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
504 if (parallel_pack_ || shard_by_col_) {
505 signal_switch(k + 1);
506 for (Index m = nm_ - 1; m >= 0; m--) signal_kernel(m, n, k, m == 0);
512 void kernel(Index m, Index n, Index k) {
516 const Index nend = n * gn_ + gn(n);
517 const Index mend = m * gm_ + gm(m);
519 for (Index n1 = n * gn_; n1 < nend; n1++) {
520 for (Index m1 = m * gm_; m1 < mend; m1++)
521 GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
522 packed_lhs_[k % (P - 1)][m1],
523 packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
524 Scalar(1), -1, -1, 0, 0);
527 for (Index m1 = m * gm_; m1 < mend; m1++)
528 for (Index n1 = n * gn_; n1 < nend; n1++) {
529 GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
530 packed_lhs_[k % (P - 1)][m1],
531 packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
532 Scalar(1), -1, -1, 0, 0);
535 signal_kernel(m, n, k + 1,
false);
536 signal_switch(k + 2);
539 void signal_packing(Index k) {
541 Index
s = state_packing_ready_[k % P].fetch_sub(1);
544 state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
545 enqueue_packing(k, shard_by_col_);
548 void signal_kernel(Index m, Index n, Index k,
bool sync) {
549 std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
550 Index s = state->load();
552 if (s != 1 && state->fetch_sub(1) != 1)
return;
553 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
557 device_.enqueueNoNotification([=]() { kernel(m, n, k); });
560 void signal_switch(Index k, Index v = 1) {
561 Index s = state_switch_[k % P].fetch_sub(v);
567 state_switch_[k % P] =
568 (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
573 if (parallel_pack_) {
574 enqueue_packing(k, !shard_by_col_);
575 enqueue_packing(k, shard_by_col_);
576 }
else if (shard_by_col_) {
577 enqueue_packing(k,
false);
579 enqueue_packing(k,
true);
587 }
else if (k == nk_) {
589 parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
596 void enqueue_packing(Index k,
bool rhs) {
597 enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
600 void enqueue_packing_helper(Index start, Index end, Index k,
bool rhs) {
601 if (end - start == 1) {
607 Index mid = (start + end) / 2;
608 device_.enqueueNoNotification(
609 [=]() { enqueue_packing_helper(mid, end, k, rhs); });
610 device_.enqueueNoNotification(
611 [=]() { enqueue_packing_helper(start, mid, k, rhs); });
616 Index bm(Index m)
const {
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
617 Index bn(Index n)
const {
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
618 Index bk(Index k)
const {
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
620 Index gm(Index m)
const {
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
621 Index gn(Index n)
const {
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
623 Context(
const Context&) =
delete;
624 void operator=(
const Context&) =
delete;
628 static bool shardByCol(Index m, Index n, Index num_threads) {
635 if (m / num_threads >= Traits::nr &&
637 (n / num_threads < Traits::nr ||
640 (n / num_threads < 4 * Traits::nr &&
641 (n % (num_threads * Traits::nr)) != 0 &&
643 ((m % (num_threads * Traits::nr)) == 0 ||
651 if (n / num_threads < 16 * Traits::nr && m > n * 32)
return false;
655 Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
656 int num_threads,
bool shard_by_col)
const {
659 Index nm0 =
divup(m, bm);
665 while (gm1 <= nm0 && nm1 ==
divup(nm0, gm1)) gm1++;
666 if (gm1 > nm0)
break;
668 int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
671 nm1 =
divup(nm0, gm1);
672 if (res == 0)
continue;
679 Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
680 int num_threads,
bool shard_by_col)
const {
683 Index nn0 =
divup(n, bn);
686 while (gn1 <= nn0 && nn1 ==
divup(nn0, gn1)) gn1++;
687 if (gn1 > nn0)
break;
688 int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
691 nn1 =
divup(nn0, gn1);
692 if (res == 0)
continue;
700 int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
701 Index gn, Index oldgm, Index oldgn,
int num_threads,
702 bool shard_by_col)
const {
703 const TensorOpCost cost =
704 contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col,
true);
706 static_cast<double>(bm) * gm * bn * gn, cost);
709 if (taskSize < 1)
return 1;
711 if (taskSize > 2)
return -1;
718 Index nm0 =
divup(m, bm);
719 Index nn0 =
divup(n, bn);
720 Index new_tasks =
divup(nm0, gm) *
divup(nn0, gn);
721 double new_parallelism =
static_cast<double>(new_tasks) /
722 (divup<int>(new_tasks, num_threads) * num_threads);
723 Index old_tasks =
divup(nm0, oldgm) *
divup(nn0, oldgn);
724 double old_parallelism =
static_cast<double>(old_tasks) /
725 (divup<int>(old_tasks, num_threads) * num_threads);
726 if (new_parallelism > old_parallelism || new_parallelism == 1)
return 1;
730 #else // EIGEN_USE_SIMPLE_THREAD_POOL 732 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
733 void evalProduct(Scalar* buffer)
const {
734 if (this->m_j_size == 1) {
735 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
739 evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
742 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
743 void evalGemm(Scalar* buffer)
const {
745 const Index k = this->m_k_size;
748 const Index m = this->m_i_size;
751 const Index n = this->m_j_size;
754 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
761 LeftEvaluator, left_nocontract_t,
762 contract_t, lhs_packet_size,
763 lhs_inner_dim_contiguous,
767 RightEvaluator, right_nocontract_t,
768 contract_t, rhs_packet_size,
769 rhs_inner_dim_contiguous,
770 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
772 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
775 typedef internal::gemm_pack_lhs<LhsScalar,
Index,
typename LhsMapper::SubMapper, Traits::mr,
776 Traits::LhsProgress,
ColMajor> LhsPacker;
777 typedef internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> RhsPacker;
780 typedef internal::gebp_kernel<LhsScalar, RhsScalar,
Index, OutputMapper,
781 Traits::mr, Traits::nr,
false,
false> GebpKernel;
783 typedef internal::packLhsArg<LhsScalar, LhsMapper, Index> packLArg;
784 typedef internal::packRhsAndKernelArg<LhsScalar, RhsScalar, RhsMapper, OutputMapper, Index> packRKArg;
787 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
788 this->m_left_contracting_strides, this->m_k_strides);
790 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
791 this->m_right_contracting_strides, this->m_k_strides);
793 OutputMapper output(buffer, m);
796 const Index num_threads = this->m_device.numThreads();
797 internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, num_threads);
798 Index mc = blocking.mc();
799 Index nc = blocking.nc();
800 Index kc = blocking.kc();
805 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) 806 const Index k_blocks = CEIL_DIV(k, kc);
807 const Index n_blocks = CEIL_DIV(n, nc);
808 const Index m_blocks = CEIL_DIV(m, mc);
809 const Index sizeA = mc * kc;
810 const Index sizeB = kc * nc;
823 const Index numBlockAs = numext::mini(num_threads, m_blocks);
824 MaxSizeVector<LhsScalar *> blockAs(num_threads);
825 for (
int i = 0; i < num_threads; i++) {
826 blockAs.push_back(static_cast<LhsScalar *>(this->m_device.allocate(sizeA *
sizeof(LhsScalar))));
833 MaxSizeVector<RhsScalar *> blockBs(n_blocks);
834 for (
int i = 0; i < n_blocks; i++) {
835 blockBs.push_back(static_cast<RhsScalar *>(this->m_device.allocate(sizeB *
sizeof(RhsScalar))));
839 MaxSizeVector<Notification*> lhs_notifications(num_threads,
nullptr);
842 const Index num_kernel_notifications = num_threads * n_blocks;
843 MaxSizeVector<Notification*> kernel_notifications(num_kernel_notifications,
846 for (Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) {
847 const Index k_start = k_block_idx * kc;
849 const Index actual_kc = numext::mini(k_start + kc, k) - k_start;
851 for (Index m_block_idx = 0; m_block_idx < m_blocks; m_block_idx += numBlockAs) {
852 const Index num_blocks = numext::mini(m_blocks-m_block_idx, numBlockAs);
854 for (Index mt_block_idx = m_block_idx; mt_block_idx < m_block_idx+num_blocks; mt_block_idx++) {
855 const Index m_start = mt_block_idx * mc;
856 const Index actual_mc = numext::mini(m_start + mc, m) - m_start;
859 Index blockAId = (k_block_idx * m_blocks + mt_block_idx) % num_threads;
861 for (
int i = 0; i < n_blocks; ++i) {
862 Index notification_id = (blockAId * n_blocks + i);
865 if (kernel_notifications[notification_id]) {
866 wait_until_ready(kernel_notifications[notification_id]);
867 delete kernel_notifications[notification_id];
869 kernel_notifications[notification_id] =
new Notification();
871 const packLArg
arg = {
883 delete lhs_notifications[blockAId];
884 lhs_notifications[blockAId] =
885 this->m_device.enqueue(&Self::packLhs<packLArg, LhsPacker>, arg);
889 const Index m_base_start = m_block_idx * mc;
890 const bool need_to_pack = m_block_idx == 0;
892 for (Index n_block_idx = 0; n_block_idx < n_blocks; n_block_idx++) {
893 const Index n_start = n_block_idx * nc;
894 const Index actual_nc = numext::mini(n_start + nc, n) - n_start;
899 for (Index i = num_blocks; i < num_threads; ++i) {
900 Index blockAId = (k_block_idx * m_blocks + i + m_block_idx) % num_threads;
901 Index future_id = (blockAId * n_blocks + n_block_idx);
902 wait_until_ready(kernel_notifications[future_id]);
908 blockBs[n_block_idx],
925 &kernel_notifications,
933 this->m_device.enqueueNoNotification(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg);
939 for (
size_t i = 0; i < kernel_notifications.size(); ++i) {
940 wait_until_ready(kernel_notifications[i]);
941 delete kernel_notifications[i];
946 for (
size_t i = 0; i < lhs_notifications.size(); ++i) {
947 delete lhs_notifications[i];
951 for (
size_t i = 0; i < blockAs.size(); i++) {
952 this->m_device.deallocate(blockAs[i]);
954 for (
size_t i = 0; i < blockBs.size(); i++) {
955 this->m_device.deallocate(blockBs[i]);
966 template <
typename packLArg,
typename LhsPacker>
967 static void packLhs(
const packLArg arg) {
970 pack_lhs(arg.blockA, arg.lhs.getSubMapper(arg.m_start, arg.k_start), arg.kc, arg.mc);
982 template <
typename packRKArg,
typename RhsPacker,
typename GebpKernel>
983 static void packRhsAndKernel(packRKArg arg) {
984 if (arg.need_to_pack) {
986 pack_rhs(arg.blockB, arg.rhs.getSubMapper(arg.k, arg.n), arg.kc, arg.nc);
990 for (Index mt_block_idx = 0; mt_block_idx < arg.num_blockAs; mt_block_idx++) {
991 const Index m_base_start = arg.m + arg.mc*mt_block_idx;
992 if (m_base_start < arg.max_m) {
993 Index blockAId = (arg.k_block_idx * arg.m_blocks + mt_block_idx + arg.m_block_idx) % arg.num_threads;
994 wait_until_ready((*arg.lhs_notifications)[blockAId]);
995 const Index actual_mc = numext::mini(m_base_start + arg.mc, arg.max_m) - m_base_start;
996 gebp(arg.output.getSubMapper(m_base_start, arg.n),
997 (*arg.blockAs)[blockAId], arg.blockB,
998 actual_mc, arg.kc, arg.nc, Scalar(1), -1, -1, 0, 0);
1001 const Index set_idx = blockAId * arg.n_blocks + arg.n_block_idx;
1002 (*arg.kernel_notifications)[set_idx]->Notify();
1006 #endif // EIGEN_USE_SIMPLE_THREAD_POOL 1008 TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1009 bool shard_by_col,
bool prepacked)
const {
1013 const double kd =
static_cast<double>(bk);
1017 double computeBandwidth = bk == 1 ? 4.0 :
1018 (shard_by_col ? bn : bm) < Traits::nr ||
1019 (shard_by_col ? bm : bn) < Traits::mr ? 2.0 : 0.5;
1020 #ifndef EIGEN_VECTORIZE_FMA 1024 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1027 TensorOpCost cost = TensorOpCost(0, 0, kd * computeBandwidth,
true, packed_size);
1029 cost += TensorOpCost(0,
sizeof(CoeffReturnType), 0,
true, output_packet_size);
1037 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd / n);
1038 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd / m);
1042 lhsCost.dropMemoryCost();
1044 rhsCost.dropMemoryCost();
1045 return cost + lhsCost + rhsCost;
1051 #endif // EIGEN_USE_THREADS 1052 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int numThreads(double output_size, const TensorOpCost &cost_per_coeff, int max_threads)
typename XprType::Scalar type
EIGEN_DEVICE_FUNC void * aligned_malloc(std::size_t size)
#define EIGEN_MAX_ALIGN_BYTES
EIGEN_DEVICE_FUNC const Scalar & x
EIGEN_DEVICE_FUNC void aligned_free(void *ptr)
std::ptrdiff_t l2CacheSize()
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ArgReturnType arg() const
void run(Expr &expr, Dev &dev)
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double taskSize(double output_size, const TensorOpCost &cost_per_coeff)
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T divup(const X x, const Y y)
internal::packet_traits< Scalar >::type type