gemv.hpp
Go to the documentation of this file.
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
00002 // Copyright (C) 2008-2011 Conrad Sanderson
00003 // 
00004 // This file is part of the Armadillo C++ library.
00005 // It is provided without any warranty of fitness
00006 // for any purpose. You can redistribute this file
00007 // and/or modify it under the terms of the GNU
00008 // Lesser General Public License (LGPL) as published
00009 // by the Free Software Foundation, either version 3
00010 // of the License or (at your option) any later version.
00011 // (see http://www.opensource.org/licenses for more info)
00012 
00013 
00016 
00017 
00018 
00020 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00021 class gemv_emul_tinysq
00022   {
00023   public:
00024   
00025   
00026   template<const uword row, const uword col>
00027   struct pos
00028     {
00029     static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
00030     static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
00031     static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
00032     };
00033   
00034   
00035   
00036   template<typename eT, const uword i>
00037   arma_hot
00038   arma_inline
00039   static
00040   void
00041   assign(eT* y, const eT acc, const eT alpha, const eT beta)
00042     {
00043     if(use_beta == false)
00044       {
00045       y[i] = (use_alpha == false) ? acc : alpha*acc;
00046       }
00047     else
00048       {
00049       const eT tmp = y[i];
00050       
00051       y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
00052       }
00053     }
00054   
00055   
00056 
00057   template<typename eT>
00058   arma_hot
00059   inline
00060   static
00061   void
00062   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00063     {
00064     arma_extra_debug_sigprint();
00065     
00066     const eT*  Am = A.memptr();
00067     
00068     switch(A.n_rows)
00069       {
00070       case 1:
00071         {
00072         const eT acc = Am[0] * x[0];
00073         
00074         assign<eT, 0>(y, acc, alpha, beta);
00075         }
00076         break;
00077       
00078       
00079       case 2:
00080         {
00081         const eT x0 = x[0];
00082         const eT x1 = x[1];
00083         
00084         const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
00085         const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
00086         
00087         assign<eT, 0>(y, acc0, alpha, beta);
00088         assign<eT, 1>(y, acc1, alpha, beta);
00089         }
00090         break;
00091       
00092         
00093       case 3:
00094         {
00095         const eT x0 = x[0];
00096         const eT x1 = x[1];
00097         const eT x2 = x[2];
00098         
00099         const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
00100         const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
00101         const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
00102         
00103         assign<eT, 0>(y, acc0, alpha, beta);
00104         assign<eT, 1>(y, acc1, alpha, beta);
00105         assign<eT, 2>(y, acc2, alpha, beta);
00106         }
00107         break;
00108       
00109       
00110       case 4:
00111         {
00112         const eT x0 = x[0];
00113         const eT x1 = x[1];
00114         const eT x2 = x[2];
00115         const eT x3 = x[3];
00116         
00117         const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
00118         const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
00119         const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
00120         const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
00121         
00122         assign<eT, 0>(y, acc0, alpha, beta);
00123         assign<eT, 1>(y, acc1, alpha, beta);
00124         assign<eT, 2>(y, acc2, alpha, beta);
00125         assign<eT, 3>(y, acc3, alpha, beta);
00126         }
00127         break;
00128       
00129       
00130       default:
00131         ;
00132       }
00133     }
00134     
00135   };
00136 
00137 
00138 
00142 
00143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00144 class gemv_emul_large
00145   {
00146   public:
00147   
00148   template<typename eT>
00149   arma_hot
00150   inline
00151   static
00152   void
00153   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00154     {
00155     arma_extra_debug_sigprint();
00156     
00157     const uword A_n_rows = A.n_rows;
00158     const uword A_n_cols = A.n_cols;
00159     
00160     if(do_trans_A == false)
00161       {
00162       if(A_n_rows == 1)
00163         {
00164         const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x);
00165         
00166         if( (use_alpha == false) && (use_beta == false) )
00167           {
00168           y[0] = acc;
00169           }
00170         else
00171         if( (use_alpha == true) && (use_beta == false) )
00172           {
00173           y[0] = alpha * acc;
00174           }
00175         else
00176         if( (use_alpha == false) && (use_beta == true) )
00177           {
00178           y[0] = acc + beta*y[0];
00179           }
00180         else
00181         if( (use_alpha == true) && (use_beta == true) )
00182           {
00183           y[0] = alpha*acc + beta*y[0];
00184           }
00185         }
00186       else
00187       for(uword row=0; row < A_n_rows; ++row)
00188         {
00189         eT acc = eT(0);
00190         
00191         for(uword i=0; i < A_n_cols; ++i)
00192           {
00193           acc += A.at(row,i) * x[i];
00194           }
00195         
00196         if( (use_alpha == false) && (use_beta == false) )
00197           {
00198           y[row] = acc;
00199           }
00200         else
00201         if( (use_alpha == true) && (use_beta == false) )
00202           {
00203           y[row] = alpha * acc;
00204           }
00205         else
00206         if( (use_alpha == false) && (use_beta == true) )
00207           {
00208           y[row] = acc + beta*y[row];
00209           }
00210         else
00211         if( (use_alpha == true) && (use_beta == true) )
00212           {
00213           y[row] = alpha*acc + beta*y[row];
00214           }
00215         }
00216       }
00217     else
00218     if(do_trans_A == true)
00219       {
00220       for(uword col=0; col < A_n_cols; ++col)
00221         {
00222         // col is interpreted as row when storing the results in 'y'
00223         
00224         
00225         // const eT* A_coldata = A.colptr(col);
00226         // 
00227         // eT acc = eT(0);
00228         // for(uword row=0; row < A_n_rows; ++row)
00229         //   {
00230         //   acc += A_coldata[row] * x[row];
00231         //   }
00232         
00233         const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
00234         
00235         if( (use_alpha == false) && (use_beta == false) )
00236           {
00237           y[col] = acc;
00238           }
00239         else
00240         if( (use_alpha == true) && (use_beta == false) )
00241           {
00242           y[col] = alpha * acc;
00243           }
00244         else
00245         if( (use_alpha == false) && (use_beta == true) )
00246           {
00247           y[col] = acc + beta*y[col];
00248           }
00249         else
00250         if( (use_alpha == true) && (use_beta == true) )
00251           {
00252           y[col] = alpha*acc + beta*y[col];
00253           }
00254         
00255         }
00256       }
00257     }
00258   
00259   };
00260 
00261 
00262 
00263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00264 class gemv_emul
00265   {
00266   public:
00267   
00268   template<typename eT>
00269   arma_hot
00270   inline
00271   static
00272   void
00273   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
00274     {
00275     arma_extra_debug_sigprint();
00276     arma_ignore(junk);
00277     
00278     const uword A_n_rows = A.n_rows;
00279     const uword A_n_cols = A.n_cols;
00280     
00281     if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
00282       {
00283       gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
00284       }
00285     else
00286       {
00287       gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
00288       }
00289     }
00290   
00291   
00292   
00293   template<typename eT>
00294   arma_hot
00295   inline
00296   static
00297   void
00298   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
00299     {
00300     arma_extra_debug_sigprint();
00301     
00302     Mat<eT> tmp_A;
00303     
00304     if(do_trans_A)
00305       {
00306       op_htrans::apply_noalias(tmp_A, A);
00307       }
00308     
00309     const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
00310     
00311     const uword AA_n_rows = AA.n_rows;
00312     const uword AA_n_cols = AA.n_cols;
00313     
00314     if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
00315       {
00316       gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
00317       }
00318     else
00319       {
00320       gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
00321       }
00322     }
00323   };
00324 
00325 
00326 
00330 
00331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00332 class gemv
00333   {
00334   public:
00335   
00336   template<typename eT>
00337   inline
00338   static
00339   void
00340   apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00341     {
00342     arma_extra_debug_sigprint();
00343     
00344     if(A.n_elem <= 64u)
00345       {
00346       gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00347       }
00348     else
00349       {
00350       #if defined(ARMA_USE_ATLAS)
00351         {
00352         arma_extra_debug_print("atlas::cblas_gemv()");
00353         
00354         atlas::cblas_gemv<eT>
00355           (
00356           atlas::CblasColMajor,
00357           (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
00358           A.n_rows,
00359           A.n_cols,
00360           (use_alpha) ? alpha : eT(1),
00361           A.mem,
00362           A.n_rows,
00363           x,
00364           1,
00365           (use_beta) ? beta : eT(0),
00366           y,
00367           1
00368           );
00369         }
00370       #elif defined(ARMA_USE_BLAS)
00371         {
00372         arma_extra_debug_print("blas::gemv()");
00373         
00374         const char      trans_A     = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
00375         const blas_int  m           = A.n_rows;
00376         const blas_int  n           = A.n_cols;
00377         const eT        local_alpha = (use_alpha) ? alpha : eT(1);
00378         //const blas_int  lda         = A.n_rows;
00379         const blas_int  inc         = 1;
00380         const eT        local_beta  = (use_beta) ? beta : eT(0);
00381         
00382         arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
00383         
00384         blas::gemv<eT>
00385           (
00386           &trans_A,
00387           &m,
00388           &n,
00389           &local_alpha,
00390           A.mem,
00391           &m,  // lda
00392           x,
00393           &inc,
00394           &local_beta,
00395           y,
00396           &inc
00397           );
00398         }
00399       #else
00400         {
00401         gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00402         }
00403       #endif
00404       }
00405     
00406     }
00407   
00408   
00409   
00410   template<typename eT>
00411   arma_inline
00412   static
00413   void
00414   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00415     {
00416     gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00417     }
00418   
00419   
00420   
00421   arma_inline
00422   static
00423   void
00424   apply
00425     (
00426           float*      y,
00427     const Mat<float>& A,
00428     const float*      x,
00429     const float       alpha = float(1),
00430     const float       beta  = float(0)
00431     )
00432     {
00433     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00434     }
00435 
00436 
00437   
00438   arma_inline
00439   static
00440   void
00441   apply
00442     (
00443           double*      y,
00444     const Mat<double>& A,
00445     const double*      x,
00446     const double       alpha = double(1),
00447     const double       beta  = double(0)
00448     )
00449     {
00450     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00451     }
00452 
00453 
00454   
00455   arma_inline
00456   static
00457   void
00458   apply
00459     (
00460           std::complex<float>*         y,
00461     const Mat< std::complex<float > >& A,
00462     const std::complex<float>*         x,
00463     const std::complex<float>          alpha = std::complex<float>(1),
00464     const std::complex<float>          beta  = std::complex<float>(0)
00465     )
00466     {
00467     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00468     }
00469 
00470 
00471   
00472   arma_inline
00473   static
00474   void
00475   apply
00476     (
00477           std::complex<double>*        y,
00478     const Mat< std::complex<double> >& A,
00479     const std::complex<double>*        x,
00480     const std::complex<double>         alpha = std::complex<double>(1),
00481     const std::complex<double>         beta  = std::complex<double>(0)
00482     )
00483     {
00484     gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
00485     }
00486 
00487 
00488   
00489   };
00490 
00491 


armadillo_matrix
Author(s): Conrad Sanderson - NICTA (www.nicta.com.au), (Wrapper by Sjoerd van den Dries)
autogenerated on Tue Jan 7 2014 11:42:04