10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H 11 #define EIGEN_GENERAL_MATRIX_VECTOR_H 58 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
83 ResScalar* res,
Index resIncr,
87 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
97 #ifdef _EIGEN_ACCUMULATE_PACKETS 98 #error _EIGEN_ACCUMULATE_PACKETS has already been defined 100 #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) \ 102 padd(pload<ResPacket>(&res[j]), \ 104 padd(pcj.pmul(lhs0.template load<LhsPacket, Alignment0>(j), ptmp0), \ 105 pcj.pmul(lhs1.template load<LhsPacket, Alignment13>(j), ptmp1)), \ 106 padd(pcj.pmul(lhs2.template load<LhsPacket, Alignment2>(j), ptmp2), \ 107 pcj.pmul(lhs3.template load<LhsPacket, Alignment13>(j), ptmp3)) ))) 109 typedef typename LhsMapper::VectorMapper LhsScalars;
116 enum { AllAligned = 0, EvenAligned, FirstAligned, NoneAligned };
117 const Index columnsAtOnce = 4;
118 const Index peels = 2;
119 const Index LhsPacketAlignedMask = LhsPacketSize-1;
120 const Index ResPacketAlignedMask = ResPacketSize-1;
124 const Index lhsStride = lhs.stride();
129 Index alignedSize = ResPacketSize>1 ? alignedStart + ((size-alignedStart) & ~ResPacketAlignedMask) : 0;
130 const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
132 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
133 Index alignmentPattern = alignmentStep==0 ? AllAligned
134 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
138 const Index lhsAlignmentOffset = lhs.firstAligned(size);
141 Index skipColumns = 0;
143 if( (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == size) || (
UIntPtr(res)%
sizeof(
ResScalar)) )
147 alignmentPattern = NoneAligned;
149 else if(LhsPacketSize > 4)
153 alignmentPattern = NoneAligned;
155 else if (LhsPacketSize>1)
159 while (skipColumns<LhsPacketSize &&
160 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipColumns)%LhsPacketSize))
162 if (skipColumns==LhsPacketSize)
165 alignmentPattern = NoneAligned;
170 skipColumns = (std::min)(skipColumns,cols);
179 else if(Vectorizable)
183 alignmentPattern = AllAligned;
186 const Index offset1 = (FirstAligned && alignmentStep==1)?3:1;
187 const Index offset3 = (FirstAligned && alignmentStep==1)?1:3;
189 Index columnBound = ((cols-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns;
190 for (
Index i=skipColumns; i<columnBound; i+=columnsAtOnce)
192 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(i, 0)),
193 ptmp1 = pset1<RhsPacket>(alpha*rhs(i+offset1, 0)),
194 ptmp2 = pset1<RhsPacket>(alpha*rhs(i+2, 0)),
195 ptmp3 = pset1<RhsPacket>(alpha*rhs(i+offset3, 0));
198 const LhsScalars lhs0 = lhs.getVectorMapper(0, i+0), lhs1 = lhs.getVectorMapper(0, i+offset1),
199 lhs2 = lhs.getVectorMapper(0, i+2), lhs3 = lhs.getVectorMapper(0, i+offset3);
205 for (
Index j=0; j<alignedStart; ++j)
213 if (alignedSize>alignedStart)
215 switch(alignmentPattern)
218 for (
Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
222 for (
Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
227 Index j = alignedStart;
230 LhsPacket A00, A01, A02, A03, A10, A11, A12, A13;
233 A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
234 A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
235 A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
237 for (; j<peeledSize; j+=peels*ResPacketSize)
239 A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
240 A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
241 A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
243 A00 = lhs0.template load<LhsPacket, Aligned>(j);
244 A10 = lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize);
245 T0 = pcj.
pmadd(A00, ptmp0, pload<ResPacket>(&res[j]));
246 T1 = pcj.
pmadd(A10, ptmp0, pload<ResPacket>(&res[j+ResPacketSize]));
248 T0 = pcj.
pmadd(A01, ptmp1, T0);
249 A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
250 T0 = pcj.
pmadd(A02, ptmp2, T0);
251 A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
252 T0 = pcj.
pmadd(A03, ptmp3, T0);
254 A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
255 T1 = pcj.
pmadd(A11, ptmp1, T1);
256 T1 = pcj.
pmadd(A12, ptmp2, T1);
257 T1 = pcj.
pmadd(A13, ptmp3, T1);
258 pstore(&res[j+ResPacketSize],T1);
261 for (; j<alignedSize; j+=ResPacketSize)
266 for (
Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
285 Index start = columnBound;
288 for (
Index k=start; k<end; ++k)
290 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(k, 0));
291 const LhsScalars lhs0 = lhs.getVectorMapper(0, k);
297 for (
Index j=0; j<alignedStart; ++j)
300 if (lhs0.template aligned<LhsPacket>(alignedStart))
301 for (
Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
302 pstore(&res[i], pcj.
pmadd(lhs0.template load<LhsPacket, Aligned>(i), ptmp0, pload<ResPacket>(&res[i])));
304 for (
Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
305 pstore(&res[i], pcj.
pmadd(lhs0.template load<LhsPacket, Unaligned>(i), ptmp0, pload<ResPacket>(&res[i])));
320 }
while(Vectorizable);
321 #undef _EIGEN_ACCUMULATE_PACKETS 334 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
357 const LhsMapper& lhs,
358 const RhsMapper& rhs,
359 ResScalar* res,
Index resIncr,
363 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
366 const LhsMapper& lhs,
367 const RhsMapper& rhs,
373 #ifdef _EIGEN_ACCUMULATE_PACKETS 374 #error _EIGEN_ACCUMULATE_PACKETS has already been defined 377 #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\ 378 RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \ 379 ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \ 380 ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \ 381 ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \ 382 ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); } 387 typedef typename LhsMapper::VectorMapper LhsScalars;
389 enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
390 const Index rowsAtOnce = 4;
391 const Index peels = 2;
392 const Index RhsPacketAlignedMask = RhsPacketSize-1;
393 const Index LhsPacketAlignedMask = LhsPacketSize-1;
394 const Index depth = cols;
395 const Index lhsStride = lhs.stride();
400 Index alignedStart = rhs.firstAligned(depth);
401 Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
402 const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
404 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
405 Index alignmentPattern = alignmentStep==0 ? AllAligned
406 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
410 const Index lhsAlignmentOffset = lhs.firstAligned(depth);
411 const Index rhsAlignmentOffset = rhs.firstAligned(rows);
416 if( (
sizeof(LhsScalar)!=
sizeof(RhsScalar)) ||
417 (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
418 (rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
422 alignmentPattern = NoneAligned;
424 else if(LhsPacketSize > 4)
427 alignmentPattern = NoneAligned;
429 else if (LhsPacketSize>1)
433 while (skipRows<LhsPacketSize &&
434 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
436 if (skipRows==LhsPacketSize)
439 alignmentPattern = NoneAligned;
444 skipRows = (std::min)(skipRows,
Index(rows));
453 else if(Vectorizable)
457 alignmentPattern = AllAligned;
460 const Index offset1 = (FirstAligned && alignmentStep==1)?3:1;
461 const Index offset3 = (FirstAligned && alignmentStep==1)?1:3;
463 Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
464 for (
Index i=skipRows; i<rowBound; i+=rowsAtOnce)
471 const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0),
472 lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0);
482 for (
Index j=0; j<alignedStart; ++j)
484 RhsScalar
b = rhs(j, 0);
485 tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
486 tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
489 if (alignedSize>alignedStart)
491 switch(alignmentPattern)
494 for (
Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
498 for (
Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
503 Index j = alignedStart;
513 A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
514 A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
515 A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
517 for (; j<peeledSize; j+=peels*RhsPacketSize)
519 RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0);
520 A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
521 A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
522 A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
524 ptmp0 = pcj.
pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0);
525 ptmp1 = pcj.
pmadd(A01, b, ptmp1);
526 A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
527 ptmp2 = pcj.
pmadd(A02, b, ptmp2);
528 A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
529 ptmp3 = pcj.
pmadd(A03, b, ptmp3);
530 A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
532 b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0);
533 ptmp0 = pcj.
pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0);
534 ptmp1 = pcj.
pmadd(A11, b, ptmp1);
535 ptmp2 = pcj.
pmadd(A12, b, ptmp2);
536 ptmp3 = pcj.
pmadd(A13, b, ptmp3);
539 for (; j<alignedSize; j+=RhsPacketSize)
544 for (
Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
557 for (
Index j=alignedSize; j<depth; ++j)
559 RhsScalar
b = rhs(j, 0);
560 tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
561 tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
563 res[i*resIncr] += alpha*tmp0;
564 res[(i+offset1)*resIncr] += alpha*tmp1;
565 res[(i+2)*resIncr] += alpha*tmp2;
566 res[(i+offset3)*resIncr] += alpha*tmp3;
571 Index start = rowBound;
574 for (
Index i=start; i<end; ++i)
577 ResPacket ptmp0 = pset1<ResPacket>(tmp0);
578 const LhsScalars lhs0 = lhs.getVectorMapper(i, 0);
581 for (
Index j=0; j<alignedStart; ++j)
582 tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
584 if (alignedSize>alignedStart)
587 if (lhs0.template aligned<LhsPacket>(alignedStart))
588 for (
Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
589 ptmp0 = pcj.
pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
591 for (
Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
592 ptmp0 = pcj.
pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
598 for (
Index j=alignedSize; j<depth; ++j)
599 tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
600 res[i*resIncr] += alpha*tmp0;
610 }
while(Vectorizable);
612 #undef _EIGEN_ACCUMULATE_PACKETS 619 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H const AutoDiffScalar< DerType > & conj(const AutoDiffScalar< DerType > &x)
packet_traits< RhsScalar >::type _RhsPacket
conditional< Vectorizable, _ResPacket, ResScalar >::type ResPacket
packet_traits< LhsScalar >::type _LhsPacket
static constexpr size_t size(Tuple< Args... > &)
Provides access to the number of elements in a tuple as a compile-time constant expression.
#define _EIGEN_ACCUMULATE_PACKETS(Alignment0, Alignment13, Alignment2)
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type predux(const Packet &a)
conditional< Vectorizable, _RhsPacket, RhsScalar >::type RhsPacket
conditional< Vectorizable, _LhsPacket, LhsScalar >::type LhsPacket
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
#define EIGEN_DONT_INLINE
static Index first_default_aligned(const DenseBase< Derived > &m)
packet_traits< LhsScalar >::type _LhsPacket
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
conditional< Vectorizable, _RhsPacket, RhsScalar >::type RhsPacket
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type pfirst(const Packet &a)
conditional< Vectorizable, _ResPacket, ResScalar >::type ResPacket
packet_traits< RhsScalar >::type _RhsPacket
EIGEN_DEVICE_FUNC void pstore(Scalar *to, const Packet &from)
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar &x, const RhsScalar &y, const Scalar &c) const
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
#define eigen_internal_assert(x)
packet_traits< ResScalar >::type _ResPacket
conditional< Vectorizable, _LhsPacket, LhsScalar >::type LhsPacket
void run(Expr &expr, Dev &dev)
EIGEN_DEVICE_FUNC const Scalar & b
packet_traits< ResScalar >::type _ResPacket
EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar &x, const RhsScalar &y) const
#define EIGEN_UNUSED_VARIABLE(var)