12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H 13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H 15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC) 19 template<
typename Scalar,
typename Index,
typename LhsMapper,
20 typename RhsMapper,
typename OutputMapper,
bool needs_edge_check>
22 EigenContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
23 const OutputMapper output,
Scalar* lhs_shmem,
Scalar* rhs_shmem,
24 const Index m_size,
const Index n_size,
const Index k_size) {
26 const Index m_block_idx =
blockIdx.x;
27 const Index n_block_idx =
blockIdx.y;
29 const Index base_m = 64 * m_block_idx;
30 const Index base_n = 64 * n_block_idx;
70 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
79 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
98 const Index lhs_vert = base_m + load_idx_vert;
100 #define prefetchIntoRegisters(base_k) \ 120 if (!needs_edge_check || lhs_vert < m_size) { \ 121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ 122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ 123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ 124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ 125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ 126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ 127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ 128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ 130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \ 131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ 139 } else if (lhs_horiz_6 < k_size) { \ 140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 147 } else if (lhs_horiz_5 < k_size) { \ 148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 154 } else if (lhs_horiz_4 < k_size) { \ 155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 160 } else if (lhs_horiz_3 < k_size) { \ 161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 165 } else if (lhs_horiz_2 < k_size) { \ 166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 169 } else if (lhs_horiz_1 < k_size) { \ 170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 172 } else if (lhs_horiz_0 < k_size) { \ 173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 177 const Index rhs_vert = base_k + load_idx_vert; \ 178 if (!needs_edge_check || rhs_vert < k_size) { \ 179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ 180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ 181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ 182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ 183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ 184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ 185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ 186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ 188 if (rhs_horiz_7 < n_size) { \ 189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ 197 } else if (rhs_horiz_6 < n_size) { \ 198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 205 } else if (rhs_horiz_5 < n_size) { \ 206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 212 } else if (rhs_horiz_4 < n_size) { \ 213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 218 } else if (rhs_horiz_3 < n_size) { \ 219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 223 } else if (rhs_horiz_2 < n_size) { \ 224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 227 } else if (rhs_horiz_1 < n_size) { \ 228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 230 } else if (rhs_horiz_0 < n_size) { \ 231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 236 #define writeRegToShmem(_) \ 237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ 238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ 240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ 241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ 243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ 244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ 246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ 247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ 249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ 250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ 252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ 253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ 255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ 256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ 258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ 259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ 262 #define res(i, j) _res_##i##j 263 #define initResultRow(i) \ 264 Scalar res(i, 0) = conv(0); \ 265 Scalar res(i, 1) = conv(0); \ 266 Scalar res(i, 2) = conv(0); \ 267 Scalar res(i, 3) = conv(0); \ 268 Scalar res(i, 4) = conv(0); \ 269 Scalar res(i, 5) = conv(0); \ 270 Scalar res(i, 6) = conv(0); \ 271 Scalar res(i, 7) = conv(0); \ 273 internal::scalar_cast_op<int, Scalar>
conv;
284 for (Index base_k = 0; base_k < k_size; base_k += 64) {
289 prefetchIntoRegisters(base_k);
292 #undef prefetchIntoRegisters 293 #undef writeRegToShmem 301 #define lcol(i) _lcol##i 311 #define rrow(j) _rrow##j 325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] 326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] 328 #define loadData(i, j) \ 329 lcol(0) = lhs_element(0, j); \ 330 rrow(0) = rhs_element(i, 0); \ 331 lcol(1) = lhs_element(1, j); \ 332 rrow(1) = rhs_element(i, 1); \ 333 lcol(2) = lhs_element(2, j); \ 334 rrow(2) = rhs_element(i, 2); \ 335 lcol(3) = lhs_element(3, j); \ 336 rrow(3) = rhs_element(i, 3); \ 337 lcol(4) = lhs_element(4, j); \ 338 rrow(4) = rhs_element(i, 4); \ 339 lcol(5) = lhs_element(5, j); \ 340 rrow(5) = rhs_element(i, 5); \ 341 lcol(6) = lhs_element(6, j); \ 342 rrow(6) = rhs_element(i, 6); \ 343 lcol(7) = lhs_element(7, j); \ 344 rrow(7) = rhs_element(i, 7); \ 346 #define computeCol(j) \ 347 res(0, j) += lcol(0) * rrow(j); \ 348 res(1, j) += lcol(1) * rrow(j); \ 349 res(2, j) += lcol(2) * rrow(j); \ 350 res(3, j) += lcol(3) * rrow(j); \ 351 res(4, j) += lcol(4) * rrow(j); \ 352 res(5, j) += lcol(5) * rrow(j); \ 353 res(6, j) += lcol(6) * rrow(j); \ 354 res(7, j) += lcol(7) * rrow(j); \ 356 #define computePass(i) \ 391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000) 392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) 394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask) 397 #define reduceRow(i, mask) \ 398 shuffleInc(i, 0, mask); \ 399 shuffleInc(i, 1, mask); \ 400 shuffleInc(i, 2, mask); \ 401 shuffleInc(i, 3, mask); \ 402 shuffleInc(i, 4, mask); \ 403 shuffleInc(i, 5, mask); \ 404 shuffleInc(i, 6, mask); \ 405 shuffleInc(i, 7, mask); \ 407 #define reduceMatrix(mask) \ 408 reduceRow(0, mask); \ 409 reduceRow(1, mask); \ 410 reduceRow(2, mask); \ 411 reduceRow(3, mask); \ 412 reduceRow(4, mask); \ 413 reduceRow(5, mask); \ 414 reduceRow(6, mask); \ 415 reduceRow(7, mask); \ 442 #define writeResultShmem(i, j) \ 443 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ 445 #define writeRow(i) \ 446 writeResultShmem(i, 0); \ 447 writeResultShmem(i, 1); \ 448 writeResultShmem(i, 2); \ 449 writeResultShmem(i, 3); \ 450 writeResultShmem(i, 4); \ 451 writeResultShmem(i, 5); \ 452 writeResultShmem(i, 6); \ 453 writeResultShmem(i, 7); \ 465 #undef writeResultShmem 472 if (max_j_write == 8) {
493 for (
int j = 0;
j < max_j_write;
j++) {
503 template<
typename Scalar,
typename Index,
typename LhsMapper,
504 typename RhsMapper,
typename OutputMapper>
506 #if defined(EIGEN_HIPCC) 507 __launch_bounds__(512, 1)
509 __launch_bounds__(512)
511 EigenContractionKernel(
const LhsMapper lhs,
const RhsMapper rhs,
512 const OutputMapper output,
513 const Index m_size,
const Index n_size,
const Index k_size) {
514 __shared__
Scalar lhs_shmem[72 * 64];
515 __shared__
Scalar rhs_shmem[72 * 64];
517 const Index m_block_idx =
blockIdx.x;
518 const Index n_block_idx =
blockIdx.y;
520 const Index base_m = 64 * m_block_idx;
521 const Index base_n = 64 * n_block_idx;
523 if (base_m + 63 < m_size && base_n + 63 < n_size) {
524 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
526 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
531 template<
typename Index,
typename LhsMapper,
532 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
533 bool CHECK_RHS_BOUNDARY>
534 __device__ __forceinline__
void 535 EigenFloatContractionKernelInternal16x16(
const LhsMapper lhs,
const RhsMapper rhs,
536 const OutputMapper output, float2 lhs_shmem2[][16],
537 float2 rhs_shmem2[][8],
const Index m_size,
538 const Index n_size,
const Index k_size,
539 const Index base_m,
const Index base_n) {
542 float4 lhs_pf0, rhs_pf0;
545 for (
int i=0;
i < 4;
i++) {
546 results[
i].x = results[
i].y = results[
i].z = results[
i].w = 0;
549 #define prefetch_lhs(reg, row, col) \ 550 if (!CHECK_LHS_BOUNDARY) { \ 551 if (col < k_size) { \ 552 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \ 555 if (col < k_size) { \ 556 if (row + 3 < m_size) { \ 557 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \ 558 } else if (row + 2 < m_size) { \ 559 reg.x =lhs(row + 0, col); \ 560 reg.y =lhs(row + 1, col); \ 561 reg.z =lhs(row + 2, col); \ 562 } else if (row + 1 < m_size) { \ 563 reg.x =lhs(row + 0, col); \ 564 reg.y =lhs(row + 1, col); \ 565 } else if (row < m_size) { \ 566 reg.x =lhs(row + 0, col); \ 573 for (Index k = 0; k < k_size; k += 16) {
575 lhs_pf0 = internal::pset1<float4>(0);
576 rhs_pf0 = internal::pset1<float4>(0);
579 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
584 if (!CHECK_RHS_BOUNDARY) {
585 if ((rhs_vert + 3) < k_size) {
587 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
588 }
else if (rhs_vert + 2 < k_size) {
590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
593 }
else if (rhs_vert + 1 < k_size) {
594 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
595 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
596 }
else if (rhs_vert < k_size) {
597 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
600 if (rhs_horiz0 < n_size) {
601 if ((rhs_vert + 3) < k_size) {
602 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
603 }
else if ((rhs_vert + 2) < k_size) {
604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
607 }
else if ((rhs_vert + 1) < k_size) {
608 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
609 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
610 }
else if (rhs_vert < k_size) {
611 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
624 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000) 625 x1 = __shfl_xor(x1, 4);
626 x2 = __shfl_xor(x2, 4);
628 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
629 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
660 #define add_vals(fl1, fl2, fr1, fr2)\ 661 results[0].x += fl1.x * fr1.x;\ 662 results[0].y += fl1.y * fr1.x;\ 663 results[0].z += fl2.x * fr1.x;\ 664 results[0].w += fl2.y * fr1.x;\ 666 results[1].x += fl1.x * fr1.y;\ 667 results[1].y += fl1.y * fr1.y;\ 668 results[1].z += fl2.x * fr1.y;\ 669 results[1].w += fl2.y * fr1.y;\ 671 results[2].x += fl1.x * fr2.x;\ 672 results[2].y += fl1.y * fr2.x;\ 673 results[2].z += fl2.x * fr2.x;\ 674 results[2].w += fl2.y * fr2.x;\ 676 results[3].x += fl1.x * fr2.y;\ 677 results[3].y += fl1.y * fr2.y;\ 678 results[3].z += fl2.x * fr2.y;\ 679 results[3].w += fl2.y * fr2.y;\ 685 for (
int koff = 0; koff < 16; koff ++) {
687 float2 fl1 = lhs_shmem2[koff][
threadIdx.x];
688 float2 fl2 = lhs_shmem2[koff + 16][
threadIdx.x];
691 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
692 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
694 add_vals(fl1, fl2, fr1, fr2)
703 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
704 for (
int i = 0;
i < 4;
i++) {
705 output(lhs_vert, horiz_base +
i) = results[
i].x;
706 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
707 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
708 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
710 }
else if (!CHECK_RHS_BOUNDARY) {
712 if (lhs_vert + 3 < m_size) {
713 for (
int i = 0;
i < 4;
i++) {
714 output(lhs_vert, horiz_base +
i) = results[
i].x;
715 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
716 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
717 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
719 }
else if (lhs_vert + 2 < m_size) {
720 for (
int i = 0;
i < 4;
i++) {
721 output(lhs_vert, horiz_base +
i) = results[
i].x;
722 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
723 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
725 }
else if (lhs_vert + 1 < m_size) {
726 for (
int i = 0;
i < 4;
i++) {
727 output(lhs_vert, horiz_base +
i) = results[
i].x;
728 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
730 }
else if (lhs_vert < m_size) {
731 for (
int i = 0;
i < 4;
i++) {
732 output(lhs_vert, horiz_base +
i) = results[
i].x;
735 }
else if (!CHECK_LHS_BOUNDARY) {
745 for (
int i = 0;
i < 4;
i++) {
746 if (horiz_base+
i < n_size) {
747 output(lhs_vert, horiz_base +
i) = results[
i].x;
748 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
749 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
750 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
755 for (
int i = 0;
i < 4;
i++) {
756 if (horiz_base+
i < n_size) {
757 if (lhs_vert < m_size)
758 output(lhs_vert, horiz_base +
i) = results[
i].x;
759 if (lhs_vert + 1 < m_size)
760 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
761 if (lhs_vert + 2 < m_size)
762 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
763 if (lhs_vert + 3 < m_size)
764 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
771 template<
typename Index,
typename LhsMapper,
772 typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
773 bool CHECK_RHS_BOUNDARY>
774 __device__ __forceinline__
void 775 EigenFloatContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
776 const OutputMapper output, float2 lhs_shmem2[][32],
777 float2 rhs_shmem2[][8],
const Index m_size,
778 const Index n_size,
const Index k_size,
779 const Index base_m,
const Index base_n) {
782 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
783 float4 rhs_pf0, rhs_pf1;
786 for (
int i=0;
i < 8;
i++) {
787 results[
i].x = results[
i].y = results[
i].z = results[
i].w = 0;
791 for (Index k = 0; k < k_size; k += 32) {
792 lhs_pf0 = internal::pset1<float4>(0);
793 lhs_pf1 = internal::pset1<float4>(0);
794 lhs_pf2 = internal::pset1<float4>(0);
795 lhs_pf3 = internal::pset1<float4>(0);
797 rhs_pf0 = internal::pset1<float4>(0);
798 rhs_pf1 = internal::pset1<float4>(0);
800 if (!CHECK_LHS_BOUNDARY) {
802 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
803 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+8));
804 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+16));
805 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+24));
806 }
else if ((
threadIdx.y/4+k+16) < k_size) {
807 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
808 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+8));
809 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+16));
810 }
else if ((
threadIdx.y/4+k+8) < k_size) {
811 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
812 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+8));
814 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
818 if (lhs_vert + 3 < m_size) {
820 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
821 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+8));
822 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+16));
823 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+24));
824 }
else if ((
threadIdx.y/4+k+16) < k_size) {
825 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
826 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+8));
827 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+16));
828 }
else if ((
threadIdx.y/4+k+8) < k_size) {
829 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
830 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k+8));
832 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (
threadIdx.y/4+k));
834 }
else if (lhs_vert + 2 < m_size) {
836 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
837 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
838 lhs_pf0.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k));
839 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
840 lhs_pf1.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+8));
841 lhs_pf1.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k+8));
842 lhs_pf2.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+16));
843 lhs_pf2.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+16));
844 lhs_pf2.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k+16));
845 lhs_pf3.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+24));
846 lhs_pf3.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+24));
847 lhs_pf3.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k+24));
848 }
else if ((
threadIdx.y/4+k+16) < k_size) {
849 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
850 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
851 lhs_pf0.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k));
852 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
853 lhs_pf1.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+8));
854 lhs_pf1.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k+8));
855 lhs_pf2.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+16));
856 lhs_pf2.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+16));
857 lhs_pf2.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k+16));
858 }
else if ((
threadIdx.y/4+k+8) < k_size) {
859 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
860 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
861 lhs_pf0.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k));
862 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
863 lhs_pf1.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+8));
864 lhs_pf1.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k+8));
866 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
867 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
868 lhs_pf0.z =lhs(lhs_vert + 2, (
threadIdx.y/4+k));
870 }
else if (lhs_vert + 1 < m_size) {
872 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
873 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
874 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
875 lhs_pf1.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+8));
876 lhs_pf2.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+16));
877 lhs_pf2.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+16));
878 lhs_pf3.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+24));
879 lhs_pf3.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+24));
880 }
else if ((
threadIdx.y/4+k+16) < k_size) {
881 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
882 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
883 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
884 lhs_pf1.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+8));
885 lhs_pf2.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+16));
886 lhs_pf2.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+16));
887 }
else if ((
threadIdx.y/4+k+8) < k_size) {
888 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
889 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
890 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
891 lhs_pf1.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k+8));
893 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
894 lhs_pf0.y =lhs(lhs_vert + 1, (
threadIdx.y/4+k));
896 }
else if (lhs_vert < m_size) {
898 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
899 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
900 lhs_pf2.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+16));
901 lhs_pf3.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+24));
902 }
else if ((
threadIdx.y/4+k+16) < k_size) {
903 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
904 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
905 lhs_pf2.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+16));
906 }
else if ((
threadIdx.y/4+k+8) < k_size) {
907 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
908 lhs_pf1.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k+8));
910 lhs_pf0.x =lhs(lhs_vert + 0, (
threadIdx.y/4+k));
917 Index rhs_horiz1 =
threadIdx.y*2+1+base_n;
918 if (!CHECK_RHS_BOUNDARY) {
919 if ((rhs_vert + 3) < k_size) {
921 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
922 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
923 }
else if (rhs_vert + 2 < k_size) {
925 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
926 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
927 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
928 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
929 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
930 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
931 }
else if (rhs_vert + 1 < k_size) {
932 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
933 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
934 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
935 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
936 }
else if (rhs_vert < k_size) {
937 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
938 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
941 if (rhs_horiz1 < n_size) {
942 if ((rhs_vert + 3) < k_size) {
944 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
945 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
946 }
else if (rhs_vert + 2 < k_size) {
948 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
949 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
950 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
951 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
952 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
953 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
954 }
else if (k+
threadIdx.x*4 + 1 < k_size) {
955 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
956 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
957 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
958 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
963 }
else if (rhs_horiz0 < n_size) {
964 if ((rhs_vert + 3) < k_size) {
966 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
967 }
else if ((rhs_vert + 2) < k_size) {
969 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
970 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
971 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
972 }
else if ((rhs_vert + 1) < k_size) {
973 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
974 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
975 }
else if (rhs_vert < k_size) {
976 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\ 1007 results[0].x += a_feat1.x * f1.x;\ 1008 results[1].x += a_feat1.x * f1.y;\ 1009 results[2].x += a_feat1.x * f2.x;\ 1010 results[3].x += a_feat1.x * f2.y;\ 1011 results[4].x += a_feat1.x * f3.x;\ 1012 results[5].x += a_feat1.x * f3.y;\ 1013 results[6].x += a_feat1.x * f4.x;\ 1014 results[7].x += a_feat1.x * f4.y;\ 1016 results[0].y += a_feat1.y * f1.x;\ 1017 results[1].y += a_feat1.y * f1.y;\ 1018 results[2].y += a_feat1.y * f2.x;\ 1019 results[3].y += a_feat1.y * f2.y;\ 1020 results[4].y += a_feat1.y * f3.x;\ 1021 results[5].y += a_feat1.y * f3.y;\ 1022 results[6].y += a_feat1.y * f4.x;\ 1023 results[7].y += a_feat1.y * f4.y;\ 1025 results[0].z += a_feat2.x * f1.x;\ 1026 results[1].z += a_feat2.x * f1.y;\ 1027 results[2].z += a_feat2.x * f2.x;\ 1028 results[3].z += a_feat2.x * f2.y;\ 1029 results[4].z += a_feat2.x * f3.x;\ 1030 results[5].z += a_feat2.x * f3.y;\ 1031 results[6].z += a_feat2.x * f4.x;\ 1032 results[7].z += a_feat2.x * f4.y;\ 1034 results[0].w += a_feat2.y * f1.x;\ 1035 results[1].w += a_feat2.y * f1.y;\ 1036 results[2].w += a_feat2.y * f2.x;\ 1037 results[3].w += a_feat2.y * f2.y;\ 1038 results[4].w += a_feat2.y * f3.x;\ 1039 results[5].w += a_feat2.y * f3.y;\ 1040 results[6].w += a_feat2.y * f4.x;\ 1041 results[7].w += a_feat2.y * f4.y;\ 1057 for (
int koff = 0; koff < 32; koff ++) {
1062 int start_feature = (
threadIdx.y / 4) * 8;
1064 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1065 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1066 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1067 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1069 add_vals(a3, a4, br1, br2, br3, br4)
1075 Index horiz_base = (
threadIdx.y/4)*8+base_n;
1076 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1077 for (
int i = 0;
i < 8;
i++) {
1078 output(lhs_vert, horiz_base +
i) = results[
i].x;
1079 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
1080 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
1081 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
1083 }
else if (!CHECK_RHS_BOUNDARY) {
1084 if (lhs_vert + 3 < m_size) {
1085 for (
int i = 0;
i < 8;
i++) {
1086 output(lhs_vert, horiz_base +
i) = results[
i].x;
1087 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
1088 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
1089 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
1091 }
else if (lhs_vert + 2 < m_size) {
1092 for (
int i = 0;
i < 8;
i++) {
1093 output(lhs_vert, horiz_base +
i) = results[
i].x;
1094 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
1095 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
1097 }
else if (lhs_vert + 1 < m_size) {
1098 for (
int i = 0;
i < 8;
i++) {
1099 output(lhs_vert, horiz_base +
i) = results[
i].x;
1100 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
1102 }
else if (lhs_vert < m_size) {
1103 for (
int i = 0;
i < 8;
i++) {
1104 output(lhs_vert, horiz_base +
i) = results[
i].x;
1107 }
else if (!CHECK_LHS_BOUNDARY) {
1109 for (
int i = 0;
i < 8;
i++) {
1110 if (horiz_base +
i < n_size) {
1111 output(lhs_vert, horiz_base +
i) = results[
i].x;
1112 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
1113 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
1114 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
1119 for (
int i = 0;
i < 8;
i++) {
1120 if (horiz_base +
i < n_size) {
1121 if (lhs_vert < m_size)
1122 output(lhs_vert, horiz_base +
i) = results[
i].x;
1123 if (lhs_vert + 1 < m_size)
1124 output(lhs_vert + 1, horiz_base +
i) = results[
i].y;
1125 if (lhs_vert + 2 < m_size)
1126 output(lhs_vert + 2, horiz_base +
i) = results[
i].z;
1127 if (lhs_vert + 3 < m_size)
1128 output(lhs_vert + 3, horiz_base +
i) = results[
i].w;
1135 template<
typename Index,
typename LhsMapper,
1136 typename RhsMapper,
typename OutputMapper>
1138 #if defined(EIGEN_HIPCC) 1139 __launch_bounds__(256, 1)
1141 __launch_bounds__(256)
1143 EigenFloatContractionKernel(
const LhsMapper lhs,
const RhsMapper rhs,
1144 const OutputMapper output,
1145 const Index m_size,
const Index n_size,
const Index k_size) {
1146 __shared__ float2 lhs_shmem[64*32];
1147 __shared__ float2 rhs_shmem[128*8];
1149 typedef float2 LHS_MEM[64][32];
1150 typedef float2 RHS_MEM[128][8];
1152 const Index m_block_idx =
blockIdx.x;
1153 const Index n_block_idx =
blockIdx.y;
1155 const Index base_m = 128 * m_block_idx;
1156 const Index base_n = 64 * n_block_idx;
1158 bool check_rhs = (base_n + 63) >= n_size;
1159 bool check_lhs128 = (base_m + 127) >= m_size;
1162 if (!check_lhs128) {
1164 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1165 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1167 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1168 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1171 if (!check_lhs128) {
1173 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1174 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1176 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1177 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1182 template<
typename Index,
typename LhsMapper,
1183 typename RhsMapper,
typename OutputMapper>
1185 #if defined(EIGEN_HIPCC) 1186 __launch_bounds__(256, 1)
1188 __launch_bounds__(256)
1190 EigenFloatContractionKernel16x16(
const LhsMapper lhs,
const RhsMapper rhs,
1191 const OutputMapper output,
1192 const Index m_size,
const Index n_size,
const Index k_size) {
1193 __shared__ float2 lhs_shmem[32][16];
1194 __shared__ float2 rhs_shmem[64][8];
1196 const Index m_block_idx =
blockIdx.x;
1197 const Index n_block_idx =
blockIdx.y;
1199 const Index base_m = 64 * m_block_idx;
1200 const Index base_n = 64 * n_block_idx;
1202 if (base_m + 63 < m_size) {
1203 if (base_n + 63 < n_size) {
1204 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1206 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1209 if (base_n + 63 < n_size) {
1210 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1212 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1218 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
1219 struct TensorEvaluator<const TensorContractionOp<
Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1220 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1222 typedef GpuDevice Device;
1224 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1225 typedef TensorContractionEvaluatorBase<Self>
Base;
1227 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
XprType;
1241 typedef typename internal::conditional<
1242 static_cast<int>(
Layout) == static_cast<int>(
ColMajor), LeftArgType, RightArgType>
::type EvalLeftArgType;
1243 typedef typename internal::conditional<
1244 static_cast<int>(
Layout) == static_cast<int>(
ColMajor), RightArgType, LeftArgType>
::type EvalRightArgType;
1246 static const int LDims =
1247 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>
::value;
1248 static const int RDims =
1249 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>
::value;
1256 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
1257 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
1259 static const int NumDims = LDims + RDims - 2 * ContractDims;
1267 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1268 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1270 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1271 typedef typename RightEvaluator::Dimensions RightDimensions;
1277 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1282 this->m_leftImpl.evalSubExprsIfNeeded(
NULL);
1283 this->m_rightImpl.evalSubExprsIfNeeded(
NULL);
1288 this->m_result =
static_cast<Scalar *
>(this->
m_device.allocate(this->
dimensions().TotalSize() *
sizeof(Scalar)));
1289 evalTo(this->m_result);
1294 void evalTo(Scalar*
buffer)
const {
1295 if (this->m_lhs_inner_dim_contiguous) {
1296 if (this->m_rhs_inner_dim_contiguous) {
1297 if (this->m_rhs_inner_dim_reordered) {
1298 evalTyped<true, true, true, Unaligned>(buffer);
1301 evalTyped<true, true, false, Unaligned>(buffer);
1305 if (this->m_rhs_inner_dim_reordered) {
1306 evalTyped<true, false, true, Unaligned>(buffer);
1309 evalTyped<true, false, false, Unaligned>(buffer);
1314 if (this->m_rhs_inner_dim_contiguous) {
1315 if (this->m_rhs_inner_dim_reordered) {
1316 evalTyped<false, true, true, Unaligned>(buffer);
1319 evalTyped<false, true, false, Unaligned>(buffer);
1323 if (this->m_rhs_inner_dim_reordered) {
1324 evalTyped<false, false, true, Unaligned>(buffer);
1327 evalTyped<false, false, false, Unaligned>(buffer);
1333 template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
struct LaunchKernels {
1334 static void Run(
const LhsMapper& lhs,
const RhsMapper& rhs,
const OutputMapper& output, Index
m, Index
n, Index k,
const GpuDevice& device) {
1335 const Index m_blocks = (m + 63) / 64;
1336 const Index n_blocks = (n + 63) / 64;
1337 const dim3 num_blocks(m_blocks, n_blocks, 1);
1338 const dim3 block_size(8, 8, 8);
1339 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1343 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
struct LaunchKernels<
float,
float, Index, LhsMapper, RhsMapper, OutputMapper> {
1344 static void Run(
const LhsMapper& lhs,
const RhsMapper& rhs,
const OutputMapper& output, Index
m, Index
n, Index k,
const GpuDevice& device) {
1345 if (m < 768 || n < 768) {
1346 const Index m_blocks = (m + 63) / 64;
1347 const Index n_blocks = (n + 63) / 64;
1348 const dim3 num_blocks(m_blocks, n_blocks, 1);
1349 const dim3 block_size(16, 16, 1);
1350 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1352 const Index m_blocks = (m + 127) / 128;
1353 const Index n_blocks = (n + 63) / 64;
1354 const dim3 num_blocks(m_blocks, n_blocks, 1);
1355 const dim3 block_size(8, 32, 1);
1356 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1361 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1362 void evalTyped(Scalar* buffer)
const {
1364 const Index k = this->m_k_size;
1368 const Index m = this->m_i_size;
1371 const Index n = this->m_j_size;
1374 this->
m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
1377 LeftEvaluator, left_nocontract_t,
1379 lhs_inner_dim_contiguous,
1383 RightEvaluator, right_nocontract_t,
1385 rhs_inner_dim_contiguous,
1386 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
1388 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1392 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1393 this->m_left_contracting_strides, this->m_k_strides);
1395 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1396 this->m_right_contracting_strides, this->m_k_strides);
1398 OutputMapper output(buffer, m);
1400 #if defined(EIGEN_USE_HIP) 1401 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1403 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1406 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->
m_device);
1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC 1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
typename XprType::Scalar type
#define EIGEN_STRONG_INLINE
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Derived::Scalar CoeffReturnType
PyObject * conv(PyObject *o)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Namespace containing all symbols from the Eigen library.
Pose3 x2(Rot3::Ypr(0.0, 0.0, 0.0), l2)
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
const Device EIGEN_DEVICE_REF m_device
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
PacketType< CoeffReturnType, Device >::type PacketReturnType
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Derived::Dimensions Dimensions
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
std::map< std::string, Array< float, 1, 8, DontAlign|RowMajor > > results
std::vector< size_t > Indices
EIGEN_STRONG_INLINE TensorEvaluator(const Derived &m, const Device &device)
#define EIGEN_UNUSED_VARIABLE(var)
internal::packet_traits< Scalar >::type type