00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00016
00017
00018
00020 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00021 class gemm_emul_tinysq
00022 {
00023 public:
00024
00025
00026 template<typename eT>
00027 arma_hot
00028 inline
00029 static
00030 void
00031 apply
00032 (
00033 Mat<eT>& C,
00034 const Mat<eT>& A,
00035 const Mat<eT>& B,
00036 const eT alpha = eT(1),
00037 const eT beta = eT(0)
00038 )
00039 {
00040 arma_extra_debug_sigprint();
00041
00042 switch(A.n_rows)
00043 {
00044 case 4:
00045 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
00046
00047 case 3:
00048 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
00049
00050 case 2:
00051 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
00052
00053 case 1:
00054 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
00055
00056 default:
00057 ;
00058 }
00059 }
00060
00061 };
00062
00063
00064
00065 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00066 class gemm_emul_large
00067 {
00068 public:
00069
00070 template<typename eT>
00071 arma_hot
00072 inline
00073 static
00074 void
00075 apply
00076 (
00077 Mat<eT>& C,
00078 const Mat<eT>& A,
00079 const Mat<eT>& B,
00080 const eT alpha = eT(1),
00081 const eT beta = eT(0)
00082 )
00083 {
00084 arma_extra_debug_sigprint();
00085
00086 const uword A_n_rows = A.n_rows;
00087 const uword A_n_cols = A.n_cols;
00088
00089 const uword B_n_rows = B.n_rows;
00090 const uword B_n_cols = B.n_cols;
00091
00092 if( (do_trans_A == false) && (do_trans_B == false) )
00093 {
00094 arma_aligned podarray<eT> tmp(A_n_cols);
00095 eT* A_rowdata = tmp.memptr();
00096
00097 for(uword row_A=0; row_A < A_n_rows; ++row_A)
00098 {
00099 tmp.copy_row(A, row_A);
00100
00101 for(uword col_B=0; col_B < B_n_cols; ++col_B)
00102 {
00103 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
00104
00105 if( (use_alpha == false) && (use_beta == false) )
00106 {
00107 C.at(row_A,col_B) = acc;
00108 }
00109 else
00110 if( (use_alpha == true) && (use_beta == false) )
00111 {
00112 C.at(row_A,col_B) = alpha * acc;
00113 }
00114 else
00115 if( (use_alpha == false) && (use_beta == true) )
00116 {
00117 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00118 }
00119 else
00120 if( (use_alpha == true) && (use_beta == true) )
00121 {
00122 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00123 }
00124
00125 }
00126 }
00127 }
00128 else
00129 if( (do_trans_A == true) && (do_trans_B == false) )
00130 {
00131 for(uword col_A=0; col_A < A_n_cols; ++col_A)
00132 {
00133
00134
00135 const eT* A_coldata = A.colptr(col_A);
00136
00137 for(uword col_B=0; col_B < B_n_cols; ++col_B)
00138 {
00139 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
00140
00141 if( (use_alpha == false) && (use_beta == false) )
00142 {
00143 C.at(col_A,col_B) = acc;
00144 }
00145 else
00146 if( (use_alpha == true) && (use_beta == false) )
00147 {
00148 C.at(col_A,col_B) = alpha * acc;
00149 }
00150 else
00151 if( (use_alpha == false) && (use_beta == true) )
00152 {
00153 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00154 }
00155 else
00156 if( (use_alpha == true) && (use_beta == true) )
00157 {
00158 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00159 }
00160
00161 }
00162 }
00163 }
00164 else
00165 if( (do_trans_A == false) && (do_trans_B == true) )
00166 {
00167 Mat<eT> BB;
00168 op_strans::apply_noalias(BB, B);
00169
00170 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
00171 }
00172 else
00173 if( (do_trans_A == true) && (do_trans_B == true) )
00174 {
00175
00176
00177
00178
00179
00180
00181
00182 arma_aligned podarray<eT> tmp(B.n_cols);
00183 eT* B_rowdata = tmp.memptr();
00184
00185 for(uword row_B=0; row_B < B_n_rows; ++row_B)
00186 {
00187 tmp.copy_row(B, row_B);
00188
00189 for(uword col_A=0; col_A < A_n_cols; ++col_A)
00190 {
00191 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
00192
00193 if( (use_alpha == false) && (use_beta == false) )
00194 {
00195 C.at(col_A,row_B) = acc;
00196 }
00197 else
00198 if( (use_alpha == true) && (use_beta == false) )
00199 {
00200 C.at(col_A,row_B) = alpha * acc;
00201 }
00202 else
00203 if( (use_alpha == false) && (use_beta == true) )
00204 {
00205 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00206 }
00207 else
00208 if( (use_alpha == true) && (use_beta == true) )
00209 {
00210 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00211 }
00212
00213 }
00214 }
00215
00216 }
00217 }
00218
00219 };
00220
00221
00222
00223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00224 class gemm_emul
00225 {
00226 public:
00227
00228
00229 template<typename eT>
00230 arma_hot
00231 inline
00232 static
00233 void
00234 apply
00235 (
00236 Mat<eT>& C,
00237 const Mat<eT>& A,
00238 const Mat<eT>& B,
00239 const eT alpha = eT(1),
00240 const eT beta = eT(0),
00241 const typename arma_not_cx<eT>::result* junk = 0
00242 )
00243 {
00244 arma_extra_debug_sigprint();
00245 arma_ignore(junk);
00246
00247 const uword A_n_rows = A.n_rows;
00248 const uword A_n_cols = A.n_cols;
00249
00250 const uword B_n_rows = B.n_rows;
00251 const uword B_n_cols = B.n_cols;
00252
00253 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
00254 {
00255 if(do_trans_B == false)
00256 {
00257 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
00258 }
00259 else
00260 {
00261 Mat<eT> BB(A_n_rows, A_n_rows);
00262 op_strans::apply_noalias_tinysq(BB, B);
00263
00264 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
00265 }
00266 }
00267 else
00268 {
00269 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
00270 }
00271 }
00272
00273
00274
00275 template<typename eT>
00276 arma_hot
00277 inline
00278 static
00279 void
00280 apply
00281 (
00282 Mat<eT>& C,
00283 const Mat<eT>& A,
00284 const Mat<eT>& B,
00285 const eT alpha = eT(1),
00286 const eT beta = eT(0),
00287 const typename arma_cx_only<eT>::result* junk = 0
00288 )
00289 {
00290 arma_extra_debug_sigprint();
00291 arma_ignore(junk);
00292
00293
00294
00295 Mat<eT> tmp_A;
00296 Mat<eT> tmp_B;
00297
00298 if(do_trans_A)
00299 {
00300 op_htrans::apply_noalias(tmp_A, A);
00301 }
00302
00303 if(do_trans_B)
00304 {
00305 op_htrans::apply_noalias(tmp_B, B);
00306 }
00307
00308 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
00309 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
00310
00311 const uword A_n_rows = AA.n_rows;
00312 const uword A_n_cols = AA.n_cols;
00313
00314 const uword B_n_rows = BB.n_rows;
00315 const uword B_n_cols = BB.n_cols;
00316
00317 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
00318 {
00319 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00320 }
00321 else
00322 {
00323 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00324 }
00325 }
00326
00327 };
00328
00329
00330
00334
00335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00336 class gemm
00337 {
00338 public:
00339
00340 template<typename eT>
00341 inline
00342 static
00343 void
00344 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00345 {
00346 arma_extra_debug_sigprint();
00347
00348 if( (A.n_elem <= 48u) && (B.n_elem <= 48u) )
00349 {
00350 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00351 }
00352 else
00353 {
00354 #if defined(ARMA_USE_ATLAS)
00355 {
00356 arma_extra_debug_print("atlas::cblas_gemm()");
00357
00358 atlas::cblas_gemm<eT>
00359 (
00360 atlas::CblasColMajor,
00361 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
00362 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
00363 C.n_rows,
00364 C.n_cols,
00365 (do_trans_A) ? A.n_rows : A.n_cols,
00366 (use_alpha) ? alpha : eT(1),
00367 A.mem,
00368 (do_trans_A) ? A.n_rows : C.n_rows,
00369 B.mem,
00370 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
00371 (use_beta) ? beta : eT(0),
00372 C.memptr(),
00373 C.n_rows
00374 );
00375 }
00376 #elif defined(ARMA_USE_BLAS)
00377 {
00378 arma_extra_debug_print("blas::gemm()");
00379
00380 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
00381 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
00382
00383 const blas_int m = C.n_rows;
00384 const blas_int n = C.n_cols;
00385 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
00386
00387 const eT local_alpha = (use_alpha) ? alpha : eT(1);
00388
00389 const blas_int lda = (do_trans_A) ? k : m;
00390 const blas_int ldb = (do_trans_B) ? n : k;
00391
00392 const eT local_beta = (use_beta) ? beta : eT(0);
00393
00394 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
00395 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
00396
00397 blas::gemm<eT>
00398 (
00399 &trans_A,
00400 &trans_B,
00401 &m,
00402 &n,
00403 &k,
00404 &local_alpha,
00405 A.mem,
00406 &lda,
00407 B.mem,
00408 &ldb,
00409 &local_beta,
00410 C.memptr(),
00411 &m
00412 );
00413 }
00414 #else
00415 {
00416 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00417 }
00418 #endif
00419 }
00420 }
00421
00422
00423
00425 template<typename eT>
00426 inline
00427 static
00428 void
00429 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00430 {
00431 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00432 }
00433
00434
00435
00436 arma_inline
00437 static
00438 void
00439 apply
00440 (
00441 Mat<float>& C,
00442 const Mat<float>& A,
00443 const Mat<float>& B,
00444 const float alpha = float(1),
00445 const float beta = float(0)
00446 )
00447 {
00448 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00449 }
00450
00451
00452
00453 arma_inline
00454 static
00455 void
00456 apply
00457 (
00458 Mat<double>& C,
00459 const Mat<double>& A,
00460 const Mat<double>& B,
00461 const double alpha = double(1),
00462 const double beta = double(0)
00463 )
00464 {
00465 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00466 }
00467
00468
00469
00470 arma_inline
00471 static
00472 void
00473 apply
00474 (
00475 Mat< std::complex<float> >& C,
00476 const Mat< std::complex<float> >& A,
00477 const Mat< std::complex<float> >& B,
00478 const std::complex<float> alpha = std::complex<float>(1),
00479 const std::complex<float> beta = std::complex<float>(0)
00480 )
00481 {
00482 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00483 }
00484
00485
00486
00487 arma_inline
00488 static
00489 void
00490 apply
00491 (
00492 Mat< std::complex<double> >& C,
00493 const Mat< std::complex<double> >& A,
00494 const Mat< std::complex<double> >& B,
00495 const std::complex<double> alpha = std::complex<double>(1),
00496 const std::complex<double> beta = std::complex<double>(0)
00497 )
00498 {
00499 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00500 }
00501
00502 };
00503
00504
00505