11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
14 #pragma GCC target("cpu=power10")
17 #if !__has_builtin(__builtin_vsx_assemble_pair)
18 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
26 template<
typename Scalar,
typename Packet>
29 __builtin_mma_xxsetaccz(acc);
32 template<
typename DataMapper,
typename Index,
typename Packet, const Index accCols>
36 __builtin_mma_disassemble_acc(&
result.packet, acc);
39 bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes,
data,
i,
j);
43 data.template storePacketBlock<Packet, 4>(
i,
j, tRes);
46 template<
typename DataMapper,
typename Index,
typename Packet,
typename Packetc, const Index accColsC,
int N>
50 __builtin_mma_disassemble_acc(&resultReal.
packet, accReal);
51 __builtin_mma_disassemble_acc(&resultImag.
packet, accImag);
54 bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes,
data,
i,
j);
57 bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
60 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
63 data.template storePacketBlock<Packetc, 4>(
i + (
N+1)*
accColsC,
j, acc2);
67 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
70 if(NegativeAccumulate)
72 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
74 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
78 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
81 __vector_pair* a0 = (__vector_pair *)(&
a.packet[0]);
82 if(NegativeAccumulate)
84 __builtin_mma_xvf64gernp(acc, *a0, (__vector
unsigned char)
b);
86 __builtin_mma_xvf64gerpp(acc, *a0, (__vector
unsigned char)
b);
90 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
93 if(NegativeAccumulate)
95 __builtin_mma_xvf64gernp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
97 __builtin_mma_xvf64gerpp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
101 template<
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
107 template<
typename Scalar,
typename Packet,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
110 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
112 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
115 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
116 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
120 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
125 template<
typename Scalar,
typename Packet>
128 rhsV = ploadRhs<Scalar, Packet>((
const Scalar*)(rhs));
134 rhsV.packet[0] = ploadRhs<double, Packet2d>((
const double *)((
Packet2d *)rhs ));
135 rhsV.packet[1] = ploadRhs<double, Packet2d>((
const double *)(((
Packet2d *)rhs) + 1));
142 __builtin_vsx_assemble_pair(&rhsV,
143 (__vector
unsigned char)(ploadRhs<double, Packet2d>((
const double *)(((
Packet2d *)rhs) + 1))),
144 (__vector
unsigned char)(ploadRhs<double, Packet2d>((
const double *)((
Packet2d *)rhs ))));
146 __asm__ (
"lxvp %x0,%1" :
"=wa" (rhsV) :
"Y" (*rhs));
159 #define MICRO_MMA_UNROLL(func) \
160 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
162 #define MICRO_MMA_LOAD_ONE(iter) \
163 if (unroll_factor > iter) { \
164 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
165 lhs_ptr##iter += accCols; \
167 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
170 #define MICRO_MMA_WORK_ONE(iter, type, peel) \
171 if (unroll_factor > iter) { \
172 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
175 #define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \
176 if (PEEL_MMA > peel) { \
177 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
178 ploadRhsMMA<Scalar, type>(rhs_ptr + (accRows * peel), rhsV##peel); \
179 MICRO_MMA_UNROLL(func2); \
180 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
181 func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \
183 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
186 #define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
187 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
188 MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
189 MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
190 MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
191 MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
192 MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
194 #define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
196 MICRO_MMA_TYPE_PEEL(func,func2,type,0);
198 #define MICRO_MMA_ONE_PEEL \
199 if (sizeof(Scalar) == sizeof(float)) { \
200 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
202 MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
204 rhs_ptr += (accRows * PEEL_MMA);
206 #define MICRO_MMA_ONE \
207 if (sizeof(Scalar) == sizeof(float)) { \
208 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \
210 MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \
214 #define MICRO_MMA_DST_PTR_ONE(iter) \
215 if (unroll_factor > iter) { \
216 bsetzeroMMA<Scalar, Packet>(&accZero##iter); \
218 EIGEN_UNUSED_VARIABLE(accZero##iter); \
221 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
223 #define MICRO_MMA_SRC_PTR_ONE(iter) \
224 if (unroll_factor > iter) { \
225 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
227 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
230 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE)
232 #define MICRO_MMA_PREFETCH_ONE(iter) \
233 if (unroll_factor > iter) { \
234 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
237 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE)
239 #define MICRO_MMA_STORE_ONE(iter) \
240 if (unroll_factor > iter) { \
241 storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
244 #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
246 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols>
248 const DataMapper&
res,
258 const Scalar* rhs_ptr = rhs_base;
260 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
272 for(; k <
depth; k++)
278 row += unroll_factor*accCols;
281 template<
typename Scalar,
typename Index,
typename Packet,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols>
282 void gemmMMA(
const DataMapper&
res,
const Scalar* blockA,
const Scalar* blockB,
Index rows,
Index depth,
Index cols,
Scalar alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
284 const Index remaining_rows =
rows % accCols;
285 const Index remaining_cols =
cols % accRows;
287 if( strideA == -1 ) strideA =
depth;
288 if( strideB == -1 ) strideB =
depth;
291 const Packet pMask = bmask<Packet>((
const int)(remaining_rows));
296 const Scalar* rhs_base = blockB +
col*strideB + accRows*offsetB;
297 const Scalar* lhs_base = blockA;
300 #define MAX_MMA_UNROLL 7
302 gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
305 #if MAX_MMA_UNROLL > 7
307 gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
310 #if MAX_MMA_UNROLL > 6
312 gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
315 #if MAX_MMA_UNROLL > 5
317 gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
320 #if MAX_MMA_UNROLL > 4
322 gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
325 #if MAX_MMA_UNROLL > 3
327 gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
330 #if MAX_MMA_UNROLL > 2
332 gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
335 #if MAX_MMA_UNROLL > 1
337 gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, pAlpha);
343 #undef MAX_MMA_UNROLL
345 if(remaining_rows > 0)
347 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col,
rows,
cols, remaining_rows, pAlpha, pMask);
351 if(remaining_cols > 0)
353 const Scalar* rhs_base = blockB +
col*strideB + remaining_cols*offsetB;
354 const Scalar* lhs_base = blockA;
360 gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
rows,
col, remaining_cols, pAlpha);
362 if (remaining_rows > 0)
364 gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(
res, lhs_base, rhs_base,
depth, strideA, offsetA,
row,
col, remaining_rows, remaining_cols, pAlpha);
371 #define accColsC (accCols / 2)
372 #define advanceRows ((LhsIsReal) ? 1 : 2)
373 #define advanceCols ((RhsIsReal) ? 1 : 2)
376 #define PEEL_COMPLEX_MMA 7
378 #define MICRO_COMPLEX_MMA_UNROLL(func) \
379 func(0) func(1) func(2) func(3) func(4)
381 #define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
382 if (unroll_factor > iter) { \
383 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
384 lhs_ptr_real##iter += accCols; \
386 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
387 lhs_ptr_imag##iter += accCols; \
389 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
392 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
393 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
396 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
397 if (unroll_factor > iter) { \
398 pgercMMA<Scalar, Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
401 #define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
402 if (PEEL_COMPLEX_MMA > peel) { \
403 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
404 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
405 ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
407 ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
409 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
411 MICRO_COMPLEX_MMA_UNROLL(func2); \
412 func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
414 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
415 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
418 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
419 type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
420 type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
421 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
422 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
423 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
424 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
425 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
427 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
428 type rhsV0, rhsVi0; \
429 MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0);
431 #define MICRO_COMPLEX_MMA_ONE_PEEL \
432 if (sizeof(Scalar) == sizeof(float)) { \
433 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
435 MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
437 rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \
438 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA);
440 #define MICRO_COMPLEX_MMA_ONE \
441 if (sizeof(Scalar) == sizeof(float)) { \
442 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \
444 MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \
446 rhs_ptr_real += accRows; \
447 if(!RhsIsReal) rhs_ptr_imag += accRows;
449 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
450 if (unroll_factor > iter) { \
451 bsetzeroMMA<Scalar, Packet>(&accReal##iter); \
452 bsetzeroMMA<Scalar, Packet>(&accImag##iter); \
454 EIGEN_UNUSED_VARIABLE(accReal##iter); \
455 EIGEN_UNUSED_VARIABLE(accImag##iter); \
458 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
460 #define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
461 if (unroll_factor > iter) { \
462 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
464 lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
466 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
469 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
470 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
473 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
475 #define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
476 if (unroll_factor > iter) { \
477 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
479 EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
483 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
485 #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
486 if (unroll_factor > iter) { \
487 storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
490 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
492 template<
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
typename Index, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
494 const DataMapper&
res,
506 const Scalar* rhs_ptr_real = rhs_base;
507 const Scalar* rhs_ptr_imag;
509 rhs_ptr_imag = rhs_base + accRows*strideB;
513 const Scalar* lhs_ptr_real0 =
NULL, * lhs_ptr_imag0 =
NULL, * lhs_ptr_real1 =
NULL, * lhs_ptr_imag1 =
NULL;
514 const Scalar* lhs_ptr_real2 =
NULL, * lhs_ptr_imag2 =
NULL, * lhs_ptr_real3 =
NULL, * lhs_ptr_imag3 =
NULL;
516 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
531 for(; k <
depth; k++)
537 row += unroll_factor*accCols;
540 template<
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Index,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper, const Index accRows, const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
541 void gemm_complexMMA(
const DataMapper&
res,
const LhsScalar* blockAc,
const RhsScalar* blockBc,
Index rows,
Index depth,
Index cols, Scalarc
alpha,
Index strideA,
Index strideB,
Index offsetA,
Index offsetB)
543 const Index remaining_rows =
rows % accCols;
544 const Index remaining_cols =
cols % accRows;
546 if( strideA == -1 ) strideA =
depth;
547 if( strideB == -1 ) strideB =
depth;
549 const Packet pAlphaReal = pset1<Packet>(
alpha.real());
550 const Packet pAlphaImag = pset1<Packet>(
alpha.imag());
551 const Packet pMask = bmask<Packet>((
const int)(remaining_rows));
560 const Scalar* lhs_base = blockA;
563 #define MAX_COMPLEX_MMA_UNROLL 4
565 gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col, pAlphaReal, pAlphaImag);
568 #if MAX_COMPLEX_MMA_UNROLL > 4
570 gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col, pAlphaReal, pAlphaImag);
573 #if MAX_COMPLEX_MMA_UNROLL > 3
575 gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col, pAlphaReal, pAlphaImag);
578 #if MAX_COMPLEX_MMA_UNROLL > 2
580 gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col, pAlphaReal, pAlphaImag);
583 #if MAX_COMPLEX_MMA_UNROLL > 1
585 gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col, pAlphaReal, pAlphaImag);
591 #undef MAX_COMPLEX_MMA_UNROLL
593 if(remaining_rows > 0)
595 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col,
rows,
cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
599 if(remaining_cols > 0)
602 const Scalar* lhs_base = blockA;
608 gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
rows,
col, remaining_cols, pAlphaReal, pAlphaImag);
610 if (remaining_rows > 0)
612 gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(
res, lhs_base, rhs_base,
depth, strideA, offsetA, strideB,
row,
col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
623 #pragma GCC reset_options
628 #endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H