00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00016
00017
00018
00020 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00021 class gemv_emul_tinysq
00022 {
00023 public:
00024
00025
00026 template<const uword row, const uword col>
00027 struct pos
00028 {
00029 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
00030 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
00031 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
00032 };
00033
00034
00035
00036 template<typename eT, const uword i>
00037 arma_hot
00038 arma_inline
00039 static
00040 void
00041 assign(eT* y, const eT acc, const eT alpha, const eT beta)
00042 {
00043 if(use_beta == false)
00044 {
00045 y[i] = (use_alpha == false) ? acc : alpha*acc;
00046 }
00047 else
00048 {
00049 const eT tmp = y[i];
00050
00051 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
00052 }
00053 }
00054
00055
00056
00057 template<typename eT>
00058 arma_hot
00059 inline
00060 static
00061 void
00062 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00063 {
00064 arma_extra_debug_sigprint();
00065
00066 const eT* Am = A.memptr();
00067
00068 switch(A.n_rows)
00069 {
00070 case 1:
00071 {
00072 const eT acc = Am[0] * x[0];
00073
00074 assign<eT, 0>(y, acc, alpha, beta);
00075 }
00076 break;
00077
00078
00079 case 2:
00080 {
00081 const eT x0 = x[0];
00082 const eT x1 = x[1];
00083
00084 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
00085 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
00086
00087 assign<eT, 0>(y, acc0, alpha, beta);
00088 assign<eT, 1>(y, acc1, alpha, beta);
00089 }
00090 break;
00091
00092
00093 case 3:
00094 {
00095 const eT x0 = x[0];
00096 const eT x1 = x[1];
00097 const eT x2 = x[2];
00098
00099 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
00100 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
00101 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
00102
00103 assign<eT, 0>(y, acc0, alpha, beta);
00104 assign<eT, 1>(y, acc1, alpha, beta);
00105 assign<eT, 2>(y, acc2, alpha, beta);
00106 }
00107 break;
00108
00109
00110 case 4:
00111 {
00112 const eT x0 = x[0];
00113 const eT x1 = x[1];
00114 const eT x2 = x[2];
00115 const eT x3 = x[3];
00116
00117 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
00118 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
00119 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
00120 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
00121
00122 assign<eT, 0>(y, acc0, alpha, beta);
00123 assign<eT, 1>(y, acc1, alpha, beta);
00124 assign<eT, 2>(y, acc2, alpha, beta);
00125 assign<eT, 3>(y, acc3, alpha, beta);
00126 }
00127 break;
00128
00129
00130 default:
00131 ;
00132 }
00133 }
00134
00135 };
00136
00137
00138
00142
00143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00144 class gemv_emul_large
00145 {
00146 public:
00147
00148 template<typename eT>
00149 arma_hot
00150 inline
00151 static
00152 void
00153 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00154 {
00155 arma_extra_debug_sigprint();
00156
00157 const uword A_n_rows = A.n_rows;
00158 const uword A_n_cols = A.n_cols;
00159
00160 if(do_trans_A == false)
00161 {
00162 if(A_n_rows == 1)
00163 {
00164 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x);
00165
00166 if( (use_alpha == false) && (use_beta == false) )
00167 {
00168 y[0] = acc;
00169 }
00170 else
00171 if( (use_alpha == true) && (use_beta == false) )
00172 {
00173 y[0] = alpha * acc;
00174 }
00175 else
00176 if( (use_alpha == false) && (use_beta == true) )
00177 {
00178 y[0] = acc + beta*y[0];
00179 }
00180 else
00181 if( (use_alpha == true) && (use_beta == true) )
00182 {
00183 y[0] = alpha*acc + beta*y[0];
00184 }
00185 }
00186 else
00187 for(uword row=0; row < A_n_rows; ++row)
00188 {
00189 eT acc = eT(0);
00190
00191 for(uword i=0; i < A_n_cols; ++i)
00192 {
00193 acc += A.at(row,i) * x[i];
00194 }
00195
00196 if( (use_alpha == false) && (use_beta == false) )
00197 {
00198 y[row] = acc;
00199 }
00200 else
00201 if( (use_alpha == true) && (use_beta == false) )
00202 {
00203 y[row] = alpha * acc;
00204 }
00205 else
00206 if( (use_alpha == false) && (use_beta == true) )
00207 {
00208 y[row] = acc + beta*y[row];
00209 }
00210 else
00211 if( (use_alpha == true) && (use_beta == true) )
00212 {
00213 y[row] = alpha*acc + beta*y[row];
00214 }
00215 }
00216 }
00217 else
00218 if(do_trans_A == true)
00219 {
00220 for(uword col=0; col < A_n_cols; ++col)
00221 {
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
00234
00235 if( (use_alpha == false) && (use_beta == false) )
00236 {
00237 y[col] = acc;
00238 }
00239 else
00240 if( (use_alpha == true) && (use_beta == false) )
00241 {
00242 y[col] = alpha * acc;
00243 }
00244 else
00245 if( (use_alpha == false) && (use_beta == true) )
00246 {
00247 y[col] = acc + beta*y[col];
00248 }
00249 else
00250 if( (use_alpha == true) && (use_beta == true) )
00251 {
00252 y[col] = alpha*acc + beta*y[col];
00253 }
00254
00255 }
00256 }
00257 }
00258
00259 };
00260
00261
00262
00263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00264 class gemv_emul
00265 {
00266 public:
00267
00268 template<typename eT>
00269 arma_hot
00270 inline
00271 static
00272 void
00273 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
00274 {
00275 arma_extra_debug_sigprint();
00276 arma_ignore(junk);
00277
00278 const uword A_n_rows = A.n_rows;
00279 const uword A_n_cols = A.n_cols;
00280
00281 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
00282 {
00283 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
00284 }
00285 else
00286 {
00287 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
00288 }
00289 }
00290
00291
00292
00293 template<typename eT>
00294 arma_hot
00295 inline
00296 static
00297 void
00298 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
00299 {
00300 arma_extra_debug_sigprint();
00301
00302 Mat<eT> tmp_A;
00303
00304 if(do_trans_A)
00305 {
00306 op_htrans::apply_noalias(tmp_A, A);
00307 }
00308
00309 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
00310
00311 const uword AA_n_rows = AA.n_rows;
00312 const uword AA_n_cols = AA.n_cols;
00313
00314 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
00315 {
00316 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
00317 }
00318 else
00319 {
00320 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
00321 }
00322 }
00323 };
00324
00325
00326
00330
00331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00332 class gemv
00333 {
00334 public:
00335
00336 template<typename eT>
00337 inline
00338 static
00339 void
00340 apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00341 {
00342 arma_extra_debug_sigprint();
00343
00344 if(A.n_elem <= 64u)
00345 {
00346 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00347 }
00348 else
00349 {
00350 #if defined(ARMA_USE_ATLAS)
00351 {
00352 arma_extra_debug_print("atlas::cblas_gemv()");
00353
00354 atlas::cblas_gemv<eT>
00355 (
00356 atlas::CblasColMajor,
00357 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
00358 A.n_rows,
00359 A.n_cols,
00360 (use_alpha) ? alpha : eT(1),
00361 A.mem,
00362 A.n_rows,
00363 x,
00364 1,
00365 (use_beta) ? beta : eT(0),
00366 y,
00367 1
00368 );
00369 }
00370 #elif defined(ARMA_USE_BLAS)
00371 {
00372 arma_extra_debug_print("blas::gemv()");
00373
00374 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
00375 const blas_int m = A.n_rows;
00376 const blas_int n = A.n_cols;
00377 const eT local_alpha = (use_alpha) ? alpha : eT(1);
00378
00379 const blas_int inc = 1;
00380 const eT local_beta = (use_beta) ? beta : eT(0);
00381
00382 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
00383
00384 blas::gemv<eT>
00385 (
00386 &trans_A,
00387 &m,
00388 &n,
00389 &local_alpha,
00390 A.mem,
00391 &m,
00392 x,
00393 &inc,
00394 &local_beta,
00395 y,
00396 &inc
00397 );
00398 }
00399 #else
00400 {
00401 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00402 }
00403 #endif
00404 }
00405
00406 }
00407
00408
00409
00410 template<typename eT>
00411 arma_inline
00412 static
00413 void
00414 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00415 {
00416 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00417 }
00418
00419
00420
00421 arma_inline
00422 static
00423 void
00424 apply
00425 (
00426 float* y,
00427 const Mat<float>& A,
00428 const float* x,
00429 const float alpha = float(1),
00430 const float beta = float(0)
00431 )
00432 {
00433 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00434 }
00435
00436
00437
00438 arma_inline
00439 static
00440 void
00441 apply
00442 (
00443 double* y,
00444 const Mat<double>& A,
00445 const double* x,
00446 const double alpha = double(1),
00447 const double beta = double(0)
00448 )
00449 {
00450 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00451 }
00452
00453
00454
00455 arma_inline
00456 static
00457 void
00458 apply
00459 (
00460 std::complex<float>* y,
00461 const Mat< std::complex<float > >& A,
00462 const std::complex<float>* x,
00463 const std::complex<float> alpha = std::complex<float>(1),
00464 const std::complex<float> beta = std::complex<float>(0)
00465 )
00466 {
00467 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00468 }
00469
00470
00471
00472 arma_inline
00473 static
00474 void
00475 apply
00476 (
00477 std::complex<double>* y,
00478 const Mat< std::complex<double> >& A,
00479 const std::complex<double>* x,
00480 const std::complex<double> alpha = std::complex<double>(1),
00481 const std::complex<double> beta = std::complex<double>(0)
00482 )
00483 {
00484 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00485 }
00486
00487
00488
00489 };
00490
00491