10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
23 template <
int N,
typename T1,
typename T2,
typename T3>
26 template <
typename T1,
typename T2,
typename T3>
29 template <
typename T1,
typename T2,
typename T3>
32 template<
typename LhsScalar,
typename RhsScalar,
int _PacketSize=GEMVPacketFull>
37 #define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
38 typedef typename gemv_packet_cond<packet_size, \
39 typename packet_traits<name ## Scalar>::type, \
40 typename packet_traits<name ## Scalar>::half, \
41 typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
42 prefix ## name ## Packet
47 #undef PACKET_DECL_COND_PREFIX
78 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
101 const LhsMapper& lhs,
102 const RhsMapper& rhs,
107 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
110 const LhsMapper& alhs,
111 const RhsMapper& rhs,
127 const Index lhsStride = lhs.stride();
130 ResPacketSize = Traits::ResPacketSize,
131 ResPacketSizeHalf = HalfTraits::ResPacketSize,
132 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
133 LhsPacketSize = Traits::LhsPacketSize,
134 HasHalf = (
int)ResPacketSizeHalf < (
int)ResPacketSize,
135 HasQuarter = (
int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
143 const Index n_half =
rows-1*ResPacketSizeHalf+1;
144 const Index n_quarter =
rows-1*ResPacketSizeQuarter+1;
147 const Index block_cols =
cols<128 ?
cols : (lhsStride*
sizeof(LhsScalar)<32000?16:4);
156 for(;
i<n8;
i+=ResPacketSize*8)
170 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
171 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,
c1);
172 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*2,
j),b0,
c2);
173 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*3,
j),b0,c3);
174 c4 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*4,
j),b0,c4);
175 c5 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*5,
j),b0,c5);
176 c6 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*6,
j),b0,c6);
177 c7 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*7,
j),b0,c7);
198 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
199 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,
c1);
200 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*2,
j),b0,
c2);
201 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*3,
j),b0,c3);
219 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
220 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,
c1);
221 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*2,
j),b0,
c2);
237 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
238 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,
c1);
250 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
255 if(HasHalf &&
i<n_half)
261 c0 = pcj_half.
pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(
i+0,
j),b0,c0);
263 pstoreu(
res+
i+ResPacketSizeHalf*0,
pmadd(c0,palpha_half,ploadu<ResPacketHalf>(
res+
i+ResPacketSizeHalf*0)));
264 i+=ResPacketSizeHalf;
266 if(HasQuarter &&
i<n_quarter)
272 c0 = pcj_quarter.
pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(
i+0,
j),b0,c0);
274 pstoreu(
res+
i+ResPacketSizeQuarter*0,
pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(
res+
i+ResPacketSizeQuarter*0)));
275 i+=ResPacketSizeQuarter;
281 c0 += cj.
pmul(lhs(
i,
j), rhs(
j,0));
297 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
320 const LhsMapper& lhs,
321 const RhsMapper& rhs,
326 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
329 const LhsMapper& alhs,
330 const RhsMapper& rhs,
346 const Index n8 = lhs.stride()*
sizeof(LhsScalar)>32000 ? 0 :
rows-7;
352 ResPacketSize = Traits::ResPacketSize,
353 ResPacketSizeHalf = HalfTraits::ResPacketSize,
354 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
355 LhsPacketSize = Traits::LhsPacketSize,
356 LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
357 LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
358 HasHalf = (
int)ResPacketSizeHalf < (
int)ResPacketSize,
359 HasQuarter = (
int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
375 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
377 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j,0);
379 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
380 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+1,
j),b0,
c1);
381 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+2,
j),b0,
c2);
382 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+3,
j),b0,c3);
383 c4 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+4,
j),b0,c4);
384 c5 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+5,
j),b0,c5);
385 c6 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+6,
j),b0,c6);
386 c7 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+7,
j),b0,c7);
398 RhsScalar b0 = rhs(
j,0);
400 cc0 += cj.
pmul(lhs(
i+0,
j), b0);
401 cc1 += cj.
pmul(lhs(
i+1,
j), b0);
402 cc2 += cj.
pmul(lhs(
i+2,
j), b0);
403 cc3 += cj.
pmul(lhs(
i+3,
j), b0);
404 cc4 += cj.
pmul(lhs(
i+4,
j), b0);
405 cc5 += cj.
pmul(lhs(
i+5,
j), b0);
406 cc6 += cj.
pmul(lhs(
i+6,
j), b0);
407 cc7 += cj.
pmul(lhs(
i+7,
j), b0);
426 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
428 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j,0);
430 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
431 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+1,
j),b0,
c1);
432 c2 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+2,
j),b0,
c2);
433 c3 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+3,
j),b0,c3);
441 RhsScalar b0 = rhs(
j,0);
443 cc0 += cj.
pmul(lhs(
i+0,
j), b0);
444 cc1 += cj.
pmul(lhs(
i+1,
j), b0);
445 cc2 += cj.
pmul(lhs(
i+2,
j), b0);
446 cc3 += cj.
pmul(lhs(
i+3,
j), b0);
459 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
461 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j,0);
463 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
464 c1 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+1,
j),b0,
c1);
470 RhsScalar b0 = rhs(
j,0);
472 cc0 += cj.
pmul(lhs(
i+0,
j), b0);
473 cc1 += cj.
pmul(lhs(
i+1,
j), b0);
484 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
486 RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(
j,0);
487 c0 = pcj.
pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i,
j),b0,c0);
491 for(;
j+LhsPacketSizeHalf<=
cols;
j+=LhsPacketSizeHalf)
493 RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(
j,0);
494 c0_h = pcj_half.
pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(
i,
j),b0,c0_h);
499 for(;
j+LhsPacketSizeQuarter<=
cols;
j+=LhsPacketSizeQuarter)
502 c0_q = pcj_quarter.
pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(
i,
j),b0,c0_q);
508 cc0 += cj.
pmul(lhs(
i,
j), rhs(
j,0));
518 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H