00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00016
00017
00018
00023
00024 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00025 class gemm_mixed_large
00026 {
00027 public:
00028
00029 template<typename out_eT, typename in_eT1, typename in_eT2>
00030 arma_hot
00031 inline
00032 static
00033 void
00034 apply
00035 (
00036 Mat<out_eT>& C,
00037 const Mat<in_eT1>& A,
00038 const Mat<in_eT2>& B,
00039 const out_eT alpha = out_eT(1),
00040 const out_eT beta = out_eT(0)
00041 )
00042 {
00043 arma_extra_debug_sigprint();
00044
00045 const uword A_n_rows = A.n_rows;
00046 const uword A_n_cols = A.n_cols;
00047
00048 const uword B_n_rows = B.n_rows;
00049 const uword B_n_cols = B.n_cols;
00050
00051 if( (do_trans_A == false) && (do_trans_B == false) )
00052 {
00053 podarray<in_eT1> tmp(A_n_cols);
00054 in_eT1* A_rowdata = tmp.memptr();
00055
00056 for(uword row_A=0; row_A < A_n_rows; ++row_A)
00057 {
00058 tmp.copy_row(A, row_A);
00059
00060 for(uword col_B=0; col_B < B_n_cols; ++col_B)
00061 {
00062 const in_eT2* B_coldata = B.colptr(col_B);
00063
00064 out_eT acc = out_eT(0);
00065 for(uword i=0; i < B_n_rows; ++i)
00066 {
00067 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00068 }
00069
00070 if( (use_alpha == false) && (use_beta == false) )
00071 {
00072 C.at(row_A,col_B) = acc;
00073 }
00074 else
00075 if( (use_alpha == true) && (use_beta == false) )
00076 {
00077 C.at(row_A,col_B) = alpha * acc;
00078 }
00079 else
00080 if( (use_alpha == false) && (use_beta == true) )
00081 {
00082 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00083 }
00084 else
00085 if( (use_alpha == true) && (use_beta == true) )
00086 {
00087 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00088 }
00089
00090 }
00091 }
00092 }
00093 else
00094 if( (do_trans_A == true) && (do_trans_B == false) )
00095 {
00096 for(uword col_A=0; col_A < A_n_cols; ++col_A)
00097 {
00098
00099
00100 const in_eT1* A_coldata = A.colptr(col_A);
00101
00102 for(uword col_B=0; col_B < B_n_cols; ++col_B)
00103 {
00104 const in_eT2* B_coldata = B.colptr(col_B);
00105
00106 out_eT acc = out_eT(0);
00107 for(uword i=0; i < B_n_rows; ++i)
00108 {
00109 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00110 }
00111
00112 if( (use_alpha == false) && (use_beta == false) )
00113 {
00114 C.at(col_A,col_B) = acc;
00115 }
00116 else
00117 if( (use_alpha == true) && (use_beta == false) )
00118 {
00119 C.at(col_A,col_B) = alpha * acc;
00120 }
00121 else
00122 if( (use_alpha == false) && (use_beta == true) )
00123 {
00124 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00125 }
00126 else
00127 if( (use_alpha == true) && (use_beta == true) )
00128 {
00129 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00130 }
00131
00132 }
00133 }
00134 }
00135 else
00136 if( (do_trans_A == false) && (do_trans_B == true) )
00137 {
00138 Mat<in_eT2> B_tmp;
00139
00140 op_strans::apply_noalias(B_tmp, B);
00141
00142 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00143 }
00144 else
00145 if( (do_trans_A == true) && (do_trans_B == true) )
00146 {
00147
00148
00149
00150
00151
00152
00153
00154 podarray<in_eT2> tmp(B_n_cols);
00155 in_eT2* B_rowdata = tmp.memptr();
00156
00157 for(uword row_B=0; row_B < B_n_rows; ++row_B)
00158 {
00159 tmp.copy_row(B, row_B);
00160
00161 for(uword col_A=0; col_A < A_n_cols; ++col_A)
00162 {
00163 const in_eT1* A_coldata = A.colptr(col_A);
00164
00165 out_eT acc = out_eT(0);
00166 for(uword i=0; i < A_n_rows; ++i)
00167 {
00168 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00169 }
00170
00171 if( (use_alpha == false) && (use_beta == false) )
00172 {
00173 C.at(col_A,row_B) = acc;
00174 }
00175 else
00176 if( (use_alpha == true) && (use_beta == false) )
00177 {
00178 C.at(col_A,row_B) = alpha * acc;
00179 }
00180 else
00181 if( (use_alpha == false) && (use_beta == true) )
00182 {
00183 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00184 }
00185 else
00186 if( (use_alpha == true) && (use_beta == true) )
00187 {
00188 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00189 }
00190
00191 }
00192 }
00193
00194 }
00195 }
00196
00197 };
00198
00199
00200
00204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00205 class gemm_mixed_small
00206 {
00207 public:
00208
00209 template<typename out_eT, typename in_eT1, typename in_eT2>
00210 arma_hot
00211 inline
00212 static
00213 void
00214 apply
00215 (
00216 Mat<out_eT>& C,
00217 const Mat<in_eT1>& A,
00218 const Mat<in_eT2>& B,
00219 const out_eT alpha = out_eT(1),
00220 const out_eT beta = out_eT(0)
00221 )
00222 {
00223 arma_extra_debug_sigprint();
00224
00225 const uword A_n_rows = A.n_rows;
00226 const uword A_n_cols = A.n_cols;
00227
00228 const uword B_n_rows = B.n_rows;
00229 const uword B_n_cols = B.n_cols;
00230
00231 if( (do_trans_A == false) && (do_trans_B == false) )
00232 {
00233 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
00234 {
00235 for(uword col_B = 0; col_B < B_n_cols; ++col_B)
00236 {
00237 const in_eT2* B_coldata = B.colptr(col_B);
00238
00239 out_eT acc = out_eT(0);
00240 for(uword i = 0; i < B_n_rows; ++i)
00241 {
00242 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00243 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00244 acc += val1 * val2;
00245
00246 }
00247
00248 if( (use_alpha == false) && (use_beta == false) )
00249 {
00250 C.at(row_A,col_B) = acc;
00251 }
00252 else
00253 if( (use_alpha == true) && (use_beta == false) )
00254 {
00255 C.at(row_A,col_B) = alpha * acc;
00256 }
00257 else
00258 if( (use_alpha == false) && (use_beta == true) )
00259 {
00260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00261 }
00262 else
00263 if( (use_alpha == true) && (use_beta == true) )
00264 {
00265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00266 }
00267 }
00268 }
00269 }
00270 else
00271 if( (do_trans_A == true) && (do_trans_B == false) )
00272 {
00273 for(uword col_A=0; col_A < A_n_cols; ++col_A)
00274 {
00275
00276
00277 const in_eT1* A_coldata = A.colptr(col_A);
00278
00279 for(uword col_B=0; col_B < B_n_cols; ++col_B)
00280 {
00281 const in_eT2* B_coldata = B.colptr(col_B);
00282
00283 out_eT acc = out_eT(0);
00284 for(uword i=0; i < B_n_rows; ++i)
00285 {
00286 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00287 }
00288
00289 if( (use_alpha == false) && (use_beta == false) )
00290 {
00291 C.at(col_A,col_B) = acc;
00292 }
00293 else
00294 if( (use_alpha == true) && (use_beta == false) )
00295 {
00296 C.at(col_A,col_B) = alpha * acc;
00297 }
00298 else
00299 if( (use_alpha == false) && (use_beta == true) )
00300 {
00301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00302 }
00303 else
00304 if( (use_alpha == true) && (use_beta == true) )
00305 {
00306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00307 }
00308
00309 }
00310 }
00311 }
00312 else
00313 if( (do_trans_A == false) && (do_trans_B == true) )
00314 {
00315 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
00316 {
00317 for(uword row_B = 0; row_B < B_n_rows; ++row_B)
00318 {
00319 out_eT acc = out_eT(0);
00320 for(uword i = 0; i < B_n_cols; ++i)
00321 {
00322 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00323 }
00324
00325 if( (use_alpha == false) && (use_beta == false) )
00326 {
00327 C.at(row_A,row_B) = acc;
00328 }
00329 else
00330 if( (use_alpha == true) && (use_beta == false) )
00331 {
00332 C.at(row_A,row_B) = alpha * acc;
00333 }
00334 else
00335 if( (use_alpha == false) && (use_beta == true) )
00336 {
00337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00338 }
00339 else
00340 if( (use_alpha == true) && (use_beta == true) )
00341 {
00342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00343 }
00344 }
00345 }
00346 }
00347 else
00348 if( (do_trans_A == true) && (do_trans_B == true) )
00349 {
00350 for(uword row_B=0; row_B < B_n_rows; ++row_B)
00351 {
00352
00353 for(uword col_A=0; col_A < A_n_cols; ++col_A)
00354 {
00355 const in_eT1* A_coldata = A.colptr(col_A);
00356
00357 out_eT acc = out_eT(0);
00358 for(uword i=0; i < A_n_rows; ++i)
00359 {
00360 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00361 }
00362
00363 if( (use_alpha == false) && (use_beta == false) )
00364 {
00365 C.at(col_A,row_B) = acc;
00366 }
00367 else
00368 if( (use_alpha == true) && (use_beta == false) )
00369 {
00370 C.at(col_A,row_B) = alpha * acc;
00371 }
00372 else
00373 if( (use_alpha == false) && (use_beta == true) )
00374 {
00375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00376 }
00377 else
00378 if( (use_alpha == true) && (use_beta == true) )
00379 {
00380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00381 }
00382
00383 }
00384 }
00385
00386 }
00387 }
00388
00389 };
00390
00391
00392
00393
00394
00397
00398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00399 class gemm_mixed
00400 {
00401 public:
00402
00404 template<typename out_eT, typename in_eT1, typename in_eT2>
00405 inline
00406 static
00407 void
00408 apply
00409 (
00410 Mat<out_eT>& C,
00411 const Mat<in_eT1>& A,
00412 const Mat<in_eT2>& B,
00413 const out_eT alpha = out_eT(1),
00414 const out_eT beta = out_eT(0)
00415 )
00416 {
00417 arma_extra_debug_sigprint();
00418
00419 Mat<in_eT1> tmp_A;
00420 Mat<in_eT2> tmp_B;
00421
00422 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
00423 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
00424
00425 if(do_trans_A)
00426 {
00427 op_htrans::apply_noalias(tmp_A, A);
00428 }
00429
00430 if(do_trans_B)
00431 {
00432 op_htrans::apply_noalias(tmp_B, B);
00433 }
00434
00435 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
00436 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
00437
00438 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
00439 {
00440 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00441 }
00442 else
00443 {
00444 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00445 }
00446 }
00447
00448
00449 };
00450
00451
00452