10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H 
   11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H 
   14 #ifdef EIGEN_USE_THREADS 
   18 template<
typename Indices, 
typename LeftArgType, 
typename RightArgType, 
typename OutputKernelType>
 
   19 struct TensorEvaluator<const TensorContractionOp<
Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
 
   20     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
 
   22   typedef ThreadPoolDevice Device;
 
   24   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
 
   25   typedef TensorContractionEvaluatorBase<Self> 
Base;
 
   27   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> 
XprType;
 
   41   typedef typename internal::conditional<
 
   42     static_cast<int>(
Layout) == 
static_cast<int>(
ColMajor), LeftArgType, RightArgType>
::type EvalLeftArgType;
 
   43   typedef typename internal::conditional<
 
   44     static_cast<int>(
Layout) == 
static_cast<int>(
ColMajor), RightArgType, LeftArgType>
::type EvalRightArgType;
 
   46   static const int LDims =
 
   47       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>
::value;
 
   48   static const int RDims =
 
   49       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>
::value;
 
   56   typedef array<
Index, LDims - ContractDims> left_nocontract_t;
 
   57   typedef array<
Index, RDims - ContractDims> right_nocontract_t;
 
   59   static const int NumDims = LDims + RDims - 2 * ContractDims;
 
   66   typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
 
   68   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
 
   69   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
 
   74   template <
int Alignment>
 
   76     evalProductImpl<NoCallback, Alignment>(
buffer, NoCallback());
 
   79   template <
typename EvalToCallback, 
int Alignment>
 
   80   void evalProductAsync(
Scalar* 
buffer, EvalToCallback done)
 const {
 
   81     evalProductImpl<EvalToCallback, Alignment>(
buffer, std::move(done));
 
   84   template <
typename DoneCallback, 
int Alignment>
 
   85   void evalProductImpl(
Scalar* 
buffer, DoneCallback done)
 const {
 
  101     static const bool IsEvalInSyncMode =
 
  104     const Index m = this->m_i_size;
 
  105     const Index n = this->m_j_size;
 
  106     const Index k = this->m_k_size;
 
  107     if (
m == 0 || 
n == 0 || 
k == 0) 
return;
 
  132     bool shard_by_col = shardByCol(
m, 
n, 2);
 
  138       internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar, 
Index,
 
  140           blocking(
k, 
m, 
n, 2);
 
  145       internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar, 
Index,
 
  147           blocking(
k, 
m, 
n, 2);
 
  157     const TensorOpCost cost =
 
  158         contractionCost(
m, 
n, bm, bn, bk, shard_by_col, 
false);
 
  160         static_cast<double>(
n) * 
m, cost, this->
m_device.numThreads());
 
  161     int num_threads_by_k = numThreadsInnerDim(
m, 
n, 
k);
 
  162     if (shardByInnerDim(
m, 
n, 
k, num_threads, num_threads_by_k)) {
 
  165       if (IsEvalInSyncMode) {
 
  166         EvalShardedByInnerDimContext<DoneCallback> ctx(
 
  167             this, num_threads_by_k, 
buffer, 
m, 
n, 
k, std::move(done));
 
  168         ctx.template run<Alignment>();
 
  170         auto* ctx = 
new EvalShardedByInnerDimContext<DoneCallback>(
 
  171             this, num_threads_by_k, 
buffer, 
m, 
n, 
k, std::move(done));
 
  172         ctx->template runAsync<Alignment>();
 
  180     if (
n == 1) num_threads = 1;
 
  182     if (num_threads == 1) {
 
  185       if (!IsEvalInSyncMode) done();
 
  190     shard_by_col = shardByCol(
m, 
n, num_threads);
 
  192       internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar, 
Index,
 
  194           blocking(
k, 
m, 
n, num_threads);
 
  199       internal::TensorContractionBlocking<
Scalar, LhsScalar, RhsScalar, 
Index,
 
  201           blocking(
k, 
m, 
n, num_threads);
 
  221       gm = coarsenM(
m, 
n, bm, bn, bk, 
gn, num_threads, shard_by_col);
 
  222       gn = coarsenN(
m, 
n, bm, bn, bk, gm, num_threads, shard_by_col);
 
  224       gn = coarsenN(
m, 
n, bm, bn, bk, gm, num_threads, shard_by_col);
 
  225       gm = coarsenM(
m, 
n, bm, bn, bk, 
gn, num_threads, shard_by_col);
 
  235     const Index sharding_dim_tasks = shard_by_col ? 
nn : nm;
 
  236     const int num_worker_threads = this->
m_device.numThreadsInPool();
 
  241     const float oversharding_factor =
 
  242         num_worker_threads <= 4  ? 8.0 :
 
  243         num_worker_threads <= 8  ? 4.0 :
 
  244         num_worker_threads <= 16 ? 2.0 :
 
  245         num_worker_threads <= 32 ? 1.0 :
 
  246         num_worker_threads <= 64 ? 0.8 :  0.6;
 
  248     const bool parallelize_by_sharding_dim_only =
 
  249         sharding_dim_tasks >= oversharding_factor * num_worker_threads;
 
  258     bool parallel_pack = num_threads >= nm * 
nn;
 
  260     if (
m * bk * 
Index(
sizeof(LhsScalar)) + 
n * bk * 
Index(
sizeof(RhsScalar)) <=
 
  262       parallel_pack = 
true;
 
  265     if ((shard_by_col ? nm : 
nn) == 1) parallel_pack = 
false;
 
  268     if (parallelize_by_sharding_dim_only) parallel_pack = 
false;
 
  271     if (IsEvalInSyncMode) {
 
  272 #define CONTEXT_ARGS                                                        \ 
  273   (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ 
  274    nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only,      \ 
  282 #define CONTEXT_ARGS                                                        \ 
  283   (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ 
  284    nn0, shard_by_col, parallel_pack, parallelize_by_sharding_dim_only,      \ 
  287                                         Alignment, CONTEXT_ARGS, 
run());
 
  298       eigen_assert(
false && 
"NoCallback should never be called");
 
  304   template <
typename DoneCallback, 
typename Context>
 
  305   class EvalParallelNotification;
 
  308   template <
typename Context>
 
  309   class EvalParallelNotification<NoCallback, Context> {
 
  311     EvalParallelNotification(Context*, NoCallback) {}
 
  312     void Notify() { done_.Notify(); }
 
  313     void Wait() { done_.Wait(); }
 
  319   template <
typename DoneCallback, 
typename Context>
 
  320   class EvalParallelNotification {
 
  322     EvalParallelNotification(Context* ctx, DoneCallback done)
 
  323         : ctx_(ctx), done_(
std::
move(done)) {}
 
  329       DoneCallback done_copy = std::move(done_);
 
  349   template <
typename DoneCallback, 
bool lhs_inner_dim_contiguous,
 
  350             bool rhs_inner_dim_contiguous, 
bool rhs_inner_dim_reordered,
 
  352   class EvalParallelContext {
 
  354     typedef internal::TensorContractionInputMapper<
 
  357         lhs_inner_dim_contiguous, 
false, 
Unaligned>
 
  359     typedef internal::TensorContractionInputMapper<
 
  362         rhs_inner_dim_contiguous, rhs_inner_dim_reordered, 
Unaligned>
 
  365     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
 
  367     typedef internal::TensorContractionKernel<
 
  368         Scalar, LhsScalar, RhsScalar, 
Index, OutputMapper, LhsMapper, RhsMapper>
 
  369         TensorContractionKernel;
 
  371     typedef typename TensorContractionKernel::LhsBlock LhsBlock;
 
  372     typedef typename TensorContractionKernel::RhsBlock RhsBlock;
 
  373     typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
 
  375     EvalParallelContext(
const Self* 
self, 
int num_threads, 
Scalar* 
buffer,
 
  380                         bool parallelize_by_sharding_dim_only,
 
  382         : created_by_thread_id_(
std::this_thread::get_id()),
 
  385           lhs_(
self->m_leftImpl, 
self->m_left_nocontract_strides,
 
  386                self->m_i_strides, 
self->m_left_contracting_strides,
 
  388           rhs_(
self->m_rightImpl, 
self->m_right_nocontract_strides,
 
  389                self->m_j_strides, 
self->m_right_contracting_strides,
 
  393           output_kernel_(
self->m_output_kernel),
 
  394           tensor_contraction_params_(
self->m_tensor_contraction_params),
 
  395           num_threads_(num_threads),
 
  396           shard_by_col_(shard_by_col),
 
  397           parallel_pack_(parallel_pack),
 
  398           parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
 
  412           kernel_(m_, k_, n_, bm_, bk_, bn_),
 
  413           num_thread_local_allocations_(0),
 
  417           thread_local_capacity(2 * (parallelize_by_sharding_dim_only_
 
  418                                          ? device_.numThreadsInPool()
 
  422           lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity,
 
  424           rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0,
 
  427       eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
 
  437                 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
 
  438                       (
x == 
P - 1 ? nm_ * nn_ : 0);
 
  439         state_packing_ready_[
x] =
 
  440             parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
 
  441         state_kernel_[
x] = 
new std::atomic<uint8_t>*[nm_];
 
  443           state_kernel_[
x][
m] = 
new std::atomic<uint8_t>[nn_];
 
  448             state_kernel_[
x][
m][
n].store(
 
  449                 (
x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
 
  450                 std::memory_order_relaxed);
 
  455       packed_mem_ = kernel_.allocateSlices(            
 
  459           std::min<Index>(nk_, 
P - 1),  
 
  460           packed_lhs_, packed_rhs_);
 
  462       if (parallelize_by_sharding_dim_only_) {
 
  463         const int num_worker_threads = device_.numThreadsInPool();
 
  466           can_use_thread_local_packed_ = 
new std::atomic<bool>[nn_];
 
  467           for (
int i = 0; 
i < nn_; ++
i)
 
  468             can_use_thread_local_packed_[
i].store(
true,
 
  469                                                   std::memory_order_relaxed);
 
  471           Index num_blocks = num_worker_threads * gn_;
 
  472           thread_local_pre_alocated_mem_ = kernel_.allocateSlices(  
 
  477               nullptr, &rhs_thread_local_pre_allocated_);
 
  480           can_use_thread_local_packed_ = 
new std::atomic<bool>[nm_];
 
  481           for (
int i = 0; 
i < nm_; ++
i)
 
  482             can_use_thread_local_packed_[
i].store(
true,
 
  483                                                   std::memory_order_relaxed);
 
  485           Index num_blocks = num_worker_threads * gm_;
 
  486           thread_local_pre_alocated_mem_ = kernel_.allocateSlices(  
 
  490               1, &lhs_thread_local_pre_allocated_,   
 
  496     ~EvalParallelContext() {
 
  498         for (
Index m = 0; 
m < nm_; 
m++) 
delete[] state_kernel_[
x][
m];
 
  499         delete[] state_kernel_[
x];
 
  501       kernel_.deallocate(device_, packed_mem_);
 
  502       if (parallelize_by_sharding_dim_only_) {
 
  503         kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
 
  504         delete[] can_use_thread_local_packed_;
 
  533     EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
 
  535     const Device& device_;
 
  539     OutputMapper output_;
 
  540     OutputKernelType output_kernel_;
 
  541     TensorContractionParams tensor_contraction_params_;
 
  542     const int num_threads_;
 
  543     const bool shard_by_col_;
 
  544     const bool parallel_pack_;
 
  545     const bool parallelize_by_sharding_dim_only_;
 
  566     TensorContractionKernel kernel_;
 
  605     BlockMemHandle packed_mem_;
 
  606     std::vector<LhsBlock> packed_lhs_[
P - 1];
 
  607     std::vector<RhsBlock> packed_rhs_[
P - 1];
 
  627     BlockMemHandle thread_local_pre_alocated_mem_;
 
  631     std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
 
  632     std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
 
  635     std::atomic<int> num_thread_local_allocations_;
 
  636     const int thread_local_capacity;
 
  644     template <
typename BlockType>
 
  645     class ThreadLocalBlocks {
 
  647       ThreadLocalBlocks() = 
default;
 
  649       ThreadLocalBlocks(BlockType* 
base, 
size_t grain_size)
 
  650           : is_pre_allocated_(true),
 
  651             thread_local_pre_allocated_base_(
base),
 
  652             grain_size_(grain_size) {}
 
  654       ThreadLocalBlocks(BlockMemHandle mem_handle,
 
  655                         std::vector<BlockType> blocks)
 
  656           : is_pre_allocated_(false),
 
  657             mem_handle_(
std::
move(mem_handle)),
 
  660       BlockType& 
block(
int grain_index) {
 
  663         return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index]
 
  664                                  : blocks_[grain_index];
 
  667       void Release(EvalParallelContext& ctx)
 const {
 
  668         if (!is_pre_allocated_) {
 
  669           ctx.kernel_.deallocate(ctx.device_, mem_handle_);
 
  673       size_t size()
 const {
 
  674         return is_pre_allocated_ ? grain_size_ : blocks_.size();
 
  678       bool is_pre_allocated_;
 
  681       BlockType* thread_local_pre_allocated_base_ = 
nullptr;
 
  682       size_t grain_size_ = 0;
 
  685       BlockMemHandle mem_handle_{};
 
  686       std::vector<BlockType> blocks_;
 
  695     template <
typename BlockType, 
bool is_rhs>
 
  696     class ThreadLocalBlocksInitialize {
 
  697       static constexpr 
bool kIsLhs =
 
  699       static const bool kIsRhs =
 
  701       static_assert(kIsLhs || kIsRhs, 
"Unkown block type");
 
  703       using Blocks = ThreadLocalBlocks<BlockType>;
 
  706       ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
 
  708             num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
 
  711         const int n = ctx_.num_thread_local_allocations_.fetch_add(
 
  712             1, std::memory_order_relaxed);
 
  714         if (
n >= num_worker_threads_) {
 
  715           ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
 
  717           ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, 
n, blocks);
 
  726       template <
bool pack_rhs, 
typename EvalCtx = EvalParallelContext>
 
  727       struct ThreadLocalBlocksAllocator;
 
  729       template <
typename EvalCtx>
 
  730       struct ThreadLocalBlocksAllocator<true, EvalCtx> {
 
  731         static void allocate(EvalCtx& ctx, Blocks& blocks) {
 
  732           std::vector<RhsBlock> rhs_blocks;
 
  733           BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
 
  738               nullptr, &rhs_blocks);
 
  740           blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle),
 
  741                                                std::move(rhs_blocks));
 
  744         static void reuse(EvalCtx& ctx, 
int index, Blocks& blocks) {
 
  745           RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
 
  746           blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
 
  750       template <
typename EvalCtx>
 
  751       struct ThreadLocalBlocksAllocator<false, EvalCtx> {
 
  752         static void allocate(EvalCtx& ctx, Blocks& blocks) {
 
  753           std::vector<LhsBlock> lhs_blocks;
 
  754           BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(
 
  759               &lhs_blocks, 
nullptr);
 
  761           blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle),
 
  762                                                std::move(lhs_blocks));
 
  765         static void reuse(EvalCtx& ctx, 
int index, Blocks& blocks) {
 
  766           LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
 
  767           blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
 
  771       EvalParallelContext& ctx_;
 
  772       const int num_worker_threads_;
 
  775     template <
typename BlockType>
 
  776     class ThreadLocalBlocksRelease {
 
  778       using Blocks = ThreadLocalBlocks<BlockType>;
 
  779       ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
 
  780       void operator()(Blocks& blocks) { blocks.Release(ctx_); }
 
  783       EvalParallelContext& ctx_;
 
  787     using ThreadLocalLhsInit =
 
  788         ThreadLocalBlocksInitialize<LhsBlock, 
false>;
 
  789     using ThreadLocalRhsInit =
 
  790         ThreadLocalBlocksInitialize<RhsBlock, 
true>;
 
  793     using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
 
  794     using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
 
  799                        ThreadLocalLhsRelease>
 
  800         lhs_thread_local_blocks_;
 
  802                        ThreadLocalRhsRelease>
 
  803         rhs_thread_local_blocks_;
 
  810     std::atomic<bool>* can_use_thread_local_packed_;
 
  812     std::atomic<uint8_t>** state_kernel_[
P];
 
  817     std::atomic<Index> state_packing_ready_[
P];
 
  818     std::atomic<Index> state_switch_[
P];
 
  821       if (use_thread_local) {
 
  823         ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.
local();
 
  826         return blocks.block(internal::convert_index<int>(grain_index)); 
 
  828         return packed_lhs_[
k % (
P - 1)][
m1];
 
  833       if (use_thread_local) {
 
  835         ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.
local();
 
  838         return blocks.block(internal::convert_index<int>(grain_index)); 
 
  840         return packed_rhs_[
k % (
P - 1)][
n1];
 
  855       bool use_thread_local = 
false;
 
  857       if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
 
  858           can_use_thread_local_packed_[
m].load(std::memory_order_relaxed)) {
 
  859         if (state_kernel_[
k % 
P][
m][0].load(std::memory_order_relaxed) == 1) {
 
  860           use_thread_local = 
true;
 
  866           can_use_thread_local_packed_[
m].store(
false,
 
  867                                                 std::memory_order_relaxed);
 
  871       const Index mend = 
m * gm_ + gm(
m);
 
  873         kernel_.packLhs(&packed_lhs(
m, 
k, 
m1, use_thread_local),
 
  874                         lhs_.getSubMapper(
m1 * bm_, 
k * bk_), bk(
k), bm(
m1));
 
  876       if (!parallel_pack_ && shard_by_col_) {
 
  877         assert(!use_thread_local);
 
  880         signal_switch(
k + 1);
 
  881         for (
Index n = nn_ - 1; 
n >= 0; 
n--) {
 
  882           bool sync = parallelize_by_sharding_dim_only_ || 
n == 0;
 
  883           signal_kernel(
m, 
n, 
k, sync, use_thread_local);
 
  889       bool use_thread_local = 
false;
 
  891       if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
 
  892           can_use_thread_local_packed_[
n].load(std::memory_order_relaxed)) {
 
  893         if (state_kernel_[
k % 
P][0][
n].load(std::memory_order_relaxed) == 1) {
 
  894           use_thread_local = 
true;
 
  900           can_use_thread_local_packed_[
n].store(
false,
 
  901                                                 std::memory_order_relaxed);
 
  907         if (!TensorContractionKernel::HasBeta && 
k == 0) {
 
  917           memset(buffer_ + 
n1 * bn_ * m_, 0, bn(
n1) * m_ * 
sizeof(
Scalar));
 
  919         kernel_.packRhs(&packed_rhs(
n, 
k, 
n1, use_thread_local),
 
  920                         rhs_.getSubMapper(
k * bk_, 
n1 * bn_), bk(
k), bn(
n1));
 
  923       if (parallel_pack_ || shard_by_col_) {
 
  924         signal_switch(
k + 1);
 
  925         for (
Index m = nm_ - 1; 
m >= 0; 
m--) {
 
  926           bool sync = parallelize_by_sharding_dim_only_ || 
m == 0;
 
  927           signal_kernel(
m, 
n, 
k, sync, use_thread_local);
 
  930         assert(!use_thread_local);
 
  940       const Index mend = 
m * gm_ + gm(
m);
 
  945           (TensorContractionKernel::HasBeta && 
k == 0) ? 
Scalar(0) : 
Scalar(1);
 
  950             const auto output_mapper = output_.getSubMapper(
m1 * bm_, 
n1 * bn_);
 
  953                 packed_lhs(
m, 
k, 
m1, !shard_by_col_ && use_thread_local),
 
  954                 packed_rhs(
n, 
k, 
n1, shard_by_col_ && use_thread_local), bm(
m1),
 
  959               output_kernel_(output_mapper, tensor_contraction_params_,
 
  960                              m1 * bm_, 
n1 * bn_, bm(
m1), bn(
n1));
 
  967             const auto output_mapper = output_.getSubMapper(
m1 * bm_, 
n1 * bn_);
 
  970                 packed_lhs(
m, 
k, 
m1, !shard_by_col_ && use_thread_local),
 
  971                 packed_rhs(
n, 
k, 
n1, shard_by_col_ && use_thread_local), bm(
m1),
 
  976               output_kernel_(output_mapper, tensor_contraction_params_,
 
  977                              m1 * bm_, 
n1 * bn_, bm(
m1), bn(
n1));
 
  981       signal_kernel(
m, 
n, 
k + 1, 
false, 
false);
 
  982       signal_switch(
k + 2);
 
  985     void signal_packing(
Index k) {
 
  987       Index s = state_packing_ready_[
k % 
P].fetch_sub(1);
 
  990       state_packing_ready_[
k % 
P] = shard_by_col_ ? nm_ : nn_;
 
  991       enqueue_packing(
k, shard_by_col_);
 
  995                        bool use_thread_local) {
 
  996       std::atomic<uint8_t>* 
state = &state_kernel_[
k % 
P][
m][
n];
 
  999       if (
s != 1 && 
state->fetch_sub(1) != 1) {
 
 1003       state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
 
 1005         kernel(
m, 
n, 
k, use_thread_local);
 
 1008         device_.enqueueNoNotification(
 
 1009             [=]() { kernel(
m, 
n, 
k, use_thread_local); });
 
 1014       Index s = state_switch_[
k % 
P].fetch_sub(
v);
 
 1020       state_switch_[
k % 
P] =
 
 1021           (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
 
 1026         if (parallel_pack_) {
 
 1027           enqueue_packing(
k, !shard_by_col_);
 
 1028           enqueue_packing(
k, shard_by_col_);
 
 1029         } 
else if (shard_by_col_) {
 
 1030           enqueue_packing(
k, 
false);
 
 1032           enqueue_packing(
k, 
true);
 
 1040       } 
else if (
k == nk_) {
 
 1041         signal_switch(
k + 1,
 
 1042                       parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
 
 1049     void enqueue_packing(
Index k, 
bool rhs) {
 
 1050       enqueue_packing_helper(0, rhs ? nn_ : nm_, 
k, rhs);
 
 1054       if (
end - start == 1) {
 
 1060         while (
end - start > 1) {
 
 1062           device_.enqueueNoNotification(
 
 1063               [=]() { enqueue_packing_helper(mid, 
end, 
k, rhs); });
 
 1077           (parallelize_by_sharding_dim_only_&& shard_by_col_ == rhs) &&
 
 1078           (
k > 0 || std::this_thread::get_id() == created_by_thread_id_);
 
 1081           device_.enqueueNoNotification(
 
 1082               [=]() { enqueue_packing_helper(start, 
end, 
k, rhs); });
 
 1084           enqueue_packing_helper(start, 
end, 
k, rhs);
 
 1090     Index bm(
Index m)
 const { 
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
 
 1091     Index bn(
Index n)
 const { 
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
 
 1092     Index bk(
Index k)
 const { 
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
 
 1094     Index gm(
Index m)
 const { 
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
 
 1095     Index gn(
Index n)
 const { 
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
 
 1097     EvalParallelContext(
const EvalParallelContext&) = 
delete;
 
 1098     void operator=(
const EvalParallelContext&) = 
delete;
 
 1101   template <
bool lhs_inner_dim_contiguous, 
bool rhs_inner_dim_contiguous,
 
 1102             bool rhs_inner_dim_reordered, 
int Alignment>
 
 1103   using SyncEvalParallelContext =
 
 1104       EvalParallelContext<NoCallback, lhs_inner_dim_contiguous,
 
 1105                           rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
 
 1115   template <
typename DoneCallback>
 
 1116   struct EvalShardedByInnerDimContext {
 
 1117     EvalShardedByInnerDimContext(
const Self* 
self, 
int num_threads,
 
 1120                                  DoneCallback done_callback)
 
 1122           m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
 
 1123           m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
 
 1124           m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
 
 1129           done(
std::
move(done_callback)),
 
 1130           buffer_size_bytes(
m * 
n * sizeof(
Scalar)),
 
 1131           block_size(blockSize(
k, num_threads)),
 
 1134           l0_ranges(
divup<
Index>(num_blocks, l0_size)),
 
 1135           l0_state(l0_ranges),
 
 1136           block_buffers(num_blocks) {
 
 1138       for (
int i = 0; 
i < l0_ranges; ++
i) {
 
 1139         const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, 
i);
 
 1140         l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
 
 1144       for (
Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
 
 1145         Scalar* buf = block_idx == 0
 
 1147                           : 
static_cast<Scalar*
>(evaluator->m_device.allocate(
 
 1148                                 buffer_size_bytes));
 
 1149         block_buffers.emplace_back(buf);
 
 1153     ~EvalShardedByInnerDimContext() {
 
 1154       for (
Index i = 1; 
i < num_blocks; ++
i) {
 
 1155         evaluator->m_device.deallocate(block_buffers[
i]);
 
 1159     template <
int Alignment>
 
 1161       Barrier barrier(internal::convert_index<int>(num_blocks));
 
 1162       eval<Alignment>(barrier, 0, num_blocks);
 
 1166       aggregateL0Blocks<Alignment>();
 
 1169       applyOutputKernel();
 
 1172     template <
int Alignment>
 
 1174       evalAsync<Alignment>(0, num_blocks);
 
 1182     const Self* evaluator;  
 
 1185     bool m_lhs_inner_dim_contiguous;
 
 1186     bool m_rhs_inner_dim_contiguous;
 
 1187     bool m_rhs_inner_dim_reordered;
 
 1201     Index buffer_size_bytes;
 
 1207     std::atomic<int> num_pending_blocks;
 
 1225     static const Index l0_size = 4;
 
 1229     MaxSizeVector<std::atomic<int>> l0_state;  
 
 1232     MaxSizeVector<Scalar*> block_buffers;  
 
 1234     template <
int Alignment>
 
 1236       Scalar* buf = block_buffers[block_idx];
 
 1239           evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
 
 1241            internal::convert_index<int>(num_blocks)));
 
 1244       const Index l0_index = block_idx / l0_size;
 
 1245       const int v = l0_state[l0_index].fetch_sub(1);
 
 1251         const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
 
 1252         const Index dst_block_idx = l0_index * l0_size;
 
 1254         if (rng_size == l0_size) {
 
 1255           addAllToBuffer<Alignment>(
 
 1257               block_buffers[dst_block_idx + 1],
 
 1258               block_buffers[dst_block_idx + 2],
 
 1259               block_buffers[dst_block_idx + 3],
 
 1260                block_buffers[dst_block_idx]);
 
 1263           for (
int i = 1; 
i < rng_size; ++
i) {
 
 1264             addToBuffer<Alignment>(
m * 
n,
 
 1265                                    block_buffers[dst_block_idx + 
i],
 
 1266                                    block_buffers[dst_block_idx]);
 
 1273     template <
int Alignment>
 
 1274     void aggregateL0Blocks()
 const {
 
 1277       for (; l0_index + 2 < l0_ranges; l0_index += 3) {
 
 1278         addAllToBuffer<Alignment>(
 
 1280             block_buffers[(l0_index + 0) * l0_size],
 
 1281             block_buffers[(l0_index + 1) * l0_size],
 
 1282             block_buffers[(l0_index + 2) * l0_size],
 
 1286       for (; l0_index < l0_ranges; ++l0_index) {
 
 1287         addToBuffer<Alignment>(
m * 
n, block_buffers[l0_index * l0_size],
 
 1292     void applyOutputKernel()
 const {
 
 1293       typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
 
 1294       evaluator->m_output_kernel(
 
 1295           OutputMapper(
result, 
m), evaluator->m_tensor_contraction_params,
 
 1300     Index actualBlockSize(
Index block_idx)
 const {
 
 1301       return block_idx + 1 < num_blocks
 
 1303                  : 
k + block_size - block_size * num_blocks;
 
 1308                           Index range_idx)
 const {
 
 1310       return range_idx + 1 < num_ranges
 
 1312                  : num_blocks + range_size - range_size * num_ranges;
 
 1315     template <
int Alignment>
 
 1318       const int output_packet_size =
 
 1321       const size_t num_packets = 
n / output_packet_size;
 
 1322       for (; 
i < output_packet_size * num_packets; 
i += output_packet_size) {
 
 1324             internal::pload<PacketReturnType>(src_buf + 
i);
 
 1326             internal::ploadt<PacketReturnType, Alignment>(tgt_buf + 
i);
 
 1328         internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + 
i,
 
 1331       for (; 
i < 
n; ++
i) {
 
 1332         tgt_buf[
i] += src_buf[
i];
 
 1336     template <
int Alignment>
 
 1347       const int output_packet_size =
 
 1351       const size_t num_packets = 
n / output_packet_size;
 
 1352       for (; 
i < output_packet_size * num_packets; 
i += output_packet_size) {
 
 1353         const auto src_val0 = pload<PacketReturnType>(src_buf0 + 
i);
 
 1354         const auto src_val1 = pload<PacketReturnType>(src_buf1 + 
i);
 
 1355         const auto src_val2 = pload<PacketReturnType>(src_buf2 + 
i);
 
 1357         const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + 
i);
 
 1359             padd(
padd(dst_val, src_val0), 
padd(src_val1, src_val2));
 
 1361         pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + 
i, sum);
 
 1363       for (; 
i < 
n; ++
i) {
 
 1364         dst_buf[
i] += src_buf0[
i] + src_buf1[
i] + src_buf2[
i];
 
 1368     template <
int Alignment>
 
 1369     void eval(Barrier& barrier, 
Index start_block_idx, 
Index end_block_idx) {
 
 1370       while (end_block_idx - start_block_idx > 1) {
 
 1371         Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
 
 1372         evaluator->m_device.enqueueNoNotification(
 
 1373             [
this, &barrier, mid_block_idx, end_block_idx]() {
 
 1374               eval<Alignment>(barrier, mid_block_idx, end_block_idx);
 
 1376         end_block_idx = mid_block_idx;
 
 1379       Index block_idx = start_block_idx;
 
 1380       Index block_start = block_idx * block_size;
 
 1381       Index block_end = block_start + actualBlockSize(block_idx);
 
 1383       processBlock<Alignment>(block_idx, block_start, block_end);
 
 1387     template <
int Alignment>
 
 1388     void evalAsync(
Index start_block_idx, 
Index end_block_idx) {
 
 1389       while (end_block_idx - start_block_idx > 1) {
 
 1390         Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
 
 1391         evaluator->m_device.enqueueNoNotification(
 
 1392             [
this, mid_block_idx, end_block_idx]() {
 
 1393               evalAsync<Alignment>(mid_block_idx, end_block_idx);
 
 1395         end_block_idx = mid_block_idx;
 
 1398       Index block_idx = start_block_idx;
 
 1400       Index block_start = block_idx * block_size;
 
 1401       Index block_end = block_start + actualBlockSize(block_idx);
 
 1403       processBlock<Alignment>(block_idx, block_start, block_end);
 
 1405       int v = num_pending_blocks.fetch_sub(1);
 
 1410         aggregateL0Blocks<Alignment>();
 
 1413         applyOutputKernel();
 
 1420         DoneCallback done_copy = std::move(done);
 
 1433     static Index blockSize(
Index k, 
int num_threads) {
 
 1434       const auto round_up = [=](
Index index) -> 
Index {
 
 1435         const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
 
 1436         return divup<Index>(index, kmultiple) * kmultiple;
 
 1439       const Index target_block_size = round_up(divup<Index>(
k, num_threads));
 
 1440       const Index desired_min_block_size = 12 * packet_size;
 
 1442       return numext::mini<Index>(
 
 1443           k, numext::maxi<Index>(desired_min_block_size, target_block_size));
 
 1446     EvalShardedByInnerDimContext(
const EvalShardedByInnerDimContext&) = 
delete;
 
 1447     void operator=(
const EvalShardedByInnerDimContext&) = 
delete;
 
 1463     if (
m / num_threads >= Traits::nr &&
 
 1465         (
n / num_threads < Traits::nr ||
 
 1468          (
n / num_threads < 4 * Traits::nr &&
 
 1469           (
n % (num_threads * Traits::nr)) != 0 &&
 
 1471           ((
m % (num_threads * Traits::nr)) == 0 ||
 
 1479     if (
n / num_threads < 16 * Traits::nr && m > 
n * 32) 
return false;
 
 1484                  int num_threads, 
bool shard_by_col)
 const {
 
 1493       while (gm1 <= nm0 && nm1 == 
divup(nm0, gm1)) gm1++;
 
 1494       if (gm1 > nm0) 
break;
 
 1496       int res = checkGrain(
m, 
n, bm, bn, bk, gm1, 
gn, gm, 
gn, num_threads,
 
 1499       nm1 = 
divup(nm0, gm1);
 
 1500       if (
res == 0) 
continue;
 
 1508                  int num_threads, 
bool shard_by_col)
 const {
 
 1514       while (gn1 <= nn0 && nn1 == 
divup(nn0, gn1)) gn1++;
 
 1515       if (gn1 > nn0) 
break;
 
 1516       int res = checkGrain(
m, 
n, bm, bn, bk, gm, gn1, gm, 
gn, num_threads,
 
 1519       nn1 = 
divup(nn0, gn1);
 
 1520       if (
res == 0) 
continue;
 
 1530                  bool shard_by_col)
 const {
 
 1531     const TensorOpCost cost =
 
 1532         contractionCost(bm * gm, bn * 
gn, bm, bn, bk, shard_by_col, 
true);
 
 1534         static_cast<double>(bm) * gm * bn * 
gn, cost);
 
 1537     if (taskSize < 1) 
return 1;
 
 1539     if (taskSize > 2) 
return -1;
 
 1549     double new_parallelism = 
static_cast<double>(new_tasks) /
 
 1550                              (divup<int>(new_tasks, num_threads) * num_threads);
 
 1552     double old_parallelism = 
static_cast<double>(old_tasks) /
 
 1553                              (divup<int>(old_tasks, num_threads) * num_threads);
 
 1554     if (new_parallelism > old_parallelism || new_parallelism == 1) 
return 1;
 
 1559                                bool shard_by_col, 
bool prepacked)
 const {
 
 1563     const double kd = 
static_cast<double>(bk);
 
 1564     double compute_bandwidth = computeBandwidth(
false, bm, bn, bk);
 
 1566     TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, 
true, packed_size);
 
 1568     cost += TensorOpCost(0, 
sizeof(
CoeffReturnType), 0, 
true, output_packet_size);
 
 1576     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd / 
n);
 
 1577     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd / 
m);
 
 1581       lhsCost.dropMemoryCost();
 
 1583       rhsCost.dropMemoryCost();
 
 1584     return cost + lhsCost + rhsCost;
 
 1590                               int num_threads_by_k) {
 
 1591     std::ptrdiff_t bufsize = 
m * 
n * 
sizeof(
Scalar);
 
 1592     bool shard_by_k = 
false;
 
 1594         num_threads_by_k < 2 ||  
 
 1599         k / num_threads_by_k < 2 * Traits::nr) {  
 
 1604                (
k / num_threads_by_k > 8 * Traits::nr &&
 
 1608                  num_threads_by_k > num_threads))) {
 
 1617     TensorOpCost cost(0, 0, (computeBandwidth(
true, 
m, 
n, 
k) * 
m) * 
n, 
true, output_packet_size);
 
 1619     cost += TensorOpCost(0, 
sizeof(
CoeffReturnType), 0, 
true, output_packet_size);
 
 1620     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * 
m;
 
 1621     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * 
n;
 
 1624     lhsCost.dropMemoryCost();
 
 1625     return cost + lhsCost + rhsCost;
 
 1630     TensorOpCost cost = contractionCostPerInnerDim(
m, 
n, 
k);
 
 1631     double total_parallel_cost =
 
 1636         m * 
n, TensorOpCost(2, 1, 1, 
true, output_packet_size));
 
 1637     int num_threads = 1;
 
 1638     double min_cost = total_parallel_cost;
 
 1639     double kPerThreadOverHead = 3000;
 
 1640     double kFixedOverHead = 100000;
 
 1641     for (
int nt = 2; nt <= this->
m_device.numThreads(); nt += 2) {
 
 1642       double sequential_cost =
 
 1643           kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
 
 1644       double parallel_cost = total_parallel_cost / nt + sequential_cost;
 
 1645       if (parallel_cost < min_cost) {
 
 1647         min_cost = parallel_cost;
 
 1653   double computeBandwidth(
bool shard_by_col, 
Index bm, 
Index bn,
 
 1658     double computeBandwidth =
 
 1660                 : (shard_by_col ? bn : bm) < Traits::nr ||
 
 1661                           (shard_by_col ? bm : bn) < Traits::mr
 
 1664 #ifndef EIGEN_VECTORIZE_FMA 
 1669     if (computeBandwidth == 0.5) computeBandwidth = 1.0;
 
 1671     return computeBandwidth;
 
 1678 #endif  // EIGEN_USE_THREADS 
 1679 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H