00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_SPARSELU_GEMM_KERNEL_H
00011 #define EIGEN_SPARSELU_GEMM_KERNEL_H
00012
00013 namespace Eigen {
00014
00015 namespace internal {
00016
00017
00024 template<typename Scalar,typename Index>
00025 EIGEN_DONT_INLINE
00026 void sparselu_gemm(Index m, Index n, Index d, const Scalar* A, Index lda, const Scalar* B, Index ldb, Scalar* C, Index ldc)
00027 {
00028 using namespace Eigen::internal;
00029
00030 typedef typename packet_traits<Scalar>::type Packet;
00031 enum {
00032 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
00033 PacketSize = packet_traits<Scalar>::size,
00034 PM = 8,
00035 RN = 2,
00036 RK = NumberOfRegisters>=16 ? 4 : 2,
00037 BM = 4096/sizeof(Scalar),
00038 SM = PM*PacketSize
00039 };
00040 Index d_end = (d/RK)*RK;
00041 Index n_end = (n/RN)*RN;
00042 Index i0 = internal::first_aligned(A,m);
00043
00044 eigen_internal_assert(((lda%PacketSize)==0) && ((ldc%PacketSize)==0) && (i0==internal::first_aligned(C,m)));
00045
00046
00047 for(Index i=0; i<i0; ++i)
00048 {
00049 for(Index j=0; j<n; ++j)
00050 {
00051 Scalar c = C[i+j*ldc];
00052 for(Index k=0; k<d; ++k)
00053 c += B[k+j*ldb] * A[i+k*lda];
00054 C[i+j*ldc] = c;
00055 }
00056 }
00057
00058 for(Index ib=i0; ib<m; ib+=BM)
00059 {
00060 Index actual_b = std::min<Index>(BM, m-ib);
00061 Index actual_b_end1 = (actual_b/SM)*SM;
00062 Index actual_b_end2 = (actual_b/PacketSize)*PacketSize;
00063
00064
00065 for(Index j=0; j<n_end; j+=RN)
00066 {
00067 const Scalar* Bc0 = B+(j+0)*ldb;
00068 const Scalar* Bc1 = B+(j+1)*ldb;
00069
00070 for(Index k=0; k<d_end; k+=RK)
00071 {
00072
00073
00074 Packet b00, b10, b20, b30, b01, b11, b21, b31;
00075 b00 = pset1<Packet>(Bc0[0]);
00076 b10 = pset1<Packet>(Bc0[1]);
00077 if(RK==4) b20 = pset1<Packet>(Bc0[2]);
00078 if(RK==4) b30 = pset1<Packet>(Bc0[3]);
00079 b01 = pset1<Packet>(Bc1[0]);
00080 b11 = pset1<Packet>(Bc1[1]);
00081 if(RK==4) b21 = pset1<Packet>(Bc1[2]);
00082 if(RK==4) b31 = pset1<Packet>(Bc1[3]);
00083
00084 Packet a0, a1, a2, a3, c0, c1, t0, t1;
00085
00086 const Scalar* A0 = A+ib+(k+0)*lda;
00087 const Scalar* A1 = A+ib+(k+1)*lda;
00088 const Scalar* A2 = A+ib+(k+2)*lda;
00089 const Scalar* A3 = A+ib+(k+3)*lda;
00090
00091 Scalar* C0 = C+ib+(j+0)*ldc;
00092 Scalar* C1 = C+ib+(j+1)*ldc;
00093
00094 a0 = pload<Packet>(A0);
00095 a1 = pload<Packet>(A1);
00096 if(RK==4)
00097 {
00098 a2 = pload<Packet>(A2);
00099 a3 = pload<Packet>(A3);
00100 }
00101 else
00102 {
00103
00104 a2 = a3 = a0;
00105 }
00106
00107 #define KMADD(c, a, b, tmp) {tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp);}
00108 #define WORK(I) \
00109 c0 = pload<Packet>(C0+i+(I)*PacketSize); \
00110 c1 = pload<Packet>(C1+i+(I)*PacketSize); \
00111 KMADD(c0, a0, b00, t0) \
00112 KMADD(c1, a0, b01, t1) \
00113 a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
00114 KMADD(c0, a1, b10, t0) \
00115 KMADD(c1, a1, b11, t1) \
00116 a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
00117 if(RK==4) KMADD(c0, a2, b20, t0) \
00118 if(RK==4) KMADD(c1, a2, b21, t1) \
00119 if(RK==4) a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
00120 if(RK==4) KMADD(c0, a3, b30, t0) \
00121 if(RK==4) KMADD(c1, a3, b31, t1) \
00122 if(RK==4) a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
00123 pstore(C0+i+(I)*PacketSize, c0); \
00124 pstore(C1+i+(I)*PacketSize, c1)
00125
00126
00127 for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
00128 {
00129 EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL1");
00130 prefetch((A0+i+(5)*PacketSize));
00131 prefetch((A1+i+(5)*PacketSize));
00132 if(RK==4) prefetch((A2+i+(5)*PacketSize));
00133 if(RK==4) prefetch((A3+i+(5)*PacketSize));
00134 WORK(0);
00135 WORK(1);
00136 WORK(2);
00137 WORK(3);
00138 WORK(4);
00139 WORK(5);
00140 WORK(6);
00141 WORK(7);
00142 }
00143
00144 for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
00145 {
00146 WORK(0);
00147 }
00148 #undef WORK
00149
00150 for(Index i=actual_b_end2; i<actual_b; ++i)
00151 {
00152 if(RK==4)
00153 {
00154 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
00155 C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1]+A2[i]*Bc1[2]+A3[i]*Bc1[3];
00156 }
00157 else
00158 {
00159 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
00160 C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1];
00161 }
00162 }
00163
00164 Bc0 += RK;
00165 Bc1 += RK;
00166 }
00167 }
00168
00169 if((n-n_end)>0)
00170 {
00171 const Scalar* Bc0 = B+(n-1)*ldb;
00172
00173 for(Index k=0; k<d_end; k+=RK)
00174 {
00175
00176
00177 Packet b00, b10, b20, b30;
00178 b00 = pset1<Packet>(Bc0[0]);
00179 b10 = pset1<Packet>(Bc0[1]);
00180 if(RK==4) b20 = pset1<Packet>(Bc0[2]);
00181 if(RK==4) b30 = pset1<Packet>(Bc0[3]);
00182
00183 Packet a0, a1, a2, a3, c0, t0;
00184
00185 const Scalar* A0 = A+ib+(k+0)*lda;
00186 const Scalar* A1 = A+ib+(k+1)*lda;
00187 const Scalar* A2 = A+ib+(k+2)*lda;
00188 const Scalar* A3 = A+ib+(k+3)*lda;
00189
00190 Scalar* C0 = C+ib+(n_end)*ldc;
00191
00192 a0 = pload<Packet>(A0);
00193 a1 = pload<Packet>(A1);
00194 if(RK==4)
00195 {
00196 a2 = pload<Packet>(A2);
00197 a3 = pload<Packet>(A3);
00198 }
00199 else
00200 {
00201
00202 a2 = a3 = a0;
00203 }
00204
00205 #define WORK(I) \
00206 c0 = pload<Packet>(C0+i+(I)*PacketSize); \
00207 KMADD(c0, a0, b00, t0) \
00208 a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
00209 KMADD(c0, a1, b10, t0) \
00210 a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
00211 if(RK==4) KMADD(c0, a2, b20, t0) \
00212 if(RK==4) a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
00213 if(RK==4) KMADD(c0, a3, b30, t0) \
00214 if(RK==4) a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
00215 pstore(C0+i+(I)*PacketSize, c0);
00216
00217
00218 for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
00219 {
00220 EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL2");
00221 WORK(0);
00222 WORK(1);
00223 WORK(2);
00224 WORK(3);
00225 WORK(4);
00226 WORK(5);
00227 WORK(6);
00228 WORK(7);
00229 }
00230
00231 for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
00232 {
00233 WORK(0);
00234 }
00235
00236 for(Index i=actual_b_end2; i<actual_b; ++i)
00237 {
00238 if(RK==4)
00239 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
00240 else
00241 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
00242 }
00243
00244 Bc0 += RK;
00245 #undef WORK
00246 }
00247 }
00248
00249
00250 Index rd = d-d_end;
00251 if(rd>0)
00252 {
00253 for(Index j=0; j<n; ++j)
00254 {
00255 enum {
00256 Alignment = PacketSize>1 ? Aligned : 0
00257 };
00258 typedef Map<Matrix<Scalar,Dynamic,1>, Alignment > MapVector;
00259 typedef Map<const Matrix<Scalar,Dynamic,1>, Alignment > ConstMapVector;
00260 if(rd==1) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b);
00261
00262 else if(rd==2) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
00263 + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b);
00264
00265 else MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
00266 + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b)
00267 + B[2+d_end+j*ldb] * ConstMapVector(A+(d_end+2)*lda+ib, actual_b);
00268 }
00269 }
00270
00271 }
00272 }
00273 #undef KMADD
00274
00275 }
00276
00277 }
00278
00279 #endif // EIGEN_SPARSELU_GEMM_KERNEL_H