19 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
20 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
24 namespace TensorSycl {
27 #ifndef EIGEN_SYCL_DISABLE_GEMV
42 template <
typename Scalar,
typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor>
78 template <
typename Scalar,
typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK>
84 #ifndef EIGEN_SYCL_REG_M
91 #ifndef EIGEN_SYCL_REG_N
115 #ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER
157 template <
bool PacketLoad,
bool is_coalesced_layout, bool,
typename PacketType,
typename TensorMapper,
158 typename StorageIndex>
160 const TensorMapper &tensorMapper,
const StorageIndex &NCIndex,
const StorageIndex &CIndex,
const StorageIndex &ld) {
161 const StorageIndex
row = (is_coalesced_layout) ? NCIndex : CIndex;
162 const StorageIndex
col = (is_coalesced_layout) ? CIndex : NCIndex;
163 return tensorMapper.get_tensor().template packet<Unaligned>(
row + (
col * ld));
188 template <
bool PacketLoad,
bool,
bool IsRhs,
typename PacketType,
typename TensorMapper,
typename StorageIndex>
190 const TensorMapper &tensorMapper,
const StorageIndex &NCIndex,
const StorageIndex &CIndex,
const StorageIndex &) {
191 const StorageIndex
row = (IsRhs) ? CIndex : NCIndex;
192 const StorageIndex
col = (IsRhs) ? NCIndex : CIndex;
193 return tensorMapper(
row,
col);
217 template <
typename StorageIndex, StorageIndex ld, data_source dt,
typename PacketType,
typename DataScalar>
223 for (
int i = 0;
i < PacketSize;
i++) {
244 template <data_source dt,
typename PacketType,
typename DataScalar>
247 write(PacketType &packet_data, DataScalar *ptr) {
248 ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data);
264 template <data_source dt,
typename PacketType,
typename DataScalar>
276 template <
bool is_
internal>
317 template <
bool is_transposed,
bool is_rhs_,
bool packet_load_,
typename PacketType>
318 struct BlockProperties {
368 template <
typename StorageIndex>
369 struct ThreadProperties {
383 const StorageIndex linearLocalThreadId_,
const StorageIndex kGroupId_,
const StorageIndex mGroupOffset_,
384 const StorageIndex nGroupOffset_,
const StorageIndex kGroupOffset_,
const StorageIndex mLocalOffset_,
385 const StorageIndex nLocalOffset_,
const StorageIndex mGlobalOffset_,
const StorageIndex nGlobalOffset_,
386 StorageIndex kSize_,
const bool is_internal_)
450 template <
typename OutScalar,
typename LhsScalar,
typename RhsScalar,
typename OutAccessor,
typename LhsMapper,
451 typename RhsMapper,
typename StorageIndex,
typename Properties,
typename TripleDim,
bool Vectorizable,
452 typename input_mapper_properties,
bool IsFinal,
contraction_type contraction_tp>
455 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
458 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
475 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
Scratch;
476 typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space>
local_ptr;
482 ? Properties::TileSizeDimM + Properties::BC
483 : Properties::WorkLoadPerThreadM;
485 ? Properties::TileSizeDimN + Properties::BC
486 : Properties::WorkLoadPerThreadN;
501 template <contraction_type, StorageIndex>
509 template <StorageIndex MemSize>
511 OutScalar
ptr[MemSize] = {OutScalar{0}};
536 MemHolder<contraction_tp, Properties::WorkLoadPerThreadM * Properties::TileSizeDimK>
lhs_scratch_extract;
537 MemHolder<contraction_tp, Properties::WorkLoadPerThreadN * Properties::TileSizeDimK>
rhs_scratch_extract;
542 template <contraction_type tp = contraction_tp>
550 lhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})),
551 rhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {}
553 template <contraction_type tp = contraction_tp>
559 ((Properties::DoubleBuffer + 1) *
LSDL * Properties::TileSizeDimK)},
563 local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)),
565 local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {}
578 const RhsMapper rhs_, OutAccessor out_res_,
579 const StorageIndex groupSizeM_,
580 const StorageIndex groupSizeN_,
581 const StorageIndex numTiles_,
582 const TripleDim triple_dim_)
593 const RhsMapper rhs_, OutAccessor out_res_,
594 const StorageIndex groupSizeM_,
595 const StorageIndex numTiles_,
596 const TripleDim triple_dim_)
600 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
601 const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM;
602 const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM;
603 const StorageIndex mGroupId = itemID.get_group(0) %
groupSizeM;
604 const StorageIndex tmp = itemID.get_group(0) /
groupSizeM;
605 const StorageIndex nGroupId = IsFinal ? tmp : tmp %
groupSizeN;
606 const StorageIndex kGroupId = IsFinal ? 0 : tmp /
groupSizeN;
607 const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM;
608 const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN;
609 const StorageIndex mLocalOffset =
PacketSize * mLocalThreadId;
610 const StorageIndex nLocalOffset =
NStride * nLocalThreadId;
611 const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset;
612 const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset;
614 const StorageIndex kSizePerWG = IsFinal ?
triple_dim.K :
numTiles * Properties::TileSizeDimK;
615 StorageIndex kGroupOffset = kGroupId * kSizePerWG;
616 const bool is_internal =
triple_dim.M - mGroupOffset >= Properties::TileSizeDimM &&
623 kGroupOffset += kSize;
625 auto thread_properties =
627 mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal);
631 (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr)
632 : compute_panel<false>(itemID, thread_properties, out_ptr);
639 StorageIndex idx = 0;
643 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) {
645 StorageIndex lhs_index = 0;
647 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM /
PacketSize; wLPTM++) {
650 lhs_block_ptr + lhs_index);
653 lhs_index += lhs_stride;
661 template <
bool is_
internal_block, StorageIndex PrivateNStr
ide,
typename OutPtr>
663 StorageIndex mGlobalOffset, StorageIndex nGlobalOffset) {
664 auto chk_bound = [&](
const StorageIndex &mIndex,
const StorageIndex &nIndex)
EIGEN_DEVICE_FUNC {
673 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) {
675 StorageIndex outputLD = 0;
680 for (StorageIndex nId = 0; nId < PrivateNStride; nId++) {
681 StorageIndex globalRow = mGlobalOffset;
683 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM /
PacketSize; wLPTM++) {
685 if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) {
688 write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow);
691 for (StorageIndex mId = 0; mId <
PacketSize; mId++) {
692 StorageIndex mOffset = globalRow + mId;
694 out_ptr[mOffset + outputLD] =
699 globalRow += (
PacketSize * Properties::LocalThreadSizeM);
702 privateRes += Properties::WorkLoadPerThreadM /
PacketSize;
704 out_ptr += (GlobalNStride * outputLD);
706 nGlobalOffset += (PrivateNStride * GlobalNStride);
710 template <
typename InputBlockProperties,
bool is_internal_block,
typename Input,
typename PrivateReg,
715 const StorageIndex &ncOffset,
const StorageIndex cOffset) {
717 InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM;
719 InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM;
722 auto chk_bound = [&](
const StorageIndex &CIndex,
const StorageIndex &NCIndex)
EIGEN_DEVICE_FUNC {
723 return ((CIndex + InputBlockProperties::c_stride - 1 <
triple_dim.K) &&
724 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
726 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC :
triple_dim.K;
727 StorageIndex cIndex = cOffset;
730 for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) {
731 StorageIndex ncIndex = ncOffset;
733 for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) {
734 if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) {
736 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
737 InputBlockProperties::is_rhs,
typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld);
739 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
743 for (StorageIndex
i = 0;
i < InputBlockProperties::elements_per_access;
i++) {
744 const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ?
i : 0);
745 const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 :
i);
748 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
749 inpt, ncInd, cInd, ld)
751 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
753 val,
private_ptr + (InputBlockProperties::is_coalesced_layout ?
i : 0) +
754 ((InputBlockProperties::is_coalesced_layout ? 0 :
i) * WorkLoadPerThreadNC));
760 ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 &&
PacketSize != 1)
762 : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC);
766 private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC;
767 cIndex += InputBlockProperties::c_stride;
770 template <
typename InputBlockProperties, StorageIndex TileSizeDimNC>
772 const StorageIndex &linearLocalThreadId) {
773 const StorageIndex localThreadNC =
774 (InputBlockProperties::is_coalesced_layout)
775 ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride)
776 : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
777 const StorageIndex localThreadC =
778 (InputBlockProperties::is_coalesced_layout)
779 ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride)
780 : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
781 return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC);
784 template <
bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
787 sync_mem(
const cl::sycl::nd_item<1> &,
bool &db_offset) noexcept {
788 db_offset = !db_offset;
791 template <
bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
794 sync_mem(
const cl::sycl::nd_item<1> &itemID,
bool &) noexcept {
795 itemID.barrier(cl::sycl::access::fence_space::local_space);
798 template <contraction_type ctp = contraction_tp>
801 sync_mem(
const cl::sycl::nd_item<1> &,
bool &) noexcept {
805 template <
bool need_sync, contraction_type ctp = contraction_tp>
809 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
813 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
814 itemID.barrier(cl::sycl::access::fence_spacce::local_space);
819 template <
bool need_sync, contraction_type ctp = contraction_tp>
823 itemID.barrier(cl::sycl::access::fence_space::local_space);
825 template <
bool need_sync>
827 const cl::sycl::nd_item<1> &) {
831 template <
bool is_
internal_block>
834 TiledMemory &tiled_input_block,
837 extract_block<RHSBlockProperties, is_internal_block>(
838 rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK *
LSDR),
839 tiled_input_block.rhs_extract_index,
841 thread_properties.kGroupOffset - thread_properties.kSize);
843 sync_thread<contraction_tp == contraction_type::no_local>(itemID);
846 extract_block<LHSBlockProperties, is_internal_block>(
847 lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset *
LSDL * Properties::TileSizeDimK),
848 tiled_input_block.lhs_extract_index,
850 thread_properties.kGroupOffset - thread_properties.kSize);
853 sync_thread<contraction_tp == contraction_type::local>(itemID);
855 StorageIndex lhs_offset = (db_offset *
LSDL * Properties::TileSizeDimK);
856 StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK *
LSDR);
858 for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) {
860 tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes);
865 thread_properties.kSize -= Properties::TileSizeDimK;
870 template <
bool is_
internal_block,
typename OutPtr>
872 ThreadProperties<StorageIndex> &thread_properties,
874 auto tiled_input_block = TiledMemory{thread_properties,
scratch.get_pointer()};
880 while (thread_properties.kSize >= Properties::TileSizeDimK) {
881 compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
883 if (thread_properties.kSize > 0) {
884 compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
888 store<is_internal_block,
890 out_ptr + thread_properties.nGlobalOffset *
triple_dim.M, privateRes, thread_properties.mGlobalOffset,
891 thread_properties.nGlobalOffset);
894 template <
typename InputBlockProperties,
bool is_internal_block,
typename Input,
typename Local,
898 extract_block(
const Input &inpt, Local
local_ptr,
const std::pair<StorageIndex, StorageIndex>& local_index,
899 const StorageIndex &ncOffset,
const StorageIndex cOffset) {
901 InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM;
903 InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs;
905 static_assert(((
LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) &&
906 (
LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)),
907 " LocalOffset must be divisable by stride");
909 StorageIndex localThreadNC = local_index.first;
910 StorageIndex localThreadC = local_index.second;
911 auto chk_bound = [&](
const StorageIndex &CIndex,
const StorageIndex &NCIndex)
EIGEN_DEVICE_FUNC {
912 return ((CIndex + InputBlockProperties::c_stride - 1 <
triple_dim.K) &&
913 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
916 for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) {
917 const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC);
918 const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC);
919 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC :
triple_dim.K;
920 if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) {
922 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
923 InputBlockProperties::is_rhs,
typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld);
924 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
925 val,
local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
926 (InputBlockProperties::c_stride * localThreadC * LSD));
929 for (StorageIndex
i = 0;
i < InputBlockProperties::elements_per_access;
i++) {
930 const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ?
i : 0);
931 const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 :
i);
934 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
935 inpt, nCInd, cInd, ld)
938 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
939 val,
local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
940 (InputBlockProperties::is_coalesced_layout ?
i : 0) +
941 ((InputBlockProperties::c_stride * localThreadC +
942 (InputBlockProperties::is_coalesced_layout ? 0 :
i)) *
946 localThreadNC += (InputBlockProperties::is_coalesced_layout)
947 ?
LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride)
948 :
LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
949 localThreadC += (InputBlockProperties::is_coalesced_layout)
950 ?
LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride)
951 :
LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
956 #ifndef EIGEN_SYCL_DISABLE_GEMV
999 template <
typename OutScalar,
typename OutAccessor,
typename VectorMapper,
typename TensorMapper,
typename StorageIndex,
1000 typename Properties, StorageIndex KFactor,
bool Vectorizable,
bool is_lhs_vec,
bool IsFinal>
1001 struct GeneralVectorTensor {
1002 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
1005 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
1006 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
Scratch;
1009 KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1013 typedef BlockProperties<is_lhs_vec ? false : true, is_lhs_vec ? false : true, Vectorizable, PacketReturnType>
1017 const VectorMapper
vec;
1018 const TensorMapper
mat;
1025 const TensorMapper mat_, OutAccessor out_res_,
1026 const StorageIndex nonContractGroupSize_,
1027 const StorageIndex nonContractDim_,
1028 const StorageIndex contractDim_)
1039 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
1040 StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC
1041 : linearLocalThreadId % Properties::LocalThreadSizeNC;
1042 StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC
1043 : linearLocalThreadId / Properties::LocalThreadSizeNC;
1045 const StorageIndex nonContractGroupId =
1047 const StorageIndex contractGroupId =
1051 const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC;
1052 const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC;
1053 auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1054 const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId;
1055 const StorageIndex globalContractDimOffset = contractGroupOffset + contractId;
1057 const bool is_internal =
nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC &&
1058 contractDim - contractGroupOffset >= Properties::TileSizeDimC;
1060 ? compute_panel<true>(itemID,
vec,
mat, local_output, out_ptr,
1062 scratch_ptr, contractGroupOffset,
1065 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex)
1068 scratch_ptr, contractGroupOffset,
1071 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex);
1073 template <
bool is_
internal_block,
typename OutPtr>
1075 const cl::sycl::nd_item<1> &itemID,
const VectorMapper &
vec,
const TensorMapper &
mat, OutScalar *local_output,
1078 OutScalar *scratch_ptr,
const StorageIndex contractGroupOffset,
1080 const StorageIndex nonContractGroupOffset,
const StorageIndex linearLocalThreadId, StorageIndex
contractDim,
1081 StorageIndex
nonContractDim, StorageIndex contractId, StorageIndex nonContractId,
1082 StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) {
1083 OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)};
1085 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1086 const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId;
1088 Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(
vec, scratch_ptr, linearLocalThreadId,
1091 itemID.barrier(cl::sycl::access::fence_space::local_space);
1092 auto in_scratch_ptr = scratch_ptr + contractId;
1095 StorageIndex privateOffsetC = 0;
1097 for (StorageIndex
i = 0;
i < Properties::WorkLoadPerThreadC;
i++) {
1098 StorageIndex privateOffsetNC = 0;
1099 bool contract_conds = ((globalContractDimOffset + privateOffsetC) <
contractDim);
1100 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1101 auto vecScalar = *in_scratch_ptr;
1103 auto vecScalar = (check_boundary<is_internal_block>(contract_conds))
1104 ?
vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC,
1105 is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0))
1109 for (StorageIndex
j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1110 auto matScalar = (check_boundary<is_internal_block>(
1111 contract_conds && ((globalNonContractDimOffset + privateOffsetNC) <
nonContractDim)))
1112 ?
mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC
1113 : globalNonContractDimOffset + privateOffsetNC,
1114 is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC
1115 : globalContractDimOffset + privateOffsetC)
1118 outScalar[
j] = cl::sycl::mad(matScalar, vecScalar, outScalar[
j]);
1119 privateOffsetNC += Properties::LocalThreadSizeNC;
1121 privateOffsetC += Properties::LocalThreadSizeC;
1122 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1123 in_scratch_ptr += Properties::LocalThreadSizeC;
1127 auto out_scratch_ptr = local_output + outScratchIndex;
1130 for (StorageIndex
j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1131 *out_scratch_ptr = outScalar[
j];
1133 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1136 nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC;
1137 contractId = linearLocalThreadId / Properties::LocalThreadSizeNC;
1138 outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1141 out_scratch_ptr = local_output + outScratchIndex;
1143 for (StorageIndex
j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1145 for (StorageIndex
offset = Properties::LocalThreadSizeC >> 1;
offset > 0;
offset >>= 1) {
1146 itemID.barrier(cl::sycl::access::fence_space::local_space);
1147 if (contractId <
offset) {
1148 StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC *
offset);
1149 *out_scratch_ptr += out_scratch_ptr[myNeigbourId];
1153 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1156 if (contractId == 0) {
1157 out_scratch_ptr = local_output + nonContractId;
1158 StorageIndex global_final_offset = nonContractGroupOffset + nonContractId;
1159 out_ptr += global_final_offset;
1161 for (StorageIndex
j = 0;
j < Properties::WorkLoadPerThreadNC;
j++) {
1162 if (check_boundary<is_internal_block>(global_final_offset <
nonContractDim)) {
1163 auto res = *out_scratch_ptr;
1166 out_ptr += Properties::LocalThreadSizeNC;
1169 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1170 if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC;
1175 template <
typename InputBlockProperties,
bool is_internal_block,
int CFactor,
int GroupSize,
typename Input,
1178 const StorageIndex &linearLocalThreadId,
1179 const StorageIndex &cOffset,
const StorageIndex &
C) {
1180 local_ptr += InputBlockProperties::c_stride * linearLocalThreadId;
1181 StorageIndex cIndex = cOffset;
1182 for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) {
1183 if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 <
C)) {
1184 auto val =
read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
1185 InputBlockProperties::is_rhs,
typename InputBlockProperties::OutType>(inpt, StorageIndex(0),
1186 cIndex, StorageIndex(1));
1187 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr);
1190 for (StorageIndex
i = 0;
i < InputBlockProperties::elements_per_access;
i++) {
1193 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
1194 inpt, StorageIndex(0), cIndex +
i, StorageIndex(1))
1196 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr +
i);
1199 local_ptr += InputBlockProperties::c_stride * GroupSize;
1200 cIndex += InputBlockProperties::c_stride * GroupSize;
1206 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1239 template <
typename OutScalar,
typename LhsScalar,
typename RhsScalar,
typename OutAccessor,
typename LhsMapper,
1240 typename RhsMapper,
typename StorageIndex,
bool Vectorizable>
1241 struct GeneralScalarContraction {
1242 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
Scratch;
1244 const LhsMapper
lhs;
1245 const RhsMapper
rhs;
1247 const StorageIndex
rng;
1251 const StorageIndex rng_)
1255 auto out_ptr =
out_res.get_pointer();
1256 auto scratch_ptr =
scratch.get_pointer().get();
1258 StorageIndex globalid = itemID.get_global_id(0);
1259 StorageIndex localid = itemID.get_local_id(0);
1260 OutScalar accumulator = OutScalar(0);
1261 for (StorageIndex
i = globalid;
i <
rng;
i += itemID.get_global_range(0)) {
1262 accumulator = cl::sycl::mad(
lhs(0,
i),
rhs(
i, 0), accumulator);
1264 auto out_scratch_ptr = scratch_ptr + localid;
1265 *out_scratch_ptr = accumulator;
1266 for (StorageIndex
offset = itemID.get_local_range(0) >> 1;
offset > 0;
offset >>= 1) {
1267 itemID.barrier(cl::sycl::access::fence_space::local_space);
1269 *out_scratch_ptr = (accumulator += out_scratch_ptr[
offset]);
1273 out_ptr[itemID.get_group(0)] = accumulator;
1282 template <
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
1283 struct TensorEvaluator<const TensorContractionOp<
Indices, LeftArgType, RightArgType, OutputKernelType>,
1285 :
public TensorContractionEvaluatorBase<TensorEvaluator<
1286 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> {
1288 "SYCL tensor contraction does not support output kernels.");
1290 typedef Eigen::SyclDevice
Device;
1324 static const int NumDims = LDims + RDims - 2 * ContractDims;
1336 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered>
1337 struct input_mapper_propertis {
1338 static EIGEN_CONSTEXPR bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous;
1340 (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered);
1347 this->m_leftImpl.evalSubExprsIfNeeded(
NULL);
1348 this->m_rightImpl.evalSubExprsIfNeeded(
NULL);
1350 this->m_result = this->
m_device.get(
1351 static_cast<Scalar *
>(this->
m_device.allocate_temp(
this->dimensions().TotalSize() *
sizeof(Scalar))));
1355 return (this->m_result !=
NULL);
1357 const Eigen::SyclDevice &device()
const {
return this->
m_device; }
1358 void evalToSycl(
typename Base::EvaluatorPointerType
buffer)
const {
1359 if (this->m_lhs_inner_dim_contiguous) {
1360 if (this->m_rhs_inner_dim_contiguous) {
1361 if (this->m_rhs_inner_dim_reordered) {
1362 evalTyped<true, true, true, Unaligned>(
buffer);
1364 evalTyped<true, true, false, Unaligned>(
buffer);
1367 if (this->m_rhs_inner_dim_reordered) {
1368 evalTyped<true, false, true, Unaligned>(
buffer);
1370 evalTyped<true, false, false, Unaligned>(
buffer);
1374 if (this->m_rhs_inner_dim_contiguous) {
1375 if (this->m_rhs_inner_dim_reordered) {
1376 evalTyped<false, true, true, Unaligned>(
buffer);
1378 evalTyped<false, true, false, Unaligned>(
buffer);
1381 if (this->m_rhs_inner_dim_reordered) {
1382 evalTyped<false, false, true, Unaligned>(
buffer);
1384 evalTyped<false, false, false, Unaligned>(
buffer);
1390 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1391 void evalTyped(
typename Base::EvaluatorPointerType
buffer)
const {
1392 const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size};
1393 typedef internal::TensorContractionInputMapper<
1394 LhsScalar, StorageIndex,
internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t,
1401 rhs_inner_dim_reordered,
Unaligned, MakeSYCLPointer>
1405 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1406 this->m_left_contracting_strides, this->m_k_strides);
1408 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1409 this->m_right_contracting_strides, this->m_k_strides);
1411 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1412 if (triple_dim.M == 1 && triple_dim.N == 1) {
1413 launchSC(
buffer, lhs, rhs, triple_dim.K);
1416 #ifndef EIGEN_SYCL_DISABLE_GEMV
1417 if (triple_dim.M != 1 && triple_dim.N == 1) {
1418 LaunchVT<false>(
buffer, rhs, lhs, triple_dim.M, triple_dim.K);
1419 }
else if (triple_dim.M == 1 && triple_dim.N != 1) {
1420 LaunchVT<true>(
buffer, lhs, rhs, triple_dim.N, triple_dim.K);
1424 typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>
1425 inpt_mapper_properties;
1426 #ifndef EIGEN_SYCL_DISABLE_SKINNY
1427 bool skinny =
false;
1428 auto platform_name = this->device().getPlatformName();
1430 if (platform_name.find(
"AMD") == 0) {
1431 skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) &&
1432 ((triple_dim.M < 1024 && triple_dim.N < 1024) ||
1435 skinny = (((
std::max(triple_dim.K, triple_dim.N) /
std::min(triple_dim.K, triple_dim.N)) > 100) ||
1436 ((
std::max(triple_dim.K, triple_dim.M) /
std::min(triple_dim.K, triple_dim.M)) > 100) ||
1437 ((
std::max(triple_dim.N, triple_dim.M) /
std::min(triple_dim.N, triple_dim.M)) > 100));
1440 adjustTT<true, inpt_mapper_properties>(
buffer, lhs, rhs, triple_dim);
1442 #endif // EIGEN_SYCL_DISABLE_SKINNY
1443 adjustTT<false, inpt_mapper_properties>(
buffer, lhs, rhs, triple_dim);
1447 template <
bool skinny,
typename input_mapper_properties,
typename LhsMapper,
typename RhsMapper>
1449 const TripleDim &triple_dim)
const {
1450 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1451 if (device().has_local_memory()) {
1452 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters;
1453 launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>(
1457 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF
1458 if (!(device().has_local_memory())) {
1460 launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>(
1461 buffer, lhs, rhs, triple_dim);
1467 typename Properties,
typename LhsMapper,
typename RhsMapper>
1469 const TripleDim &triple_dim)
const {
1470 const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM);
1471 const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN);
1472 const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM;
1473 const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN;
1475 const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK);
1476 StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK;
1480 (
StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(),
true) * 4) /
1481 (groupSizeM * groupSizeN)),
1485 const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK;
1487 const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK;
1489 const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
1490 const StorageIndex globalRange = totalGroupSize * localRange;
1493 ? ((Properties::DoubleBuffer + 1) *
1494 (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) +
1495 ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) *
1496 (Properties::TileSizeDimN + Properties::BC))
1499 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1500 if (groupSizeK == 1) {
1502 LhsMapper, RhsMapper,
StorageIndex, Properties, TripleDim,
1505 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1506 lhs, rhs,
buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim);
1509 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1513 device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK *
sizeof(
CoeffReturnType)));
1516 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1517 lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup,
1526 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1527 tmp_global_accessor,
buffer,
1528 cl::sycl::nd_range<1>(cl::sycl::range<1>(StorageIndex(
1529 Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))),
1530 cl::sycl::range<1>(localRange)),
1531 StorageIndex(1), op, StorageIndex(triple_dim.M * triple_dim.N), groupSizeK);
1533 device().deallocate_temp(temp_pointer);
1537 #ifndef EIGEN_SYCL_DISABLE_GEMV
1538 template <
bool is_lhs_vec,
typename VectorMapper,
typename TensorMapper,
typename StorageIndex>
1540 StorageIndex NC, StorageIndex
C)
const {
1541 const StorageIndex nonContractDim = NC;
1547 const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(
C, Properties::TileSizeDimC);
1548 const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC);
1549 const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC);
1550 const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC);
1552 (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC));
1553 const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC;
1555 (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1556 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1557 if (cNumGroups > 1) {
1559 TensorMapper,
StorageIndex, Properties, CFactor,
false,
1566 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1567 vec,
mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim,
C);
1574 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1575 tmp_global_accessor,
buffer,
1576 cl::sycl::nd_range<1>(cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)),
1577 cl::sycl::range<1>(localRange)),
1580 device().deallocate_temp(temp_pointer);
1583 TensorMapper, StorageIndex, Properties, CFactor,
false,
1586 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1587 vec,
mat,
buffer, thread_range, scratchSize, nCNumGroups, nonContractDim,
C);
1592 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1593 template <
typename LhsMapper,
typename RhsMapper>
1595 StorageIndex
K)
const {
1597 (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
1598 "The Local thread size must be a power of 2 for the reduction "
1604 const StorageIndex num_work_group = ((
K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1);
1605 const StorageIndex global_range = num_work_group * local_range;
1610 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
1611 if (num_work_group > 1) {
1615 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor,
1616 thread_range, local_range,
K);
1621 device().template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
1622 tmp_global_accessor,
buffer,
1623 cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range, Op());
1625 device().deallocate_temp(temp_pointer);
1627 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs,
buffer, thread_range,
1634 this->m_leftImpl.cleanup();
1635 this->m_rightImpl.cleanup();
1637 if (this->m_result) {
1638 this->
m_device.deallocate_temp(this->m_result);
1644 this->m_leftImpl.bind(cgh);
1645 this->m_rightImpl.bind(cgh);
1646 this->m_result.bind(cgh);
1650 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H