$search
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2008-2011 Conrad Sanderson 00003 // 00004 // This file is part of the Armadillo C++ library. 00005 // It is provided without any warranty of fitness 00006 // for any purpose. You can redistribute this file 00007 // and/or modify it under the terms of the GNU 00008 // Lesser General Public License (LGPL) as published 00009 // by the Free Software Foundation, either version 3 00010 // of the License or (at your option) any later version. 00011 // (see http://www.opensource.org/licenses for more info) 00012 00013 00016 00017 00018 00023 00024 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 00025 class gemm_mixed_large 00026 { 00027 public: 00028 00029 template<typename out_eT, typename in_eT1, typename in_eT2> 00030 arma_hot 00031 inline 00032 static 00033 void 00034 apply 00035 ( 00036 Mat<out_eT>& C, 00037 const Mat<in_eT1>& A, 00038 const Mat<in_eT2>& B, 00039 const out_eT alpha = out_eT(1), 00040 const out_eT beta = out_eT(0) 00041 ) 00042 { 00043 arma_extra_debug_sigprint(); 00044 00045 const uword A_n_rows = A.n_rows; 00046 const uword A_n_cols = A.n_cols; 00047 00048 const uword B_n_rows = B.n_rows; 00049 const uword B_n_cols = B.n_cols; 00050 00051 if( (do_trans_A == false) && (do_trans_B == false) ) 00052 { 00053 podarray<in_eT1> tmp(A_n_cols); 00054 in_eT1* A_rowdata = tmp.memptr(); 00055 00056 for(uword row_A=0; row_A < A_n_rows; ++row_A) 00057 { 00058 tmp.copy_row(A, row_A); 00059 00060 for(uword col_B=0; col_B < B_n_cols; ++col_B) 00061 { 00062 const in_eT2* B_coldata = B.colptr(col_B); 00063 00064 out_eT acc = out_eT(0); 00065 for(uword i=0; i < B_n_rows; ++i) 00066 { 00067 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00068 } 00069 00070 if( (use_alpha == false) && (use_beta == false) ) 00071 { 00072 C.at(row_A,col_B) = acc; 00073 } 00074 else 00075 if( (use_alpha == true) && (use_beta == false) ) 00076 { 00077 C.at(row_A,col_B) = alpha * acc; 00078 } 00079 else 00080 if( (use_alpha == false) && (use_beta == true) ) 00081 { 00082 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00083 } 00084 else 00085 if( (use_alpha == true) && (use_beta == true) ) 00086 { 00087 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00088 } 00089 00090 } 00091 } 00092 } 00093 else 00094 if( (do_trans_A == true) && (do_trans_B == false) ) 00095 { 00096 for(uword col_A=0; col_A < A_n_cols; ++col_A) 00097 { 00098 // col_A is interpreted as row_A when storing the results in matrix C 00099 00100 const in_eT1* A_coldata = A.colptr(col_A); 00101 00102 for(uword col_B=0; col_B < B_n_cols; ++col_B) 00103 { 00104 const in_eT2* B_coldata = B.colptr(col_B); 00105 00106 out_eT acc = out_eT(0); 00107 for(uword i=0; i < B_n_rows; ++i) 00108 { 00109 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00110 } 00111 00112 if( (use_alpha == false) && (use_beta == false) ) 00113 { 00114 C.at(col_A,col_B) = acc; 00115 } 00116 else 00117 if( (use_alpha == true) && (use_beta == false) ) 00118 { 00119 C.at(col_A,col_B) = alpha * acc; 00120 } 00121 else 00122 if( (use_alpha == false) && (use_beta == true) ) 00123 { 00124 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00125 } 00126 else 00127 if( (use_alpha == true) && (use_beta == true) ) 00128 { 00129 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00130 } 00131 00132 } 00133 } 00134 } 00135 else 00136 if( (do_trans_A == false) && (do_trans_B == true) ) 00137 { 00138 Mat<in_eT2> B_tmp; 00139 00140 op_strans::apply_noalias(B_tmp, B); 00141 00142 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00143 } 00144 else 00145 if( (do_trans_A == true) && (do_trans_B == true) ) 00146 { 00147 // mat B_tmp = trans(B); 00148 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00149 00150 00151 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00152 // transpose operations are not needed 00153 00154 podarray<in_eT2> tmp(B_n_cols); 00155 in_eT2* B_rowdata = tmp.memptr(); 00156 00157 for(uword row_B=0; row_B < B_n_rows; ++row_B) 00158 { 00159 tmp.copy_row(B, row_B); 00160 00161 for(uword col_A=0; col_A < A_n_cols; ++col_A) 00162 { 00163 const in_eT1* A_coldata = A.colptr(col_A); 00164 00165 out_eT acc = out_eT(0); 00166 for(uword i=0; i < A_n_rows; ++i) 00167 { 00168 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00169 } 00170 00171 if( (use_alpha == false) && (use_beta == false) ) 00172 { 00173 C.at(col_A,row_B) = acc; 00174 } 00175 else 00176 if( (use_alpha == true) && (use_beta == false) ) 00177 { 00178 C.at(col_A,row_B) = alpha * acc; 00179 } 00180 else 00181 if( (use_alpha == false) && (use_beta == true) ) 00182 { 00183 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00184 } 00185 else 00186 if( (use_alpha == true) && (use_beta == true) ) 00187 { 00188 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00189 } 00190 00191 } 00192 } 00193 00194 } 00195 } 00196 00197 }; 00198 00199 00200 00204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 00205 class gemm_mixed_small 00206 { 00207 public: 00208 00209 template<typename out_eT, typename in_eT1, typename in_eT2> 00210 arma_hot 00211 inline 00212 static 00213 void 00214 apply 00215 ( 00216 Mat<out_eT>& C, 00217 const Mat<in_eT1>& A, 00218 const Mat<in_eT2>& B, 00219 const out_eT alpha = out_eT(1), 00220 const out_eT beta = out_eT(0) 00221 ) 00222 { 00223 arma_extra_debug_sigprint(); 00224 00225 const uword A_n_rows = A.n_rows; 00226 const uword A_n_cols = A.n_cols; 00227 00228 const uword B_n_rows = B.n_rows; 00229 const uword B_n_cols = B.n_cols; 00230 00231 if( (do_trans_A == false) && (do_trans_B == false) ) 00232 { 00233 for(uword row_A = 0; row_A < A_n_rows; ++row_A) 00234 { 00235 for(uword col_B = 0; col_B < B_n_cols; ++col_B) 00236 { 00237 const in_eT2* B_coldata = B.colptr(col_B); 00238 00239 out_eT acc = out_eT(0); 00240 for(uword i = 0; i < B_n_rows; ++i) 00241 { 00242 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)); 00243 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00244 acc += val1 * val2; 00245 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00246 } 00247 00248 if( (use_alpha == false) && (use_beta == false) ) 00249 { 00250 C.at(row_A,col_B) = acc; 00251 } 00252 else 00253 if( (use_alpha == true) && (use_beta == false) ) 00254 { 00255 C.at(row_A,col_B) = alpha * acc; 00256 } 00257 else 00258 if( (use_alpha == false) && (use_beta == true) ) 00259 { 00260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00261 } 00262 else 00263 if( (use_alpha == true) && (use_beta == true) ) 00264 { 00265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00266 } 00267 } 00268 } 00269 } 00270 else 00271 if( (do_trans_A == true) && (do_trans_B == false) ) 00272 { 00273 for(uword col_A=0; col_A < A_n_cols; ++col_A) 00274 { 00275 // col_A is interpreted as row_A when storing the results in matrix C 00276 00277 const in_eT1* A_coldata = A.colptr(col_A); 00278 00279 for(uword col_B=0; col_B < B_n_cols; ++col_B) 00280 { 00281 const in_eT2* B_coldata = B.colptr(col_B); 00282 00283 out_eT acc = out_eT(0); 00284 for(uword i=0; i < B_n_rows; ++i) 00285 { 00286 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00287 } 00288 00289 if( (use_alpha == false) && (use_beta == false) ) 00290 { 00291 C.at(col_A,col_B) = acc; 00292 } 00293 else 00294 if( (use_alpha == true) && (use_beta == false) ) 00295 { 00296 C.at(col_A,col_B) = alpha * acc; 00297 } 00298 else 00299 if( (use_alpha == false) && (use_beta == true) ) 00300 { 00301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00302 } 00303 else 00304 if( (use_alpha == true) && (use_beta == true) ) 00305 { 00306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00307 } 00308 00309 } 00310 } 00311 } 00312 else 00313 if( (do_trans_A == false) && (do_trans_B == true) ) 00314 { 00315 for(uword row_A = 0; row_A < A_n_rows; ++row_A) 00316 { 00317 for(uword row_B = 0; row_B < B_n_rows; ++row_B) 00318 { 00319 out_eT acc = out_eT(0); 00320 for(uword i = 0; i < B_n_cols; ++i) 00321 { 00322 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)); 00323 } 00324 00325 if( (use_alpha == false) && (use_beta == false) ) 00326 { 00327 C.at(row_A,row_B) = acc; 00328 } 00329 else 00330 if( (use_alpha == true) && (use_beta == false) ) 00331 { 00332 C.at(row_A,row_B) = alpha * acc; 00333 } 00334 else 00335 if( (use_alpha == false) && (use_beta == true) ) 00336 { 00337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); 00338 } 00339 else 00340 if( (use_alpha == true) && (use_beta == true) ) 00341 { 00342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); 00343 } 00344 } 00345 } 00346 } 00347 else 00348 if( (do_trans_A == true) && (do_trans_B == true) ) 00349 { 00350 for(uword row_B=0; row_B < B_n_rows; ++row_B) 00351 { 00352 00353 for(uword col_A=0; col_A < A_n_cols; ++col_A) 00354 { 00355 const in_eT1* A_coldata = A.colptr(col_A); 00356 00357 out_eT acc = out_eT(0); 00358 for(uword i=0; i < A_n_rows; ++i) 00359 { 00360 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00361 } 00362 00363 if( (use_alpha == false) && (use_beta == false) ) 00364 { 00365 C.at(col_A,row_B) = acc; 00366 } 00367 else 00368 if( (use_alpha == true) && (use_beta == false) ) 00369 { 00370 C.at(col_A,row_B) = alpha * acc; 00371 } 00372 else 00373 if( (use_alpha == false) && (use_beta == true) ) 00374 { 00375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00376 } 00377 else 00378 if( (use_alpha == true) && (use_beta == true) ) 00379 { 00380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00381 } 00382 00383 } 00384 } 00385 00386 } 00387 } 00388 00389 }; 00390 00391 00392 00393 00394 00397 00398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 00399 class gemm_mixed 00400 { 00401 public: 00402 00404 template<typename out_eT, typename in_eT1, typename in_eT2> 00405 inline 00406 static 00407 void 00408 apply 00409 ( 00410 Mat<out_eT>& C, 00411 const Mat<in_eT1>& A, 00412 const Mat<in_eT2>& B, 00413 const out_eT alpha = out_eT(1), 00414 const out_eT beta = out_eT(0) 00415 ) 00416 { 00417 arma_extra_debug_sigprint(); 00418 00419 Mat<in_eT1> tmp_A; 00420 Mat<in_eT2> tmp_B; 00421 00422 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) ); 00423 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) ); 00424 00425 if(do_trans_A) 00426 { 00427 op_htrans::apply_noalias(tmp_A, A); 00428 } 00429 00430 if(do_trans_B) 00431 { 00432 op_htrans::apply_noalias(tmp_B, B); 00433 } 00434 00435 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A; 00436 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B; 00437 00438 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) ) 00439 { 00440 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); 00441 } 00442 else 00443 { 00444 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); 00445 } 00446 } 00447 00448 00449 }; 00450 00451 00452