TensorContractionGpu.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
6 // Copyright (C) 2014 Eric Martin <eric@ericmart.in>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
14 
15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
16 
17 namespace Eigen {
18 
19 template<typename Scalar, typename Index, typename LhsMapper,
20  typename RhsMapper, typename OutputMapper, bool needs_edge_check>
21 __device__ EIGEN_STRONG_INLINE void
22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
23  const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
24  const Index m_size, const Index n_size, const Index k_size) {
25 
26  const Index m_block_idx = blockIdx.x;
27  const Index n_block_idx = blockIdx.y;
28 
29  const Index base_m = 64 * m_block_idx;
30  const Index base_n = 64 * n_block_idx;
31 
32  // declare and initialize 64 registers for output 8x8 block
33 
34  // prefetch registers
35  Scalar lhs_pf0;
36  Scalar lhs_pf1;
37  Scalar lhs_pf2;
38  Scalar lhs_pf3;
39  Scalar lhs_pf4;
40  Scalar lhs_pf5;
41  Scalar lhs_pf6;
42  Scalar lhs_pf7;
43 
44  Scalar rhs_pf0;
45  Scalar rhs_pf1;
46  Scalar rhs_pf2;
47  Scalar rhs_pf3;
48  Scalar rhs_pf4;
49  Scalar rhs_pf5;
50  Scalar rhs_pf6;
51  Scalar rhs_pf7;
52 
53  // shared memory is formatted
54  // (contract idx in block, nocontract idx in block, block idx)
55  // where block idx is column major. This transposition limits the number of
56  // bank conflicts when reading the LHS. The core idea is that since the contracting
57  // index is shared by both sides, then the contracting index should be in threadIdx.x.
58 
59  // On the LHS, we pad each row inside of each block with an extra element. This makes
60  // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
61  // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
62 
63  // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
64  // conflicts on writes and also none on reads.
65 
66  // storage indices
67  const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
68  const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
69 
70  const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71  const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72  const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73  const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74  const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75  const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76  const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77  const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
78 
79  const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80  const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81  const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82  const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83  const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84  const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85  const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86  const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
87 
88  // in the loading code, the following variables are important:
89  // threadIdx.x: the vertical position in an 8x8 block
90  // threadIdx.y: the vertical index of the 8x8 block in the grid
91  // threadIdx.z: the horizontal position in an 8x8 block
92  // k: the horizontal index of the 8x8 block in the grid
93  //
94  // The k parameter is implicit (it was the loop counter for a loop that went
95  // from 0 to <8, but now that loop is unrolled in the below code.
96 
97  const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
98  const Index lhs_vert = base_m + load_idx_vert;
99 
100 #define prefetchIntoRegisters(base_k) \
101  { \
102  lhs_pf0 = conv(0); \
103  lhs_pf1 = conv(0); \
104  lhs_pf2 = conv(0); \
105  lhs_pf3 = conv(0); \
106  lhs_pf4 = conv(0); \
107  lhs_pf5 = conv(0); \
108  lhs_pf6 = conv(0); \
109  lhs_pf7 = conv(0); \
110  \
111  rhs_pf0 = conv(0); \
112  rhs_pf1 = conv(0); \
113  rhs_pf2 = conv(0); \
114  rhs_pf3 = conv(0); \
115  rhs_pf4 = conv(0); \
116  rhs_pf5 = conv(0); \
117  rhs_pf6 = conv(0); \
118  rhs_pf7 = conv(0); \
119  \
120  if (!needs_edge_check || lhs_vert < m_size) { \
121  const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
122  const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
123  const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
124  const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
125  const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
126  const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
127  const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
128  const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
129  \
130  if (!needs_edge_check || lhs_horiz_7 < k_size) { \
131  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
132  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
133  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
134  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
135  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
136  lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
137  lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
138  lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
139  } else if (lhs_horiz_6 < k_size) { \
140  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
141  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
142  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
143  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
144  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
145  lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
146  lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
147  } else if (lhs_horiz_5 < k_size) { \
148  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
149  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
150  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
151  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
152  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
153  lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
154  } else if (lhs_horiz_4 < k_size) { \
155  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
156  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
157  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
158  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
159  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
160  } else if (lhs_horiz_3 < k_size) { \
161  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
162  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
163  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
164  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
165  } else if (lhs_horiz_2 < k_size) { \
166  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
167  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
168  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
169  } else if (lhs_horiz_1 < k_size) { \
170  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
171  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
172  } else if (lhs_horiz_0 < k_size) { \
173  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
174  } \
175  } \
176  \
177  const Index rhs_vert = base_k + load_idx_vert; \
178  if (!needs_edge_check || rhs_vert < k_size) { \
179  const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
180  const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
181  const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
182  const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
183  const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
184  const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
185  const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
186  const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
187  \
188  if (rhs_horiz_7 < n_size) { \
189  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
190  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
191  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
192  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
193  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
194  rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
195  rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
196  rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
197  } else if (rhs_horiz_6 < n_size) { \
198  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
199  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
200  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
201  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
202  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
203  rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
204  rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
205  } else if (rhs_horiz_5 < n_size) { \
206  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
207  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
208  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
209  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
210  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
211  rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
212  } else if (rhs_horiz_4 < n_size) { \
213  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
214  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
215  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
216  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
217  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
218  } else if (rhs_horiz_3 < n_size) { \
219  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
220  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
221  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
222  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
223  } else if (rhs_horiz_2 < n_size) { \
224  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
225  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
226  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
227  } else if (rhs_horiz_1 < n_size) { \
228  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
229  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
230  } else if (rhs_horiz_0 < n_size) { \
231  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
232  } \
233  } \
234  } \
235 
236 #define writeRegToShmem(_) \
237  lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
238  rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
239  \
240  lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
241  rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
242  \
243  lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
244  rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
245  \
246  lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
247  rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
248  \
249  lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
250  rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
251  \
252  lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
253  rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
254  \
255  lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
256  rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
257  \
258  lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
259  rhs_shmem[rhs_store_idx_7] = rhs_pf7; \
260 
261  // declare and initialize result array
262 #define res(i, j) _res_##i##j
263 #define initResultRow(i) \
264  Scalar res(i, 0) = conv(0); \
265  Scalar res(i, 1) = conv(0); \
266  Scalar res(i, 2) = conv(0); \
267  Scalar res(i, 3) = conv(0); \
268  Scalar res(i, 4) = conv(0); \
269  Scalar res(i, 5) = conv(0); \
270  Scalar res(i, 6) = conv(0); \
271  Scalar res(i, 7) = conv(0); \
272 
273  internal::scalar_cast_op<int, Scalar> conv;
274  initResultRow(0);
275  initResultRow(1);
276  initResultRow(2);
277  initResultRow(3);
278  initResultRow(4);
279  initResultRow(5);
280  initResultRow(6);
281  initResultRow(7);
282 #undef initResultRow
283 
284  for (Index base_k = 0; base_k < k_size; base_k += 64) {
285  // wait for previous iteration to finish with shmem. Despite common sense,
286  // the code is a bit faster with this here then at bottom of loop
287  __syncthreads();
288 
289  prefetchIntoRegisters(base_k);
290  writeRegToShmem();
291 
292  #undef prefetchIntoRegisters
293  #undef writeRegToShmem
294 
295  // wait for shared mem packing to be done before starting computation
296  __syncthreads();
297 
298  // compute 8x8 matrix product by outer product. This involves packing one column
299  // of LHS and one row of RHS into registers (takes 16 registers).
300 
301 #define lcol(i) _lcol##i
302  Scalar lcol(0);
303  Scalar lcol(1);
304  Scalar lcol(2);
305  Scalar lcol(3);
306  Scalar lcol(4);
307  Scalar lcol(5);
308  Scalar lcol(6);
309  Scalar lcol(7);
310 
311 #define rrow(j) _rrow##j
312  Scalar rrow(0);
313  Scalar rrow(1);
314  Scalar rrow(2);
315  Scalar rrow(3);
316  Scalar rrow(4);
317  Scalar rrow(5);
318  Scalar rrow(6);
319  Scalar rrow(7);
320 
321  // Now x corresponds to k, y to m, and z to n
322  const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
323  const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
324 
325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
327 
328 #define loadData(i, j) \
329  lcol(0) = lhs_element(0, j); \
330  rrow(0) = rhs_element(i, 0); \
331  lcol(1) = lhs_element(1, j); \
332  rrow(1) = rhs_element(i, 1); \
333  lcol(2) = lhs_element(2, j); \
334  rrow(2) = rhs_element(i, 2); \
335  lcol(3) = lhs_element(3, j); \
336  rrow(3) = rhs_element(i, 3); \
337  lcol(4) = lhs_element(4, j); \
338  rrow(4) = rhs_element(i, 4); \
339  lcol(5) = lhs_element(5, j); \
340  rrow(5) = rhs_element(i, 5); \
341  lcol(6) = lhs_element(6, j); \
342  rrow(6) = rhs_element(i, 6); \
343  lcol(7) = lhs_element(7, j); \
344  rrow(7) = rhs_element(i, 7); \
345 
346 #define computeCol(j) \
347  res(0, j) += lcol(0) * rrow(j); \
348  res(1, j) += lcol(1) * rrow(j); \
349  res(2, j) += lcol(2) * rrow(j); \
350  res(3, j) += lcol(3) * rrow(j); \
351  res(4, j) += lcol(4) * rrow(j); \
352  res(5, j) += lcol(5) * rrow(j); \
353  res(6, j) += lcol(6) * rrow(j); \
354  res(7, j) += lcol(7) * rrow(j); \
355 
356 #define computePass(i) \
357  loadData(i, i); \
358  \
359  computeCol(0); \
360  computeCol(1); \
361  computeCol(2); \
362  computeCol(3); \
363  computeCol(4); \
364  computeCol(5); \
365  computeCol(6); \
366  computeCol(7); \
367 
368  computePass(0);
369  computePass(1);
370  computePass(2);
371  computePass(3);
372  computePass(4);
373  computePass(5);
374  computePass(6);
375  computePass(7);
376 
377 #undef lcol
378 #undef rrow
379 #undef lhs_element
380 #undef rhs_element
381 #undef loadData
382 #undef computeCol
383 #undef computePass
384  } // end loop over k
385 
386  // we've now iterated over all of the large (ie width 64) k blocks and
387  // accumulated results in registers. At this point thread (x, y, z) contains
388  // the sum across all big k blocks of the product of little k block of index (x, y)
389  // with block of index (y, z). To compute the final output, we need to reduce
390  // the 8 threads over y by summation.
391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
393 #else
394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
395 #endif
396 
397 #define reduceRow(i, mask) \
398  shuffleInc(i, 0, mask); \
399  shuffleInc(i, 1, mask); \
400  shuffleInc(i, 2, mask); \
401  shuffleInc(i, 3, mask); \
402  shuffleInc(i, 4, mask); \
403  shuffleInc(i, 5, mask); \
404  shuffleInc(i, 6, mask); \
405  shuffleInc(i, 7, mask); \
406 
407 #define reduceMatrix(mask) \
408  reduceRow(0, mask); \
409  reduceRow(1, mask); \
410  reduceRow(2, mask); \
411  reduceRow(3, mask); \
412  reduceRow(4, mask); \
413  reduceRow(5, mask); \
414  reduceRow(6, mask); \
415  reduceRow(7, mask); \
416 
417  // actually perform the reduction, now each thread of index (_, y, z)
418  // contains the correct values in its registers that belong in the output
419  // block
420  reduceMatrix(1);
421  reduceMatrix(2);
422  reduceMatrix(4);
423 
424 #undef shuffleInc
425 #undef reduceRow
426 #undef reduceMatrix
427 
428  // now we need to copy the 64 values into main memory. We can't split work
429  // among threads because all variables are in registers. There's 2 ways
430  // to do this:
431  // (1) have 1 thread do 64 writes from registers into global memory
432  // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
433  // each do 8 writes into global memory. We can just overwrite the shared
434  // memory from the problem we just solved.
435  // (2) is slightly faster than (1) due to less branching and more ILP
436 
437  // TODO: won't yield much gain, but could just use currently unused shared mem
438  // and then we won't have to sync
439  // wait for shared mem to be out of use
440  __syncthreads();
441 
442 #define writeResultShmem(i, j) \
443  lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
444 
445 #define writeRow(i) \
446  writeResultShmem(i, 0); \
447  writeResultShmem(i, 1); \
448  writeResultShmem(i, 2); \
449  writeResultShmem(i, 3); \
450  writeResultShmem(i, 4); \
451  writeResultShmem(i, 5); \
452  writeResultShmem(i, 6); \
453  writeResultShmem(i, 7); \
454 
455  if (threadIdx.x == 0) {
456  writeRow(0);
457  writeRow(1);
458  writeRow(2);
459  writeRow(3);
460  writeRow(4);
461  writeRow(5);
462  writeRow(6);
463  writeRow(7);
464  }
465 #undef writeResultShmem
466 #undef writeRow
467 
468  const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
469  const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
470 
471  if (threadIdx.x < max_i_write) {
472  if (max_j_write == 8) {
473  // TODO: can i trade bank conflicts for coalesced writes?
474  Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
475  Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
476  Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
477  Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
478  Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
479  Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
480  Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
481  Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
482 
483  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
484  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
485  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
486  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
487  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
488  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
489  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
490  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
491  } else {
492 #pragma unroll 7
493  for (int j = 0; j < max_j_write; j++) {
494  Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
495  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
496  }
497  }
498  }
499 #undef res
500 }
501 
502 
503 template<typename Scalar, typename Index, typename LhsMapper,
504  typename RhsMapper, typename OutputMapper>
505 __global__ void
506 #if defined(EIGEN_HIPCC)
507 __launch_bounds__(512, 1)
508 #else
509 __launch_bounds__(512)
510 #endif
511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
512  const OutputMapper output,
513  const Index m_size, const Index n_size, const Index k_size) {
514  __shared__ Scalar lhs_shmem[72 * 64];
515  __shared__ Scalar rhs_shmem[72 * 64];
516 
517  const Index m_block_idx = blockIdx.x;
518  const Index n_block_idx = blockIdx.y;
519 
520  const Index base_m = 64 * m_block_idx;
521  const Index base_n = 64 * n_block_idx;
522 
523  if (base_m + 63 < m_size && base_n + 63 < n_size) {
524  EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
525  } else {
526  EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
527  }
528 }
529 
530 
531 template<typename Index, typename LhsMapper,
532  typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
533  bool CHECK_RHS_BOUNDARY>
534 __device__ __forceinline__ void
535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
536  const OutputMapper output, float2 lhs_shmem2[][16],
537  float2 rhs_shmem2[][8], const Index m_size,
538  const Index n_size, const Index k_size,
539  const Index base_m, const Index base_n) {
540 
541  // prefetch registers
542  float4 lhs_pf0, rhs_pf0;
543 
544  float4 results[4];
545  for (int i=0; i < 4; i++) {
546  results[i].x = results[i].y = results[i].z = results[i].w = 0;
547  }
548 
549 #define prefetch_lhs(reg, row, col) \
550  if (!CHECK_LHS_BOUNDARY) { \
551  if (col < k_size) { \
552  reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
553  } \
554  } else { \
555  if (col < k_size) { \
556  if (row + 3 < m_size) { \
557  reg =lhs.template loadPacket<float4,Unaligned>(row, col); \
558  } else if (row + 2 < m_size) { \
559  reg.x =lhs(row + 0, col); \
560  reg.y =lhs(row + 1, col); \
561  reg.z =lhs(row + 2, col); \
562  } else if (row + 1 < m_size) { \
563  reg.x =lhs(row + 0, col); \
564  reg.y =lhs(row + 1, col); \
565  } else if (row < m_size) { \
566  reg.x =lhs(row + 0, col); \
567  } \
568  } \
569  } \
570 
571  Index lhs_vert = base_m+threadIdx.x*4;
572 
573  for (Index k = 0; k < k_size; k += 16) {
574 
575  lhs_pf0 = internal::pset1<float4>(0);
576  rhs_pf0 = internal::pset1<float4>(0);
577 
578  Index lhs_horiz = threadIdx.y+k;
579  prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
580 
581  Index rhs_vert = k+(threadIdx.x%4)*4;
582  Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
583 
584  if (!CHECK_RHS_BOUNDARY) {
585  if ((rhs_vert + 3) < k_size) {
586  // just CHECK_RHS_BOUNDARY
587  rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
588  } else if (rhs_vert + 2 < k_size) {
589  // just CHECK_RHS_BOUNDARY
590  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
593  } else if (rhs_vert + 1 < k_size) {
594  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
595  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
596  } else if (rhs_vert < k_size) {
597  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
598  }
599  } else {
600  if (rhs_horiz0 < n_size) {
601  if ((rhs_vert + 3) < k_size) {
602  rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
603  } else if ((rhs_vert + 2) < k_size) {
604  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
607  } else if ((rhs_vert + 1) < k_size) {
608  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
609  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
610  } else if (rhs_vert < k_size) {
611  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
612  }
613  }
614  }
615  float x1, x2 ;
616  // the following can be a bitwise operation..... some day.
617  if((threadIdx.x%8) < 4) {
618  x1 = rhs_pf0.y;
619  x2 = rhs_pf0.w;
620  } else {
621  x1 = rhs_pf0.x;
622  x2 = rhs_pf0.z;
623  }
624  #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
625  x1 = __shfl_xor(x1, 4);
626  x2 = __shfl_xor(x2, 4);
627  #else
628  x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
629  x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
630  #endif
631  if((threadIdx.x%8) < 4) {
632  rhs_pf0.y = x1;
633  rhs_pf0.w = x2;
634  } else {
635  rhs_pf0.x = x1;
636  rhs_pf0.z = x2;
637  }
638 
639  // We have 64 features.
640  // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
641  // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
642  // ...
643  // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
644  // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
645  // ...
646  rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
647  rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
648 
649  // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
650  // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
651  // ...
652  // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
653  // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63)
654  // ...
655 
656  lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
657  lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
658 
659 
660 #define add_vals(fl1, fl2, fr1, fr2)\
661  results[0].x += fl1.x * fr1.x;\
662  results[0].y += fl1.y * fr1.x;\
663  results[0].z += fl2.x * fr1.x;\
664  results[0].w += fl2.y * fr1.x;\
665 \
666  results[1].x += fl1.x * fr1.y;\
667  results[1].y += fl1.y * fr1.y;\
668  results[1].z += fl2.x * fr1.y;\
669  results[1].w += fl2.y * fr1.y;\
670 \
671  results[2].x += fl1.x * fr2.x;\
672  results[2].y += fl1.y * fr2.x;\
673  results[2].z += fl2.x * fr2.x;\
674  results[2].w += fl2.y * fr2.x;\
675 \
676  results[3].x += fl1.x * fr2.y;\
677  results[3].y += fl1.y * fr2.y;\
678  results[3].z += fl2.x * fr2.y;\
679  results[3].w += fl2.y * fr2.y;\
680 
681  __syncthreads();
682 
683  // Do the multiplies.
684  #pragma unroll
685  for (int koff = 0; koff < 16; koff ++) {
686  // 32 x threads.
687  float2 fl1 = lhs_shmem2[koff][threadIdx.x];
688  float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
689 
690  int start_feature = threadIdx.y * 4;
691  float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
692  float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
693 
694  add_vals(fl1, fl2, fr1, fr2)
695  }
696  __syncthreads();
697  }
698 
699 #undef prefetch_lhs
700 #undef add_vals
701 
702  Index horiz_base = threadIdx.y*4+base_n;
703  if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
704  for (int i = 0; i < 4; i++) {
705  output(lhs_vert, horiz_base + i) = results[i].x;
706  output(lhs_vert + 1, horiz_base + i) = results[i].y;
707  output(lhs_vert + 2, horiz_base + i) = results[i].z;
708  output(lhs_vert + 3, horiz_base + i) = results[i].w;
709  }
710  } else if (!CHECK_RHS_BOUNDARY) {
711  // CHECK LHS
712  if (lhs_vert + 3 < m_size) {
713  for (int i = 0; i < 4; i++) {
714  output(lhs_vert, horiz_base + i) = results[i].x;
715  output(lhs_vert + 1, horiz_base + i) = results[i].y;
716  output(lhs_vert + 2, horiz_base + i) = results[i].z;
717  output(lhs_vert + 3, horiz_base + i) = results[i].w;
718  }
719  } else if (lhs_vert + 2 < m_size) {
720  for (int i = 0; i < 4; i++) {
721  output(lhs_vert, horiz_base + i) = results[i].x;
722  output(lhs_vert + 1, horiz_base + i) = results[i].y;
723  output(lhs_vert + 2, horiz_base + i) = results[i].z;
724  }
725  } else if (lhs_vert + 1 < m_size) {
726  for (int i = 0; i < 4; i++) {
727  output(lhs_vert, horiz_base + i) = results[i].x;
728  output(lhs_vert + 1, horiz_base + i) = results[i].y;
729  }
730  } else if (lhs_vert < m_size) {
731  for (int i = 0; i < 4; i++) {
732  output(lhs_vert, horiz_base + i) = results[i].x;
733  }
734  }
735  } else if (!CHECK_LHS_BOUNDARY) {
736  // CHECK RHS
737  /*
738  int ncols_rem = fminf(n_size- horiz_base, 4);
739  for (int i = 0; i < ncols_rem; i++) {
740  output(lhs_vert, horiz_base + i) = results[i].x;
741  output(lhs_vert + 1, horiz_base + i) = results[i].y;
742  output(lhs_vert + 2, horiz_base + i) = results[i].z;
743  output(lhs_vert + 3, horiz_base + i) = results[i].w;
744  }*/
745  for (int i = 0; i < 4; i++) {
746  if (horiz_base+i < n_size) {
747  output(lhs_vert, horiz_base + i) = results[i].x;
748  output(lhs_vert + 1, horiz_base + i) = results[i].y;
749  output(lhs_vert + 2, horiz_base + i) = results[i].z;
750  output(lhs_vert + 3, horiz_base + i) = results[i].w;
751  }
752  }
753  } else {
754  // CHECK both boundaries.
755  for (int i = 0; i < 4; i++) {
756  if (horiz_base+i < n_size) {
757  if (lhs_vert < m_size)
758  output(lhs_vert, horiz_base + i) = results[i].x;
759  if (lhs_vert + 1 < m_size)
760  output(lhs_vert + 1, horiz_base + i) = results[i].y;
761  if (lhs_vert + 2 < m_size)
762  output(lhs_vert + 2, horiz_base + i) = results[i].z;
763  if (lhs_vert + 3 < m_size)
764  output(lhs_vert + 3, horiz_base + i) = results[i].w;
765  }
766  }
767  }
768 }
769 
770 
771 template<typename Index, typename LhsMapper,
772  typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
773  bool CHECK_RHS_BOUNDARY>
774 __device__ __forceinline__ void
775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
776  const OutputMapper output, float2 lhs_shmem2[][32],
777  float2 rhs_shmem2[][8], const Index m_size,
778  const Index n_size, const Index k_size,
779  const Index base_m, const Index base_n) {
780 
781  // prefetch registers
782  float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
783  float4 rhs_pf0, rhs_pf1;
784 
785  float4 results[8];
786  for (int i=0; i < 8; i++) {
787  results[i].x = results[i].y = results[i].z = results[i].w = 0;
788  }
789 
790  Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
791  for (Index k = 0; k < k_size; k += 32) {
792  lhs_pf0 = internal::pset1<float4>(0);
793  lhs_pf1 = internal::pset1<float4>(0);
794  lhs_pf2 = internal::pset1<float4>(0);
795  lhs_pf3 = internal::pset1<float4>(0);
796 
797  rhs_pf0 = internal::pset1<float4>(0);
798  rhs_pf1 = internal::pset1<float4>(0);
799 
800  if (!CHECK_LHS_BOUNDARY) {
801  if ((threadIdx.y/4+k+24) < k_size) {
802  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
803  lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
804  lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
805  lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
806  } else if ((threadIdx.y/4+k+16) < k_size) {
807  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
808  lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
809  lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
810  } else if ((threadIdx.y/4+k+8) < k_size) {
811  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
812  lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
813  } else if ((threadIdx.y/4+k) < k_size) {
814  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
815  }
816  } else {
817  // just CHECK_LHS_BOUNDARY
818  if (lhs_vert + 3 < m_size) {
819  if ((threadIdx.y/4+k+24) < k_size) {
820  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
821  lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
822  lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
823  lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
824  } else if ((threadIdx.y/4+k+16) < k_size) {
825  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
826  lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
827  lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
828  } else if ((threadIdx.y/4+k+8) < k_size) {
829  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
830  lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
831  } else if ((threadIdx.y/4+k) < k_size) {
832  lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
833  }
834  } else if (lhs_vert + 2 < m_size) {
835  if ((threadIdx.y/4+k+24) < k_size) {
836  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
837  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
838  lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
839  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
840  lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
841  lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
842  lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
843  lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
844  lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
845  lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
846  lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
847  lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
848  } else if ((threadIdx.y/4+k+16) < k_size) {
849  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
850  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
851  lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
852  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
853  lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
854  lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
855  lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
856  lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
857  lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
858  } else if ((threadIdx.y/4+k+8) < k_size) {
859  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
860  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
861  lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
862  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
863  lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
864  lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
865  } else if ((threadIdx.y/4+k) < k_size) {
866  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
867  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
868  lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
869  }
870  } else if (lhs_vert + 1 < m_size) {
871  if ((threadIdx.y/4+k+24) < k_size) {
872  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
873  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
874  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
875  lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
876  lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
877  lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
878  lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
879  lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
880  } else if ((threadIdx.y/4+k+16) < k_size) {
881  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
882  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
883  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
884  lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
885  lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
886  lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
887  } else if ((threadIdx.y/4+k+8) < k_size) {
888  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
889  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
890  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
891  lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
892  } else if ((threadIdx.y/4+k) < k_size) {
893  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
894  lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
895  }
896  } else if (lhs_vert < m_size) {
897  if ((threadIdx.y/4+k+24) < k_size) {
898  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
899  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
900  lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
901  lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
902  } else if ((threadIdx.y/4+k+16) < k_size) {
903  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
904  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
905  lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
906  } else if ((threadIdx.y/4+k+8) < k_size) {
907  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
908  lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
909  } else if ((threadIdx.y/4+k) < k_size) {
910  lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
911  }
912  }
913  }
914  __syncthreads();
915  Index rhs_vert = k+threadIdx.x*4;
916  Index rhs_horiz0 = threadIdx.y*2+base_n;
917  Index rhs_horiz1 = threadIdx.y*2+1+base_n;
918  if (!CHECK_RHS_BOUNDARY) {
919  if ((rhs_vert + 3) < k_size) {
920  // just CHECK_RHS_BOUNDARY
921  rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
922  rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
923  } else if (rhs_vert + 2 < k_size) {
924  // just CHECK_RHS_BOUNDARY
925  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
926  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
927  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
928  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
929  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
930  rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
931  } else if (rhs_vert + 1 < k_size) {
932  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
933  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
934  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
935  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
936  } else if (rhs_vert < k_size) {
937  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
938  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
939  }
940  } else {
941  if (rhs_horiz1 < n_size) {
942  if ((rhs_vert + 3) < k_size) {
943  // just CHECK_RHS_BOUNDARY
944  rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
945  rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
946  } else if (rhs_vert + 2 < k_size) {
947  // just CHECK_RHS_BOUNDARY
948  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
949  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
950  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
951  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
952  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
953  rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
954  } else if (k+threadIdx.x*4 + 1 < k_size) {
955  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
956  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
957  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
958  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
959  } else if (k+threadIdx.x*4 < k_size) {
960  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
962  }
963  } else if (rhs_horiz0 < n_size) {
964  if ((rhs_vert + 3) < k_size) {
965  // just CHECK_RHS_BOUNDARY
966  rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
967  } else if ((rhs_vert + 2) < k_size) {
968  // just CHECK_RHS_BOUNDARY
969  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
970  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
971  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
972  } else if ((rhs_vert + 1) < k_size) {
973  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
974  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
975  } else if (rhs_vert < k_size) {
976  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
977  }
978  }
979  }
980  __syncthreads();
981  // Loaded. Do computation
982  // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
983  // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
984  // ..
985  // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
986  rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
987  // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
988  // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
989  // ..
990  rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
991  // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
992  // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
993  rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
994  // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
995  // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
996  rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
997 
998  // LHS.
999  // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
1000  // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
1001  // ...
1002  // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
1003  // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
1004 
1005 
1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
1007  results[0].x += a_feat1.x * f1.x;\
1008  results[1].x += a_feat1.x * f1.y;\
1009  results[2].x += a_feat1.x * f2.x;\
1010  results[3].x += a_feat1.x * f2.y;\
1011  results[4].x += a_feat1.x * f3.x;\
1012  results[5].x += a_feat1.x * f3.y;\
1013  results[6].x += a_feat1.x * f4.x;\
1014  results[7].x += a_feat1.x * f4.y;\
1015 \
1016  results[0].y += a_feat1.y * f1.x;\
1017  results[1].y += a_feat1.y * f1.y;\
1018  results[2].y += a_feat1.y * f2.x;\
1019  results[3].y += a_feat1.y * f2.y;\
1020  results[4].y += a_feat1.y * f3.x;\
1021  results[5].y += a_feat1.y * f3.y;\
1022  results[6].y += a_feat1.y * f4.x;\
1023  results[7].y += a_feat1.y * f4.y;\
1024 \
1025  results[0].z += a_feat2.x * f1.x;\
1026  results[1].z += a_feat2.x * f1.y;\
1027  results[2].z += a_feat2.x * f2.x;\
1028  results[3].z += a_feat2.x * f2.y;\
1029  results[4].z += a_feat2.x * f3.x;\
1030  results[5].z += a_feat2.x * f3.y;\
1031  results[6].z += a_feat2.x * f4.x;\
1032  results[7].z += a_feat2.x * f4.y;\
1033 \
1034  results[0].w += a_feat2.y * f1.x;\
1035  results[1].w += a_feat2.y * f1.y;\
1036  results[2].w += a_feat2.y * f2.x;\
1037  results[3].w += a_feat2.y * f2.y;\
1038  results[4].w += a_feat2.y * f3.x;\
1039  results[5].w += a_feat2.y * f3.y;\
1040  results[6].w += a_feat2.y * f4.x;\
1041  results[7].w += a_feat2.y * f4.y;\
1042 
1043  lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1044  lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1045  lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1046  lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1047 
1048  lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1049  lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1050  lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1051  lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1052 
1053  __syncthreads();
1054 
1055  // Do the multiplies.
1056  #pragma unroll
1057  for (int koff = 0; koff < 32; koff ++) {
1058  float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1059  float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1060 
1061  // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1062  int start_feature = (threadIdx.y / 4) * 8;
1063 
1064  float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4];
1065  float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1066  float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1067  float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1068 
1069  add_vals(a3, a4, br1, br2, br3, br4)
1070  }
1071  __syncthreads();
1072  } // end loop over k
1073 
1074  __syncthreads();
1075  Index horiz_base = (threadIdx.y/4)*8+base_n;
1076  if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1077  for (int i = 0; i < 8; i++) {
1078  output(lhs_vert, horiz_base + i) = results[i].x;
1079  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1080  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1081  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082  }
1083  } else if (!CHECK_RHS_BOUNDARY) {
1084  if (lhs_vert + 3 < m_size) {
1085  for (int i = 0; i < 8; i++) {
1086  output(lhs_vert, horiz_base + i) = results[i].x;
1087  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1088  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1089  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1090  }
1091  } else if (lhs_vert + 2 < m_size) {
1092  for (int i = 0; i < 8; i++) {
1093  output(lhs_vert, horiz_base + i) = results[i].x;
1094  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1095  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1096  }
1097  } else if (lhs_vert + 1 < m_size) {
1098  for (int i = 0; i < 8; i++) {
1099  output(lhs_vert, horiz_base + i) = results[i].x;
1100  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101  }
1102  } else if (lhs_vert < m_size) {
1103  for (int i = 0; i < 8; i++) {
1104  output(lhs_vert, horiz_base + i) = results[i].x;
1105  }
1106  }
1107  } else if (!CHECK_LHS_BOUNDARY) {
1108  // CHECK BOUNDARY_B
1109  for (int i = 0; i < 8; i++) {
1110  if (horiz_base + i < n_size) {
1111  output(lhs_vert, horiz_base + i) = results[i].x;
1112  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1113  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1114  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1115  }
1116  }
1117  } else {
1118  // CHECK both boundaries.
1119  for (int i = 0; i < 8; i++) {
1120  if (horiz_base + i < n_size) {
1121  if (lhs_vert < m_size)
1122  output(lhs_vert, horiz_base + i) = results[i].x;
1123  if (lhs_vert + 1 < m_size)
1124  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1125  if (lhs_vert + 2 < m_size)
1126  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1127  if (lhs_vert + 3 < m_size)
1128  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1129  }
1130  }
1131  }
1132 }
1133 
1134 
1135 template<typename Index, typename LhsMapper,
1136  typename RhsMapper, typename OutputMapper>
1137 __global__ void
1138 #if defined(EIGEN_HIPCC)
1139 __launch_bounds__(256, 1)
1140 #else
1141 __launch_bounds__(256)
1142 #endif
1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1144  const OutputMapper output,
1145  const Index m_size, const Index n_size, const Index k_size) {
1146  __shared__ float2 lhs_shmem[64*32];
1147  __shared__ float2 rhs_shmem[128*8];
1148 
1149  typedef float2 LHS_MEM[64][32];
1150  typedef float2 RHS_MEM[128][8];
1151 
1152  const Index m_block_idx = blockIdx.x;
1153  const Index n_block_idx = blockIdx.y;
1154 
1155  const Index base_m = 128 * m_block_idx;
1156  const Index base_n = 64 * n_block_idx;
1157 
1158  bool check_rhs = (base_n + 63) >= n_size;
1159  bool check_lhs128 = (base_m + 127) >= m_size;
1160 
1161  if (!check_rhs) {
1162  if (!check_lhs128) {
1163  // >= 128 rows left
1164  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1165  lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166  } else {
1167  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1168  lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169  }
1170  } else {
1171  if (!check_lhs128) {
1172  // >= 128 rows left
1173  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1174  lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1175  } else {
1176  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1177  lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1178  }
1179  }
1180 }
1181 
1182 template<typename Index, typename LhsMapper,
1183  typename RhsMapper, typename OutputMapper>
1184 __global__ void
1185 #if defined(EIGEN_HIPCC)
1186 __launch_bounds__(256, 1)
1187 #else
1188 __launch_bounds__(256)
1189 #endif
1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1191  const OutputMapper output,
1192  const Index m_size, const Index n_size, const Index k_size) {
1193  __shared__ float2 lhs_shmem[32][16];
1194  __shared__ float2 rhs_shmem[64][8];
1195 
1196  const Index m_block_idx = blockIdx.x;
1197  const Index n_block_idx = blockIdx.y;
1198 
1199  const Index base_m = 64 * m_block_idx;
1200  const Index base_n = 64 * n_block_idx;
1201 
1202  if (base_m + 63 < m_size) {
1203  if (base_n + 63 < n_size) {
1204  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1205  } else {
1206  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1207  }
1208  } else {
1209  if (base_n + 63 < n_size) {
1210  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1211  } else {
1212  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1213  }
1214  }
1215 }
1216 
1217 
1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
1220  public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1221 
1222  typedef GpuDevice Device;
1223 
1224  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1225  typedef TensorContractionEvaluatorBase<Self> Base;
1226 
1227  typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1229  typedef typename XprType::Index Index;
1230  typedef typename XprType::CoeffReturnType CoeffReturnType;
1232 
1233  enum {
1235  };
1236 
1237  // Most of the code is assuming that both input tensors are ColMajor. If the
1238  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1239  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1240  // will pretend B is LHS and A is RHS.
1241  typedef typename internal::conditional<
1242  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1243  typedef typename internal::conditional<
1244  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1245 
1246  static const int LDims =
1247  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1248  static const int RDims =
1249  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1250  static const int ContractDims = internal::array_size<Indices>::value;
1251 
1252  typedef array<Index, LDims> left_dim_mapper_t;
1253  typedef array<Index, RDims> right_dim_mapper_t;
1254 
1255  typedef array<Index, ContractDims> contract_t;
1256  typedef array<Index, LDims - ContractDims> left_nocontract_t;
1257  typedef array<Index, RDims - ContractDims> right_nocontract_t;
1258 
1259  static const int NumDims = LDims + RDims - 2 * ContractDims;
1260 
1261  typedef DSizes<Index, NumDims> Dimensions;
1262 
1263  // typedefs needed in evalTo
1266 
1267  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1268  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1269 
1270  typedef typename LeftEvaluator::Dimensions LeftDimensions;
1271  typedef typename RightEvaluator::Dimensions RightDimensions;
1272 
1273  TensorEvaluator(const XprType& op, const Device& device) :
1274  Base(op, device)
1275  {
1277  GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1278  }
1279 
1280  // We need to redefine this method to make nvcc happy
1282  this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1283  this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1284  if (data) {
1285  evalTo(data);
1286  return false;
1287  } else {
1288  this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1289  evalTo(this->m_result);
1290  return true;
1291  }
1292  }
1293 
1294  void evalTo(Scalar* buffer) const {
1295  if (this->m_lhs_inner_dim_contiguous) {
1296  if (this->m_rhs_inner_dim_contiguous) {
1297  if (this->m_rhs_inner_dim_reordered) {
1298  evalTyped<true, true, true, Unaligned>(buffer);
1299  }
1300  else {
1301  evalTyped<true, true, false, Unaligned>(buffer);
1302  }
1303  }
1304  else {
1305  if (this->m_rhs_inner_dim_reordered) {
1306  evalTyped<true, false, true, Unaligned>(buffer);
1307  }
1308  else {
1309  evalTyped<true, false, false, Unaligned>(buffer);
1310  }
1311  }
1312  }
1313  else {
1314  if (this->m_rhs_inner_dim_contiguous) {
1315  if (this->m_rhs_inner_dim_reordered) {
1316  evalTyped<false, true, true, Unaligned>(buffer);
1317  }
1318  else {
1319  evalTyped<false, true, false, Unaligned>(buffer);
1320  }
1321  }
1322  else {
1323  if (this->m_rhs_inner_dim_reordered) {
1324  evalTyped<false, false, true, Unaligned>(buffer);
1325  }
1326  else {
1327  evalTyped<false, false, false, Unaligned>(buffer);
1328  }
1329  }
1330  }
1331  }
1332 
1333  template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
1334  static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1335  const Index m_blocks = (m + 63) / 64;
1336  const Index n_blocks = (n + 63) / 64;
1337  const dim3 num_blocks(m_blocks, n_blocks, 1);
1338  const dim3 block_size(8, 8, 8);
1339  LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1340  }
1341  };
1342 
1343  template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1344  static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1345  if (m < 768 || n < 768) {
1346  const Index m_blocks = (m + 63) / 64;
1347  const Index n_blocks = (n + 63) / 64;
1348  const dim3 num_blocks(m_blocks, n_blocks, 1);
1349  const dim3 block_size(16, 16, 1);
1350  LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1351  } else {
1352  const Index m_blocks = (m + 127) / 128;
1353  const Index n_blocks = (n + 63) / 64;
1354  const dim3 num_blocks(m_blocks, n_blocks, 1);
1355  const dim3 block_size(8, 32, 1);
1356  LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1357  }
1358  }
1359  };
1360 
1361  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1362  void evalTyped(Scalar* buffer) const {
1363  // columns in left side, rows in right side
1364  const Index k = this->m_k_size;
1366 
1367  // rows in left side
1368  const Index m = this->m_i_size;
1369 
1370  // columns in right side
1371  const Index n = this->m_j_size;
1372 
1373  // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
1374  this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
1375 
1376  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1377  LeftEvaluator, left_nocontract_t,
1378  contract_t, 4,
1379  lhs_inner_dim_contiguous,
1380  false, Unaligned> LhsMapper;
1381 
1382  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1383  RightEvaluator, right_nocontract_t,
1384  contract_t, 4,
1385  rhs_inner_dim_contiguous,
1386  rhs_inner_dim_reordered, Unaligned> RhsMapper;
1387 
1388  typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1389 
1390 
1391  // initialize data mappers
1392  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1393  this->m_left_contracting_strides, this->m_k_strides);
1394 
1395  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1396  this->m_right_contracting_strides, this->m_k_strides);
1397 
1398  OutputMapper output(buffer, m);
1399 
1400 #if defined(EIGEN_USE_HIP)
1401  setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1402 #else
1403  setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1404 #endif
1405 
1406  LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device);
1407  }
1408 };
1409 
1410 } // end namespace Eigen
1411 
1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC
1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
Matrix3f m
SCALAR Scalar
Definition: bench_gemm.cpp:46
#define EIGEN_STRONG_INLINE
Definition: Macros.h:917
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Derived::Scalar CoeffReturnType
Point2 a3
Definition: testPose2.cpp:771
PyObject * conv(PyObject *o)
Definition: numpy.h:680
int n
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Namespace containing all symbols from the Eigen library.
Definition: jet.h:637
dim3 threadIdx
Definition: gpu_common.h:19
Pose3 x2(Rot3::Ypr(0.0, 0.0, 0.0), l2)
if((m *x).isApprox(y))
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
Definition: StaticAssert.h:127
const Device EIGEN_DEVICE_REF m_device
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
#define NULL
Definition: ccolamd.c:609
PacketType< CoeffReturnType, Device >::type PacketReturnType
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Derived::Dimensions Dimensions
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
std::map< std::string, Array< float, 1, 8, DontAlign|RowMajor > > results
Pose3 x1
Definition: testPose3.cpp:663
std::vector< size_t > Indices
dim3 blockIdx
Definition: gpu_common.h:19
EIGEN_STRONG_INLINE TensorEvaluator(const Derived &m, const Device &device)
Derived::Scalar Scalar
std::ptrdiff_t j
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:1076
internal::packet_traits< Scalar >::type type
Definition: TensorMeta.h:51
Definition: pytypes.h:1370


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:36:47