00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
00011 #define EIGEN_GENERAL_MATRIX_MATRIX_H
00012
00013 namespace Eigen {
00014
00015 namespace internal {
00016
00017 template<typename _LhsScalar, typename _RhsScalar> class level3_blocking;
00018
00019
00020 template<
00021 typename Index,
00022 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00023 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
00024 struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor>
00025 {
00026 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00027 static EIGEN_STRONG_INLINE void run(
00028 Index rows, Index cols, Index depth,
00029 const LhsScalar* lhs, Index lhsStride,
00030 const RhsScalar* rhs, Index rhsStride,
00031 ResScalar* res, Index resStride,
00032 ResScalar alpha,
00033 level3_blocking<RhsScalar,LhsScalar>& blocking,
00034 GemmParallelInfo<Index>* info = 0)
00035 {
00036
00037 general_matrix_matrix_product<Index,
00038 RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
00039 LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
00040 ColMajor>
00041 ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
00042 }
00043 };
00044
00045
00046
00047 template<
00048 typename Index,
00049 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00050 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
00051 struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor>
00052 {
00053
00054 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
00055 static void run(Index rows, Index cols, Index depth,
00056 const LhsScalar* _lhs, Index lhsStride,
00057 const RhsScalar* _rhs, Index rhsStride,
00058 ResScalar* res, Index resStride,
00059 ResScalar alpha,
00060 level3_blocking<LhsScalar,RhsScalar>& blocking,
00061 GemmParallelInfo<Index>* info = 0)
00062 {
00063 const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
00064 const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
00065
00066 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
00067
00068 Index kc = blocking.kc();
00069 Index mc = (std::min)(rows,blocking.mc());
00070
00071
00072 gemm_pack_lhs<LhsScalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
00073 gemm_pack_rhs<RhsScalar, Index, Traits::nr, RhsStorageOrder> pack_rhs;
00074 gebp_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
00075
00076 #ifdef EIGEN_HAS_OPENMP
00077 if(info)
00078 {
00079
00080 Index tid = omp_get_thread_num();
00081 Index threads = omp_get_num_threads();
00082
00083 std::size_t sizeA = kc*mc;
00084 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
00085 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, 0);
00086 ei_declare_aligned_stack_constructed_variable(RhsScalar, w, sizeW, 0);
00087
00088 RhsScalar* blockB = blocking.blockB();
00089 eigen_internal_assert(blockB!=0);
00090
00091
00092 for(Index k=0; k<depth; k+=kc)
00093 {
00094 const Index actual_kc = (std::min)(k+kc,depth)-k;
00095
00096
00097
00098 pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, mc);
00099
00100
00101
00102
00103
00104
00105
00106 while(info[tid].users!=0) {}
00107 info[tid].users += threads;
00108
00109 pack_rhs(blockB+info[tid].rhs_start*actual_kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length);
00110
00111
00112 info[tid].sync = k;
00113
00114
00115 for(Index shift=0; shift<threads; ++shift)
00116 {
00117 Index j = (tid+shift)%threads;
00118
00119
00120
00121
00122 if(shift>0)
00123 while(info[j].sync!=k) {}
00124
00125 gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*actual_kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w);
00126 }
00127
00128
00129 for(Index i=mc; i<rows; i+=mc)
00130 {
00131 const Index actual_mc = (std::min)(i+mc,rows)-i;
00132
00133
00134 pack_lhs(blockA, &lhs(i,k), lhsStride, actual_kc, actual_mc);
00135
00136
00137 gebp(res+i, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1,-1,0,0, w);
00138 }
00139
00140
00141
00142 for(Index j=0; j<threads; ++j)
00143 #pragma omp atomic
00144 --(info[j].users);
00145 }
00146 }
00147 else
00148 #endif // EIGEN_HAS_OPENMP
00149 {
00150 EIGEN_UNUSED_VARIABLE(info);
00151
00152
00153 std::size_t sizeA = kc*mc;
00154 std::size_t sizeB = kc*cols;
00155 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
00156
00157 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
00158 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
00159 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockW, sizeW, blocking.blockW());
00160
00161
00162
00163 for(Index k2=0; k2<depth; k2+=kc)
00164 {
00165 const Index actual_kc = (std::min)(k2+kc,depth)-k2;
00166
00167
00168
00169
00170
00171 pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols);
00172
00173
00174
00175 for(Index i2=0; i2<rows; i2+=mc)
00176 {
00177 const Index actual_mc = (std::min)(i2+mc,rows)-i2;
00178
00179
00180
00181
00182 pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
00183
00184
00185 gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW);
00186 }
00187 }
00188 }
00189 }
00190
00191 };
00192
00193
00194
00195
00196
00197
00198 template<typename Lhs, typename Rhs>
00199 struct traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
00200 : traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
00201 {};
00202
00203 template<typename Scalar, typename Index, typename Gemm, typename Lhs, typename Rhs, typename Dest, typename BlockingType>
00204 struct gemm_functor
00205 {
00206 gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, const Scalar& actualAlpha,
00207 BlockingType& blocking)
00208 : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking)
00209 {}
00210
00211 void initParallelSession() const
00212 {
00213 m_blocking.allocateB();
00214 }
00215
00216 void operator() (Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo<Index>* info=0) const
00217 {
00218 if(cols==-1)
00219 cols = m_rhs.cols();
00220
00221 Gemm::run(rows, cols, m_lhs.cols(),
00222 &m_lhs.coeffRef(row,0), m_lhs.outerStride(),
00223 &m_rhs.coeffRef(0,col), m_rhs.outerStride(),
00224 (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
00225 m_actualAlpha, m_blocking, info);
00226 }
00227
00228 protected:
00229 const Lhs& m_lhs;
00230 const Rhs& m_rhs;
00231 Dest& m_dest;
00232 Scalar m_actualAlpha;
00233 BlockingType& m_blocking;
00234 };
00235
00236 template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor=1,
00237 bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> class gemm_blocking_space;
00238
00239 template<typename _LhsScalar, typename _RhsScalar>
00240 class level3_blocking
00241 {
00242 typedef _LhsScalar LhsScalar;
00243 typedef _RhsScalar RhsScalar;
00244
00245 protected:
00246 LhsScalar* m_blockA;
00247 RhsScalar* m_blockB;
00248 RhsScalar* m_blockW;
00249
00250 DenseIndex m_mc;
00251 DenseIndex m_nc;
00252 DenseIndex m_kc;
00253
00254 public:
00255
00256 level3_blocking()
00257 : m_blockA(0), m_blockB(0), m_blockW(0), m_mc(0), m_nc(0), m_kc(0)
00258 {}
00259
00260 inline DenseIndex mc() const { return m_mc; }
00261 inline DenseIndex nc() const { return m_nc; }
00262 inline DenseIndex kc() const { return m_kc; }
00263
00264 inline LhsScalar* blockA() { return m_blockA; }
00265 inline RhsScalar* blockB() { return m_blockB; }
00266 inline RhsScalar* blockW() { return m_blockW; }
00267 };
00268
00269 template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
00270 class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, true>
00271 : public level3_blocking<
00272 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
00273 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
00274 {
00275 enum {
00276 Transpose = StorageOrder==RowMajor,
00277 ActualRows = Transpose ? MaxCols : MaxRows,
00278 ActualCols = Transpose ? MaxRows : MaxCols
00279 };
00280 typedef typename conditional<Transpose,_RhsScalar,_LhsScalar>::type LhsScalar;
00281 typedef typename conditional<Transpose,_LhsScalar,_RhsScalar>::type RhsScalar;
00282 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
00283 enum {
00284 SizeA = ActualRows * MaxDepth,
00285 SizeB = ActualCols * MaxDepth,
00286 SizeW = MaxDepth * Traits::WorkSpaceFactor
00287 };
00288
00289 EIGEN_ALIGN16 LhsScalar m_staticA[SizeA];
00290 EIGEN_ALIGN16 RhsScalar m_staticB[SizeB];
00291 EIGEN_ALIGN16 RhsScalar m_staticW[SizeW];
00292
00293 public:
00294
00295 gemm_blocking_space(DenseIndex , DenseIndex , DenseIndex )
00296 {
00297 this->m_mc = ActualRows;
00298 this->m_nc = ActualCols;
00299 this->m_kc = MaxDepth;
00300 this->m_blockA = m_staticA;
00301 this->m_blockB = m_staticB;
00302 this->m_blockW = m_staticW;
00303 }
00304
00305 inline void allocateA() {}
00306 inline void allocateB() {}
00307 inline void allocateW() {}
00308 inline void allocateAll() {}
00309 };
00310
00311 template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
00312 class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, false>
00313 : public level3_blocking<
00314 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
00315 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
00316 {
00317 enum {
00318 Transpose = StorageOrder==RowMajor
00319 };
00320 typedef typename conditional<Transpose,_RhsScalar,_LhsScalar>::type LhsScalar;
00321 typedef typename conditional<Transpose,_LhsScalar,_RhsScalar>::type RhsScalar;
00322 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
00323
00324 DenseIndex m_sizeA;
00325 DenseIndex m_sizeB;
00326 DenseIndex m_sizeW;
00327
00328 public:
00329
00330 gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth)
00331 {
00332 this->m_mc = Transpose ? cols : rows;
00333 this->m_nc = Transpose ? rows : cols;
00334 this->m_kc = depth;
00335
00336 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc);
00337 m_sizeA = this->m_mc * this->m_kc;
00338 m_sizeB = this->m_kc * this->m_nc;
00339 m_sizeW = this->m_kc*Traits::WorkSpaceFactor;
00340 }
00341
00342 void allocateA()
00343 {
00344 if(this->m_blockA==0)
00345 this->m_blockA = aligned_new<LhsScalar>(m_sizeA);
00346 }
00347
00348 void allocateB()
00349 {
00350 if(this->m_blockB==0)
00351 this->m_blockB = aligned_new<RhsScalar>(m_sizeB);
00352 }
00353
00354 void allocateW()
00355 {
00356 if(this->m_blockW==0)
00357 this->m_blockW = aligned_new<RhsScalar>(m_sizeW);
00358 }
00359
00360 void allocateAll()
00361 {
00362 allocateA();
00363 allocateB();
00364 allocateW();
00365 }
00366
00367 ~gemm_blocking_space()
00368 {
00369 aligned_delete(this->m_blockA, m_sizeA);
00370 aligned_delete(this->m_blockB, m_sizeB);
00371 aligned_delete(this->m_blockW, m_sizeW);
00372 }
00373 };
00374
00375 }
00376
00377 template<typename Lhs, typename Rhs>
00378 class GeneralProduct<Lhs, Rhs, GemmProduct>
00379 : public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
00380 {
00381 enum {
00382 MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
00383 };
00384 public:
00385 EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
00386
00387 typedef typename Lhs::Scalar LhsScalar;
00388 typedef typename Rhs::Scalar RhsScalar;
00389 typedef Scalar ResScalar;
00390
00391 GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
00392 {
00393 typedef internal::scalar_product_op<LhsScalar,RhsScalar> BinOp;
00394 EIGEN_CHECK_BINARY_COMPATIBILIY(BinOp,LhsScalar,RhsScalar);
00395 }
00396
00397 template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
00398 {
00399 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
00400
00401 typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs);
00402 typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs);
00403
00404 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
00405 * RhsBlasTraits::extractScalarFactor(m_rhs);
00406
00407 typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
00408 Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
00409
00410 typedef internal::gemm_functor<
00411 Scalar, Index,
00412 internal::general_matrix_matrix_product<
00413 Index,
00414 LhsScalar, (_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
00415 RhsScalar, (_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
00416 (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
00417 _ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor;
00418
00419 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
00420
00421 internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
00422 }
00423 };
00424
00425 }
00426
00427 #endif // EIGEN_GENERAL_MATRIX_MATRIX_H