gemm_mixed.hpp
Go to the documentation of this file.
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
00002 // Copyright (C) 2008-2011 Conrad Sanderson
00003 // 
00004 // This file is part of the Armadillo C++ library.
00005 // It is provided without any warranty of fitness
00006 // for any purpose. You can redistribute this file
00007 // and/or modify it under the terms of the GNU
00008 // Lesser General Public License (LGPL) as published
00009 // by the Free Software Foundation, either version 3
00010 // of the License or (at your option) any later version.
00011 // (see http://www.opensource.org/licenses for more info)
00012 
00013 
00016 
00017 
00018 
00023 
00024 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00025 class gemm_mixed_large
00026   {
00027   public:
00028   
00029   template<typename out_eT, typename in_eT1, typename in_eT2>
00030   arma_hot
00031   inline
00032   static
00033   void
00034   apply
00035     (
00036           Mat<out_eT>& C,
00037     const Mat<in_eT1>& A,
00038     const Mat<in_eT2>& B,
00039     const out_eT alpha = out_eT(1),
00040     const out_eT beta  = out_eT(0)
00041     )
00042     {
00043     arma_extra_debug_sigprint();
00044     
00045     const uword A_n_rows = A.n_rows;
00046     const uword A_n_cols = A.n_cols;
00047     
00048     const uword B_n_rows = B.n_rows;
00049     const uword B_n_cols = B.n_cols;
00050     
00051     if( (do_trans_A == false) && (do_trans_B == false) )
00052       {
00053       podarray<in_eT1> tmp(A_n_cols);
00054       in_eT1* A_rowdata = tmp.memptr();
00055       
00056       for(uword row_A=0; row_A < A_n_rows; ++row_A)
00057         {
00058         tmp.copy_row(A, row_A);
00059         
00060         for(uword col_B=0; col_B < B_n_cols; ++col_B)
00061           {
00062           const in_eT2* B_coldata = B.colptr(col_B);
00063           
00064           out_eT acc = out_eT(0);
00065           for(uword i=0; i < B_n_rows; ++i)
00066             {
00067             acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00068             }
00069         
00070           if( (use_alpha == false) && (use_beta == false) )
00071             {
00072             C.at(row_A,col_B) = acc;
00073             }
00074           else
00075           if( (use_alpha == true) && (use_beta == false) )
00076             {
00077             C.at(row_A,col_B) = alpha * acc;
00078             }
00079           else
00080           if( (use_alpha == false) && (use_beta == true) )
00081             {
00082             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00083             }
00084           else
00085           if( (use_alpha == true) && (use_beta == true) )
00086             {
00087             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00088             }
00089           
00090           }
00091         }
00092       }
00093     else
00094     if( (do_trans_A == true) && (do_trans_B == false) )
00095       {
00096       for(uword col_A=0; col_A < A_n_cols; ++col_A)
00097         {
00098         // col_A is interpreted as row_A when storing the results in matrix C
00099         
00100         const in_eT1* A_coldata = A.colptr(col_A);
00101         
00102         for(uword col_B=0; col_B < B_n_cols; ++col_B)
00103           {
00104           const in_eT2* B_coldata = B.colptr(col_B);
00105           
00106           out_eT acc = out_eT(0);
00107           for(uword i=0; i < B_n_rows; ++i)
00108             {
00109             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00110             }
00111         
00112           if( (use_alpha == false) && (use_beta == false) )
00113             {
00114             C.at(col_A,col_B) = acc;
00115             }
00116           else
00117           if( (use_alpha == true) && (use_beta == false) )
00118             {
00119             C.at(col_A,col_B) = alpha * acc;
00120             }
00121           else
00122           if( (use_alpha == false) && (use_beta == true) )
00123             {
00124             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00125             }
00126           else
00127           if( (use_alpha == true) && (use_beta == true) )
00128             {
00129             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00130             }
00131           
00132           }
00133         }
00134       }
00135     else
00136     if( (do_trans_A == false) && (do_trans_B == true) )
00137       {
00138       Mat<in_eT2> B_tmp;
00139       
00140       op_strans::apply_noalias(B_tmp, B);
00141       
00142       gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00143       }
00144     else
00145     if( (do_trans_A == true) && (do_trans_B == true) )
00146       {
00147       // mat B_tmp = trans(B);
00148       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00149       
00150       
00151       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00152       // transpose operations are not needed
00153       
00154       podarray<in_eT2> tmp(B_n_cols);
00155       in_eT2* B_rowdata = tmp.memptr();
00156       
00157       for(uword row_B=0; row_B < B_n_rows; ++row_B)
00158         {
00159         tmp.copy_row(B, row_B);
00160         
00161         for(uword col_A=0; col_A < A_n_cols; ++col_A)
00162           {
00163           const in_eT1* A_coldata = A.colptr(col_A);
00164           
00165           out_eT acc = out_eT(0);
00166           for(uword i=0; i < A_n_rows; ++i)
00167             {
00168             acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00169             }
00170         
00171           if( (use_alpha == false) && (use_beta == false) )
00172             {
00173             C.at(col_A,row_B) = acc;
00174             }
00175           else
00176           if( (use_alpha == true) && (use_beta == false) )
00177             {
00178             C.at(col_A,row_B) = alpha * acc;
00179             }
00180           else
00181           if( (use_alpha == false) && (use_beta == true) )
00182             {
00183             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00184             }
00185           else
00186           if( (use_alpha == true) && (use_beta == true) )
00187             {
00188             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00189             }
00190           
00191           }
00192         }
00193       
00194       }
00195     }
00196     
00197   };
00198 
00199 
00200 
00204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00205 class gemm_mixed_small
00206   {
00207   public:
00208   
00209   template<typename out_eT, typename in_eT1, typename in_eT2>
00210   arma_hot
00211   inline
00212   static
00213   void
00214   apply
00215     (
00216           Mat<out_eT>& C,
00217     const Mat<in_eT1>& A,
00218     const Mat<in_eT2>& B,
00219     const out_eT alpha = out_eT(1),
00220     const out_eT beta  = out_eT(0)
00221     )
00222     {
00223     arma_extra_debug_sigprint();
00224     
00225     const uword A_n_rows = A.n_rows;
00226     const uword A_n_cols = A.n_cols;
00227     
00228     const uword B_n_rows = B.n_rows;
00229     const uword B_n_cols = B.n_cols;
00230     
00231     if( (do_trans_A == false) && (do_trans_B == false) )
00232       {
00233       for(uword row_A = 0; row_A < A_n_rows; ++row_A)
00234         {
00235         for(uword col_B = 0; col_B < B_n_cols; ++col_B)
00236           {
00237           const in_eT2* B_coldata = B.colptr(col_B);
00238           
00239           out_eT acc = out_eT(0);
00240           for(uword i = 0; i < B_n_rows; ++i)
00241             {
00242             const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00243             const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00244             acc += val1 * val2;
00245             //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00246             }
00247           
00248           if( (use_alpha == false) && (use_beta == false) )
00249             {
00250             C.at(row_A,col_B) = acc;
00251             }
00252           else
00253           if( (use_alpha == true) && (use_beta == false) )
00254             {
00255             C.at(row_A,col_B) = alpha * acc;
00256             }
00257           else
00258           if( (use_alpha == false) && (use_beta == true) )
00259             {
00260             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00261             }
00262           else
00263           if( (use_alpha == true) && (use_beta == true) )
00264             {
00265             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00266             }
00267           }
00268         }
00269       }
00270     else
00271     if( (do_trans_A == true) && (do_trans_B == false) )
00272       {
00273       for(uword col_A=0; col_A < A_n_cols; ++col_A)
00274         {
00275         // col_A is interpreted as row_A when storing the results in matrix C
00276         
00277         const in_eT1* A_coldata = A.colptr(col_A);
00278         
00279         for(uword col_B=0; col_B < B_n_cols; ++col_B)
00280           {
00281           const in_eT2* B_coldata = B.colptr(col_B);
00282           
00283           out_eT acc = out_eT(0);
00284           for(uword i=0; i < B_n_rows; ++i)
00285             {
00286             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00287             }
00288         
00289           if( (use_alpha == false) && (use_beta == false) )
00290             {
00291             C.at(col_A,col_B) = acc;
00292             }
00293           else
00294           if( (use_alpha == true) && (use_beta == false) )
00295             {
00296             C.at(col_A,col_B) = alpha * acc;
00297             }
00298           else
00299           if( (use_alpha == false) && (use_beta == true) )
00300             {
00301             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00302             }
00303           else
00304           if( (use_alpha == true) && (use_beta == true) )
00305             {
00306             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00307             }
00308           
00309           }
00310         }
00311       }
00312     else
00313     if( (do_trans_A == false) && (do_trans_B == true) )
00314       {
00315       for(uword row_A = 0; row_A < A_n_rows; ++row_A)
00316         {
00317         for(uword row_B = 0; row_B < B_n_rows; ++row_B)
00318           {
00319           out_eT acc = out_eT(0);
00320           for(uword i = 0; i < B_n_cols; ++i)
00321             {
00322             acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00323             }
00324           
00325           if( (use_alpha == false) && (use_beta == false) )
00326             {
00327             C.at(row_A,row_B) = acc;
00328             }
00329           else
00330           if( (use_alpha == true) && (use_beta == false) )
00331             {
00332             C.at(row_A,row_B) = alpha * acc;
00333             }
00334           else
00335           if( (use_alpha == false) && (use_beta == true) )
00336             {
00337             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00338             }
00339           else
00340           if( (use_alpha == true) && (use_beta == true) )
00341             {
00342             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00343             }
00344           }
00345         }
00346       }
00347     else
00348     if( (do_trans_A == true) && (do_trans_B == true) )
00349       {
00350       for(uword row_B=0; row_B < B_n_rows; ++row_B)
00351         {
00352         
00353         for(uword col_A=0; col_A < A_n_cols; ++col_A)
00354           {
00355           const in_eT1* A_coldata = A.colptr(col_A);
00356           
00357           out_eT acc = out_eT(0);
00358           for(uword i=0; i < A_n_rows; ++i)
00359             {
00360             acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00361             }
00362         
00363           if( (use_alpha == false) && (use_beta == false) )
00364             {
00365             C.at(col_A,row_B) = acc;
00366             }
00367           else
00368           if( (use_alpha == true) && (use_beta == false) )
00369             {
00370             C.at(col_A,row_B) = alpha * acc;
00371             }
00372           else
00373           if( (use_alpha == false) && (use_beta == true) )
00374             {
00375             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00376             }
00377           else
00378           if( (use_alpha == true) && (use_beta == true) )
00379             {
00380             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00381             }
00382           
00383           }
00384         }
00385       
00386       }
00387     }
00388     
00389   };
00390 
00391 
00392 
00393 
00394 
00397 
00398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00399 class gemm_mixed
00400   {
00401   public:
00402   
00404   template<typename out_eT, typename in_eT1, typename in_eT2>
00405   inline
00406   static
00407   void
00408   apply
00409     (
00410           Mat<out_eT>& C,
00411     const Mat<in_eT1>& A,
00412     const Mat<in_eT2>& B,
00413     const out_eT alpha = out_eT(1),
00414     const out_eT beta  = out_eT(0)
00415     )
00416     {
00417     arma_extra_debug_sigprint();
00418     
00419     Mat<in_eT1> tmp_A;
00420     Mat<in_eT2> tmp_B;
00421     
00422     const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
00423     const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
00424     
00425     if(do_trans_A)
00426       {
00427       op_htrans::apply_noalias(tmp_A, A);
00428       }
00429     
00430     if(do_trans_B)
00431       {
00432       op_htrans::apply_noalias(tmp_B, B);
00433       }
00434      
00435     const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
00436     const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
00437     
00438     if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
00439       {
00440       gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00441       }
00442     else
00443       {
00444       gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
00445       }
00446     }
00447   
00448   
00449   };
00450 
00451 
00452 


armadillo_matrix
Author(s): Conrad Sanderson - NICTA (www.nicta.com.au), (Wrapper by Sjoerd van den Dries)
autogenerated on Tue Jan 7 2014 11:42:04