$search
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2008-2011 Conrad Sanderson 00003 // 00004 // This file is part of the Armadillo C++ library. 00005 // It is provided without any warranty of fitness 00006 // for any purpose. You can redistribute this file 00007 // and/or modify it under the terms of the GNU 00008 // Lesser General Public License (LGPL) as published 00009 // by the Free Software Foundation, either version 3 00010 // of the License or (at your option) any later version. 00011 // (see http://www.opensource.org/licenses for more info) 00012 00013 00016 00017 00018 00020 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> 00021 class gemm_emul_tinysq 00022 { 00023 public: 00024 00025 00026 template<typename eT> 00027 arma_hot 00028 inline 00029 static 00030 void 00031 apply 00032 ( 00033 Mat<eT>& C, 00034 const Mat<eT>& A, 00035 const Mat<eT>& B, 00036 const eT alpha = eT(1), 00037 const eT beta = eT(0) 00038 ) 00039 { 00040 arma_extra_debug_sigprint(); 00041 00042 switch(A.n_rows) 00043 { 00044 case 4: 00045 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta ); 00046 00047 case 3: 00048 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta ); 00049 00050 case 2: 00051 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta ); 00052 00053 case 1: 00054 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta ); 00055 00056 default: 00057 ; 00058 } 00059 } 00060 00061 }; 00062 00063 00064 00065 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 00066 class gemm_emul_large 00067 { 00068 public: 00069 00070 template<typename eT> 00071 arma_hot 00072 inline 00073 static 00074 void 00075 apply 00076 ( 00077 Mat<eT>& C, 00078 const Mat<eT>& A, 00079 const Mat<eT>& B, 00080 const eT alpha = eT(1), 00081 const eT beta = eT(0) 00082 ) 00083 { 00084 arma_extra_debug_sigprint(); 00085 00086 const uword A_n_rows = A.n_rows; 00087 const uword A_n_cols = A.n_cols; 00088 00089 const uword B_n_rows = B.n_rows; 00090 const uword B_n_cols = B.n_cols; 00091 00092 if( (do_trans_A == false) && (do_trans_B == false) ) 00093 { 00094 arma_aligned podarray<eT> tmp(A_n_cols); 00095 eT* A_rowdata = tmp.memptr(); 00096 00097 for(uword row_A=0; row_A < A_n_rows; ++row_A) 00098 { 00099 tmp.copy_row(A, row_A); 00100 00101 for(uword col_B=0; col_B < B_n_cols; ++col_B) 00102 { 00103 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B)); 00104 00105 if( (use_alpha == false) && (use_beta == false) ) 00106 { 00107 C.at(row_A,col_B) = acc; 00108 } 00109 else 00110 if( (use_alpha == true) && (use_beta == false) ) 00111 { 00112 C.at(row_A,col_B) = alpha * acc; 00113 } 00114 else 00115 if( (use_alpha == false) && (use_beta == true) ) 00116 { 00117 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00118 } 00119 else 00120 if( (use_alpha == true) && (use_beta == true) ) 00121 { 00122 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00123 } 00124 00125 } 00126 } 00127 } 00128 else 00129 if( (do_trans_A == true) && (do_trans_B == false) ) 00130 { 00131 for(uword col_A=0; col_A < A_n_cols; ++col_A) 00132 { 00133 // col_A is interpreted as row_A when storing the results in matrix C 00134 00135 const eT* A_coldata = A.colptr(col_A); 00136 00137 for(uword col_B=0; col_B < B_n_cols; ++col_B) 00138 { 00139 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B)); 00140 00141 if( (use_alpha == false) && (use_beta == false) ) 00142 { 00143 C.at(col_A,col_B) = acc; 00144 } 00145 else 00146 if( (use_alpha == true) && (use_beta == false) ) 00147 { 00148 C.at(col_A,col_B) = alpha * acc; 00149 } 00150 else 00151 if( (use_alpha == false) && (use_beta == true) ) 00152 { 00153 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00154 } 00155 else 00156 if( (use_alpha == true) && (use_beta == true) ) 00157 { 00158 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00159 } 00160 00161 } 00162 } 00163 } 00164 else 00165 if( (do_trans_A == false) && (do_trans_B == true) ) 00166 { 00167 Mat<eT> BB; 00168 op_strans::apply_noalias(BB, B); 00169 00170 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); 00171 } 00172 else 00173 if( (do_trans_A == true) && (do_trans_B == true) ) 00174 { 00175 // mat B_tmp = trans(B); 00176 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00177 00178 00179 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00180 // transpose operations are not needed 00181 00182 arma_aligned podarray<eT> tmp(B.n_cols); 00183 eT* B_rowdata = tmp.memptr(); 00184 00185 for(uword row_B=0; row_B < B_n_rows; ++row_B) 00186 { 00187 tmp.copy_row(B, row_B); 00188 00189 for(uword col_A=0; col_A < A_n_cols; ++col_A) 00190 { 00191 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A)); 00192 00193 if( (use_alpha == false) && (use_beta == false) ) 00194 { 00195 C.at(col_A,row_B) = acc; 00196 } 00197 else 00198 if( (use_alpha == true) && (use_beta == false) ) 00199 { 00200 C.at(col_A,row_B) = alpha * acc; 00201 } 00202 else 00203 if( (use_alpha == false) && (use_beta == true) ) 00204 { 00205 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00206 } 00207 else 00208 if( (use_alpha == true) && (use_beta == true) ) 00209 { 00210 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00211 } 00212 00213 } 00214 } 00215 00216 } 00217 } 00218 00219 }; 00220 00221 00222 00223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 00224 class gemm_emul 00225 { 00226 public: 00227 00228 00229 template<typename eT> 00230 arma_hot 00231 inline 00232 static 00233 void 00234 apply 00235 ( 00236 Mat<eT>& C, 00237 const Mat<eT>& A, 00238 const Mat<eT>& B, 00239 const eT alpha = eT(1), 00240 const eT beta = eT(0), 00241 const typename arma_not_cx<eT>::result* junk = 0 00242 ) 00243 { 00244 arma_extra_debug_sigprint(); 00245 arma_ignore(junk); 00246 00247 const uword A_n_rows = A.n_rows; 00248 const uword A_n_cols = A.n_cols; 00249 00250 const uword B_n_rows = B.n_rows; 00251 const uword B_n_cols = B.n_cols; 00252 00253 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) ) 00254 { 00255 if(do_trans_B == false) 00256 { 00257 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta); 00258 } 00259 else 00260 { 00261 Mat<eT> BB(A_n_rows, A_n_rows); 00262 op_strans::apply_noalias_tinysq(BB, B); 00263 00264 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); 00265 } 00266 } 00267 else 00268 { 00269 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta); 00270 } 00271 } 00272 00273 00274 00275 template<typename eT> 00276 arma_hot 00277 inline 00278 static 00279 void 00280 apply 00281 ( 00282 Mat<eT>& C, 00283 const Mat<eT>& A, 00284 const Mat<eT>& B, 00285 const eT alpha = eT(1), 00286 const eT beta = eT(0), 00287 const typename arma_cx_only<eT>::result* junk = 0 00288 ) 00289 { 00290 arma_extra_debug_sigprint(); 00291 arma_ignore(junk); 00292 00293 // "better than nothing" handling of hermitian transposes for complex number matrices 00294 00295 Mat<eT> tmp_A; 00296 Mat<eT> tmp_B; 00297 00298 if(do_trans_A) 00299 { 00300 op_htrans::apply_noalias(tmp_A, A); 00301 } 00302 00303 if(do_trans_B) 00304 { 00305 op_htrans::apply_noalias(tmp_B, B); 00306 } 00307 00308 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A; 00309 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B; 00310 00311 const uword A_n_rows = AA.n_rows; 00312 const uword A_n_cols = AA.n_cols; 00313 00314 const uword B_n_rows = BB.n_rows; 00315 const uword B_n_cols = BB.n_cols; 00316 00317 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) ) 00318 { 00319 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); 00320 } 00321 else 00322 { 00323 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); 00324 } 00325 } 00326 00327 }; 00328 00329 00330 00334 00335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 00336 class gemm 00337 { 00338 public: 00339 00340 template<typename eT> 00341 inline 00342 static 00343 void 00344 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) ) 00345 { 00346 arma_extra_debug_sigprint(); 00347 00348 if( (A.n_elem <= 48u) && (B.n_elem <= 48u) ) 00349 { 00350 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); 00351 } 00352 else 00353 { 00354 #if defined(ARMA_USE_ATLAS) 00355 { 00356 arma_extra_debug_print("atlas::cblas_gemm()"); 00357 00358 atlas::cblas_gemm<eT> 00359 ( 00360 atlas::CblasColMajor, 00361 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, 00362 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, 00363 C.n_rows, 00364 C.n_cols, 00365 (do_trans_A) ? A.n_rows : A.n_cols, 00366 (use_alpha) ? alpha : eT(1), 00367 A.mem, 00368 (do_trans_A) ? A.n_rows : C.n_rows, 00369 B.mem, 00370 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ), 00371 (use_beta) ? beta : eT(0), 00372 C.memptr(), 00373 C.n_rows 00374 ); 00375 } 00376 #elif defined(ARMA_USE_BLAS) 00377 { 00378 arma_extra_debug_print("blas::gemm()"); 00379 00380 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; 00381 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; 00382 00383 const blas_int m = C.n_rows; 00384 const blas_int n = C.n_cols; 00385 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols; 00386 00387 const eT local_alpha = (use_alpha) ? alpha : eT(1); 00388 00389 const blas_int lda = (do_trans_A) ? k : m; 00390 const blas_int ldb = (do_trans_B) ? n : k; 00391 00392 const eT local_beta = (use_beta) ? beta : eT(0); 00393 00394 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A ); 00395 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B ); 00396 00397 blas::gemm<eT> 00398 ( 00399 &trans_A, 00400 &trans_B, 00401 &m, 00402 &n, 00403 &k, 00404 &local_alpha, 00405 A.mem, 00406 &lda, 00407 B.mem, 00408 &ldb, 00409 &local_beta, 00410 C.memptr(), 00411 &m 00412 ); 00413 } 00414 #else 00415 { 00416 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); 00417 } 00418 #endif 00419 } 00420 } 00421 00422 00423 00425 template<typename eT> 00426 inline 00427 static 00428 void 00429 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) ) 00430 { 00431 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); 00432 } 00433 00434 00435 00436 arma_inline 00437 static 00438 void 00439 apply 00440 ( 00441 Mat<float>& C, 00442 const Mat<float>& A, 00443 const Mat<float>& B, 00444 const float alpha = float(1), 00445 const float beta = float(0) 00446 ) 00447 { 00448 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 00449 } 00450 00451 00452 00453 arma_inline 00454 static 00455 void 00456 apply 00457 ( 00458 Mat<double>& C, 00459 const Mat<double>& A, 00460 const Mat<double>& B, 00461 const double alpha = double(1), 00462 const double beta = double(0) 00463 ) 00464 { 00465 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 00466 } 00467 00468 00469 00470 arma_inline 00471 static 00472 void 00473 apply 00474 ( 00475 Mat< std::complex<float> >& C, 00476 const Mat< std::complex<float> >& A, 00477 const Mat< std::complex<float> >& B, 00478 const std::complex<float> alpha = std::complex<float>(1), 00479 const std::complex<float> beta = std::complex<float>(0) 00480 ) 00481 { 00482 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 00483 } 00484 00485 00486 00487 arma_inline 00488 static 00489 void 00490 apply 00491 ( 00492 Mat< std::complex<double> >& C, 00493 const Mat< std::complex<double> >& A, 00494 const Mat< std::complex<double> >& B, 00495 const std::complex<double> alpha = std::complex<double>(1), 00496 const std::complex<double> beta = std::complex<double>(0) 00497 ) 00498 { 00499 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 00500 } 00501 00502 }; 00503 00504 00505