gemm.hpp
Go to the documentation of this file.
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
00002 // Copyright (C) 2008-2011 Conrad Sanderson
00003 // 
00004 // This file is part of the Armadillo C++ library.
00005 // It is provided without any warranty of fitness
00006 // for any purpose. You can redistribute this file
00007 // and/or modify it under the terms of the GNU
00008 // Lesser General Public License (LGPL) as published
00009 // by the Free Software Foundation, either version 3
00010 // of the License or (at your option) any later version.
00011 // (see http://www.opensource.org/licenses for more info)
00012 
00013 
00016 
00017 
00018 
00020 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00021 class gemm_emul_tinysq
00022   {
00023   public:
00024   
00025   
00026   template<typename eT>
00027   arma_hot
00028   inline
00029   static
00030   void
00031   apply
00032     (
00033           Mat<eT>& C,
00034     const Mat<eT>& A,
00035     const Mat<eT>& B,
00036     const eT alpha = eT(1),
00037     const eT beta  = eT(0)
00038     )
00039     {
00040     arma_extra_debug_sigprint();
00041     
00042     switch(A.n_rows)
00043       {
00044       case 4:
00045         gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
00046         
00047       case 3:
00048         gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
00049         
00050       case 2:
00051         gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
00052         
00053       case 1:
00054         gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
00055         
00056       default:
00057         ;
00058       }
00059     }
00060   
00061   };
00062 
00063 
00064 
00065 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00066 class gemm_emul_large
00067   {
00068   public:
00069   
00070   template<typename eT>
00071   arma_hot
00072   inline
00073   static
00074   void
00075   apply
00076     (
00077           Mat<eT>& C,
00078     const Mat<eT>& A,
00079     const Mat<eT>& B,
00080     const eT alpha = eT(1),
00081     const eT beta  = eT(0)
00082     )
00083     {
00084     arma_extra_debug_sigprint();
00085 
00086     const uword A_n_rows = A.n_rows;
00087     const uword A_n_cols = A.n_cols;
00088     
00089     const uword B_n_rows = B.n_rows;
00090     const uword B_n_cols = B.n_cols;
00091     
00092     if( (do_trans_A == false) && (do_trans_B == false) )
00093       {
00094       arma_aligned podarray<eT> tmp(A_n_cols);
00095       eT* A_rowdata = tmp.memptr();
00096       
00097       for(uword row_A=0; row_A < A_n_rows; ++row_A)
00098         {
00099         tmp.copy_row(A, row_A);
00100         
00101         for(uword col_B=0; col_B < B_n_cols; ++col_B)
00102           {
00103           const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
00104           
00105           if( (use_alpha == false) && (use_beta == false) )
00106             {
00107             C.at(row_A,col_B) = acc;
00108             }
00109           else
00110           if( (use_alpha == true) && (use_beta == false) )
00111             {
00112             C.at(row_A,col_B) = alpha * acc;
00113             }
00114           else
00115           if( (use_alpha == false) && (use_beta == true) )
00116             {
00117             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00118             }
00119           else
00120           if( (use_alpha == true) && (use_beta == true) )
00121             {
00122             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00123             }
00124           
00125           }
00126         }
00127       }
00128     else
00129     if( (do_trans_A == true) && (do_trans_B == false) )
00130       {
00131       for(uword col_A=0; col_A < A_n_cols; ++col_A)
00132         {
00133         // col_A is interpreted as row_A when storing the results in matrix C
00134         
00135         const eT* A_coldata = A.colptr(col_A);
00136         
00137         for(uword col_B=0; col_B < B_n_cols; ++col_B)
00138           {
00139           const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
00140           
00141           if( (use_alpha == false) && (use_beta == false) )
00142             {
00143             C.at(col_A,col_B) = acc;
00144             }
00145           else
00146           if( (use_alpha == true) && (use_beta == false) )
00147             {
00148             C.at(col_A,col_B) = alpha * acc;
00149             }
00150           else
00151           if( (use_alpha == false) && (use_beta == true) )
00152             {
00153             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00154             }
00155           else
00156           if( (use_alpha == true) && (use_beta == true) )
00157             {
00158             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00159             }
00160           
00161           }
00162         }
00163       }
00164     else
00165     if( (do_trans_A == false) && (do_trans_B == true) )
00166       {
00167       Mat<eT> BB;
00168       op_strans::apply_noalias(BB, B);
00169       
00170       gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
00171       }
00172     else
00173     if( (do_trans_A == true) && (do_trans_B == true) )
00174       {
00175       // mat B_tmp = trans(B);
00176       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00177       
00178       
00179       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00180       // transpose operations are not needed
00181       
00182       arma_aligned podarray<eT> tmp(B.n_cols);
00183       eT* B_rowdata = tmp.memptr();
00184       
00185       for(uword row_B=0; row_B < B_n_rows; ++row_B)
00186         {
00187         tmp.copy_row(B, row_B);
00188         
00189         for(uword col_A=0; col_A < A_n_cols; ++col_A)
00190           {
00191           const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
00192           
00193           if( (use_alpha == false) && (use_beta == false) )
00194             {
00195             C.at(col_A,row_B) = acc;
00196             }
00197           else
00198           if( (use_alpha == true) && (use_beta == false) )
00199             {
00200             C.at(col_A,row_B) = alpha * acc;
00201             }
00202           else
00203           if( (use_alpha == false) && (use_beta == true) )
00204             {
00205             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00206             }
00207           else
00208           if( (use_alpha == true) && (use_beta == true) )
00209             {
00210             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00211             }
00212           
00213           }
00214         }
00215       
00216       }
00217     }
00218   
00219   };
00220     
00221   
00222   
00223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00224 class gemm_emul
00225   {
00226   public:
00227   
00228   
00229   template<typename eT>
00230   arma_hot
00231   inline
00232   static
00233   void
00234   apply
00235     (
00236           Mat<eT>& C,
00237     const Mat<eT>& A,
00238     const Mat<eT>& B,
00239     const eT alpha = eT(1),
00240     const eT beta  = eT(0),
00241     const typename arma_not_cx<eT>::result* junk = 0
00242     )
00243     {
00244     arma_extra_debug_sigprint();
00245     arma_ignore(junk);
00246     
00247     const uword A_n_rows = A.n_rows;
00248     const uword A_n_cols = A.n_cols;
00249     
00250     const uword B_n_rows = B.n_rows;
00251     const uword B_n_cols = B.n_cols;
00252     
00253     if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
00254       {
00255       if(do_trans_B == false)
00256         {
00257         gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
00258         }
00259       else
00260         {
00261         Mat<eT> BB(A_n_rows, A_n_rows);
00262         op_strans::apply_noalias_tinysq(BB, B);
00263         
00264         gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
00265         }
00266       }
00267     else
00268       {
00269       gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
00270       }
00271     }
00272   
00273 
00274 
00275   template<typename eT>
00276   arma_hot
00277   inline
00278   static
00279   void
00280   apply
00281     (
00282           Mat<eT>& C,
00283     const Mat<eT>& A,
00284     const Mat<eT>& B,
00285     const eT alpha = eT(1),
00286     const eT beta  = eT(0),
00287     const typename arma_cx_only<eT>::result* junk = 0
00288     )
00289     {
00290     arma_extra_debug_sigprint();
00291     arma_ignore(junk);
00292     
00293     // "better than nothing" handling of hermitian transposes for complex number matrices
00294     
00295     Mat<eT> tmp_A;
00296     Mat<eT> tmp_B;
00297     
00298     if(do_trans_A)
00299       {
00300       op_htrans::apply_noalias(tmp_A, A);
00301       }
00302     
00303     if(do_trans_B)
00304       {
00305       op_htrans::apply_noalias(tmp_B, B);
00306       }
00307     
00308     const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
00309     const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
00310     
00311     const uword A_n_rows = AA.n_rows;
00312     const uword A_n_cols = AA.n_cols;
00313     
00314     const uword B_n_rows = BB.n_rows;
00315     const uword B_n_cols = BB.n_cols;
00316     
00317     if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
00318       {
00319       gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00320       }
00321     else
00322       {
00323       gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00324       }
00325     }
00326 
00327   };
00328 
00329 
00330 
00334 
00335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00336 class gemm
00337   {
00338   public:
00339   
00340   template<typename eT>
00341   inline
00342   static
00343   void
00344   apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00345     {
00346     arma_extra_debug_sigprint();
00347     
00348     if( (A.n_elem <= 48u) && (B.n_elem <= 48u) )
00349       {
00350       gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00351       }
00352     else
00353       {
00354       #if defined(ARMA_USE_ATLAS)
00355         {
00356         arma_extra_debug_print("atlas::cblas_gemm()");
00357         
00358         atlas::cblas_gemm<eT>
00359           (
00360           atlas::CblasColMajor,
00361           (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
00362           (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
00363           C.n_rows,
00364           C.n_cols,
00365           (do_trans_A) ? A.n_rows : A.n_cols,
00366           (use_alpha) ? alpha : eT(1),
00367           A.mem,
00368           (do_trans_A) ? A.n_rows : C.n_rows,
00369           B.mem,
00370           (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
00371           (use_beta) ? beta : eT(0),
00372           C.memptr(),
00373           C.n_rows
00374           );
00375         }
00376       #elif defined(ARMA_USE_BLAS)
00377         {
00378         arma_extra_debug_print("blas::gemm()");
00379         
00380         const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
00381         const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
00382         
00383         const blas_int m   = C.n_rows;
00384         const blas_int n   = C.n_cols;
00385         const blas_int k   = (do_trans_A) ? A.n_rows : A.n_cols;
00386         
00387         const eT local_alpha = (use_alpha) ? alpha : eT(1);
00388         
00389         const blas_int lda = (do_trans_A) ? k : m;
00390         const blas_int ldb = (do_trans_B) ? n : k;
00391         
00392         const eT local_beta  = (use_beta) ? beta : eT(0);
00393         
00394         arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
00395         arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
00396         
00397         blas::gemm<eT>
00398           (
00399           &trans_A,
00400           &trans_B,
00401           &m,
00402           &n,
00403           &k,
00404           &local_alpha,
00405           A.mem,
00406           &lda,
00407           B.mem,
00408           &ldb,
00409           &local_beta,
00410           C.memptr(),
00411           &m
00412           );
00413         }
00414       #else
00415         {
00416         gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00417         }
00418       #endif
00419       }
00420     }
00421   
00422   
00423   
00425   template<typename eT>
00426   inline
00427   static
00428   void
00429   apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00430     {
00431     gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00432     }
00433   
00434   
00435   
00436   arma_inline
00437   static
00438   void
00439   apply
00440     (
00441           Mat<float>& C,
00442     const Mat<float>& A,
00443     const Mat<float>& B,
00444     const float alpha = float(1),
00445     const float beta  = float(0)
00446     )
00447     {
00448     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00449     }
00450   
00451   
00452   
00453   arma_inline
00454   static
00455   void
00456   apply
00457     (
00458           Mat<double>& C,
00459     const Mat<double>& A,
00460     const Mat<double>& B,
00461     const double alpha = double(1),
00462     const double beta  = double(0)
00463     )
00464     {
00465     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00466     }
00467   
00468   
00469   
00470   arma_inline
00471   static
00472   void
00473   apply
00474     (
00475           Mat< std::complex<float> >& C,
00476     const Mat< std::complex<float> >& A,
00477     const Mat< std::complex<float> >& B,
00478     const std::complex<float> alpha = std::complex<float>(1),
00479     const std::complex<float> beta  = std::complex<float>(0)
00480     )
00481     {
00482     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00483     }
00484   
00485   
00486   
00487   arma_inline
00488   static
00489   void
00490   apply
00491     (
00492           Mat< std::complex<double> >& C,
00493     const Mat< std::complex<double> >& A,
00494     const Mat< std::complex<double> >& B,
00495     const std::complex<double> alpha = std::complex<double>(1),
00496     const std::complex<double> beta  = std::complex<double>(0)
00497     )
00498     {
00499     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00500     }
00501   
00502   };
00503 
00504 
00505 


armadillo_matrix
Author(s): Conrad Sanderson - NICTA (www.nicta.com.au), (Wrapper by Sjoerd van den Dries)
autogenerated on Tue Jan 7 2014 11:42:04