op_dot_meat.hpp
Go to the documentation of this file.
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
00002 // Copyright (C) 2008-2011 Conrad Sanderson
00003 // 
00004 // This file is part of the Armadillo C++ library.
00005 // It is provided without any warranty of fitness
00006 // for any purpose. You can redistribute this file
00007 // and/or modify it under the terms of the GNU
00008 // Lesser General Public License (LGPL) as published
00009 // by the Free Software Foundation, either version 3
00010 // of the License or (at your option) any later version.
00011 // (see http://www.opensource.org/licenses for more info)
00012 
00013 
00016 
00017 
00018 
00019 
00021 template<typename eT>
00022 arma_hot
00023 arma_pure
00024 inline
00025 eT
00026 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
00027   {
00028   arma_extra_debug_sigprint();
00029   
00030   eT val1 = eT(0);
00031   eT val2 = eT(0);
00032   
00033   uword i, j;
00034   
00035   for(i=0, j=1; j<n_elem; i+=2, j+=2)
00036     {
00037     val1 += A[i] * B[i];
00038     val2 += A[j] * B[j];
00039     }
00040   
00041   if(i < n_elem)
00042     {
00043     val1 += A[i] * B[i];
00044     }
00045   
00046   return val1 + val2;
00047   }
00048 
00049 
00050 
00052 template<typename eT>
00053 arma_hot
00054 arma_pure
00055 inline
00056 typename arma_float_only<eT>::result
00057 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
00058   {
00059   arma_extra_debug_sigprint();
00060   
00061   if( n_elem <= (128/sizeof(eT)) )
00062     {
00063     return op_dot::direct_dot_arma(n_elem, A, B);
00064     }
00065   else
00066     {
00067     #if defined(ARMA_USE_ATLAS)
00068       {
00069       arma_extra_debug_print("atlas::cblas_dot()");
00070       
00071       return atlas::cblas_dot(n_elem, A, B);
00072       }
00073     #elif defined(ARMA_USE_BLAS)
00074       {
00075       arma_extra_debug_print("blas::dot()");
00076       
00077       return blas::dot(n_elem, A, B);
00078       }
00079     #else
00080       {
00081       return op_dot::direct_dot_arma(n_elem, A, B);
00082       }
00083     #endif
00084     }
00085   }
00086 
00087 
00088 
00090 template<typename eT>
00091 inline
00092 arma_hot
00093 arma_pure
00094 typename arma_cx_only<eT>::result
00095 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
00096   {
00097   #if defined(ARMA_USE_ATLAS)
00098     {
00099     arma_extra_debug_print("atlas::cx_cblas_dot()");
00100     
00101     return atlas::cx_cblas_dot(n_elem, A, B);
00102     }
00103   #elif defined(ARMA_USE_BLAS)
00104     {
00105     // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS
00106     return op_dot::direct_dot_arma(n_elem, A, B);
00107     }
00108   #else
00109     {
00110     return op_dot::direct_dot_arma(n_elem, A, B);
00111     }
00112   #endif
00113   }
00114 
00115 
00116 
00118 template<typename eT>
00119 arma_hot
00120 arma_pure
00121 inline
00122 typename arma_integral_only<eT>::result
00123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
00124   {
00125   return op_dot::direct_dot_arma(n_elem, A, B);
00126   }
00127 
00128 
00129 
00130 
00132 template<typename eT>
00133 arma_hot
00134 arma_pure
00135 inline
00136 eT
00137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
00138   {
00139   arma_extra_debug_sigprint();
00140   
00141   eT val = eT(0);
00142   
00143   for(uword i=0; i<n_elem; ++i)
00144     {
00145     val += A[i] * B[i] * C[i];
00146     }
00147 
00148   return val;
00149   }
00150 
00151 
00152 
00153 template<typename T1, typename T2>
00154 arma_hot
00155 arma_inline
00156 typename T1::elem_type
00157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00158   {
00159   arma_extra_debug_sigprint();
00160   
00161   if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00162     {
00163     return op_dot::apply_unwrap(X,Y);
00164     }
00165   else
00166     {
00167     return op_dot::apply_proxy(X,Y);
00168     }
00169   }
00170 
00171 
00172 
00173 template<typename T1, typename T2>
00174 arma_hot
00175 arma_inline
00176 typename T1::elem_type
00177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00178   {
00179   arma_extra_debug_sigprint();
00180   
00181   typedef typename T1::elem_type eT;
00182   
00183   const unwrap<T1> tmp1(X.get_ref());
00184   const unwrap<T2> tmp2(Y.get_ref());
00185   
00186   const Mat<eT>& A = tmp1.M;
00187   const Mat<eT>& B = tmp2.M;
00188   
00189   arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00190   
00191   return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00192   }
00193 
00194 
00195 
00196 template<typename T1, typename T2>
00197 arma_hot
00198 inline
00199 typename T1::elem_type
00200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00201   {
00202   arma_extra_debug_sigprint();
00203   
00204   typedef typename T1::elem_type      eT;
00205   typedef typename Proxy<T1>::ea_type ea_type1;
00206   typedef typename Proxy<T2>::ea_type ea_type2;
00207   
00208   const Proxy<T1> A(X.get_ref());
00209   const Proxy<T2> B(Y.get_ref());
00210   
00211   const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
00212   
00213   if(prefer_at_accessor == false)
00214     {
00215     arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" );
00216   
00217     const uword    N  = A.get_n_elem();
00218           ea_type1 PA = A.get_ea();
00219           ea_type2 PB = B.get_ea();
00220     
00221     eT val1 = eT(0);
00222     eT val2 = eT(0);
00223     
00224     uword i,j;
00225     
00226     for(i=0, j=1; j<N; i+=2, j+=2)
00227       {
00228       val1 += PA[i] * PB[i];
00229       val2 += PA[j] * PB[j];
00230       }
00231     
00232     if(i < N)
00233       {
00234       val1 += PA[i] * PB[i];
00235       }
00236     
00237     return val1 + val2;
00238     }
00239   else
00240     {
00241     return op_dot::apply_unwrap(A.Q, B.Q);
00242     }
00243   }
00244 
00245 
00246 
00247 //
00248 // op_norm_dot
00249 
00250 
00251 
00252 template<typename T1, typename T2>
00253 arma_hot
00254 inline
00255 typename T1::elem_type
00256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00257   {
00258   arma_extra_debug_sigprint();
00259   
00260   typedef typename T1::elem_type      eT;
00261   typedef typename Proxy<T1>::ea_type ea_type1;
00262   typedef typename Proxy<T2>::ea_type ea_type2;
00263   
00264   const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
00265   
00266   if(prefer_at_accessor == false)
00267     {
00268     const Proxy<T1> A(X.get_ref());
00269     const Proxy<T2> B(Y.get_ref());
00270     
00271     arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" );
00272     
00273     const uword    N  = A.get_n_elem();
00274           ea_type1 PA = A.get_ea();
00275           ea_type2 PB = B.get_ea();
00276     
00277     eT acc1 = eT(0);
00278     eT acc2 = eT(0);
00279     eT acc3 = eT(0);
00280     
00281     for(uword i=0; i<N; ++i)
00282       {
00283       const eT tmpA = PA[i];
00284       const eT tmpB = PB[i];
00285       
00286       acc1 += tmpA * tmpA;
00287       acc2 += tmpB * tmpB;
00288       acc3 += tmpA * tmpB;
00289       }
00290       
00291     return acc3 / ( std::sqrt(acc1 * acc2) );
00292     }
00293   else
00294     {
00295     return op_norm_dot::apply_unwrap(X, Y);
00296     }
00297   }
00298 
00299 
00300 
00301 template<typename T1, typename T2>
00302 arma_hot
00303 inline
00304 typename T1::elem_type
00305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00306   {
00307   arma_extra_debug_sigprint();
00308   
00309   typedef typename T1::elem_type eT;
00310   
00311   const unwrap<T1> tmp1(X.get_ref());
00312   const unwrap<T2> tmp2(Y.get_ref());
00313   
00314   const Mat<eT>& A = tmp1.M;
00315   const Mat<eT>& B = tmp2.M;
00316   
00317   
00318   arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00319   
00320   const uword N = A.n_elem;
00321   
00322   const eT* A_mem = A.memptr();
00323   const eT* B_mem = B.memptr();
00324   
00325   eT acc1 = eT(0);
00326   eT acc2 = eT(0);
00327   eT acc3 = eT(0);
00328   
00329   for(uword i=0; i<N; ++i)
00330     {
00331     const eT tmpA = A_mem[i];
00332     const eT tmpB = B_mem[i];
00333     
00334     acc1 += tmpA * tmpA;
00335     acc2 += tmpB * tmpB;
00336     acc3 += tmpA * tmpB;
00337     }
00338     
00339   return acc3 / ( std::sqrt(acc1 * acc2) );
00340   }
00341 
00342 
00343 
00344 //
00345 // op_cdot
00346 
00347 
00348 
00349 template<typename T1, typename T2>
00350 arma_hot
00351 arma_inline
00352 typename T1::elem_type
00353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00354   {
00355   arma_extra_debug_sigprint();
00356   
00357   typedef typename T1::elem_type      eT;
00358   typedef typename Proxy<T1>::ea_type ea_type1;
00359   typedef typename Proxy<T2>::ea_type ea_type2;
00360   
00361   const Proxy<T1> A(X.get_ref());
00362   const Proxy<T2> B(Y.get_ref());
00363   
00364   arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" );
00365   
00366   const uword    N  = A.get_n_elem();
00367         ea_type1 PA = A.get_ea();
00368         ea_type2 PB = B.get_ea();
00369   
00370   eT val1 = eT(0);
00371   eT val2 = eT(0);
00372   
00373   uword i,j;
00374   for(i=0, j=1; j<N; i+=2, j+=2)
00375     {
00376     val1 += std::conj(PA[i]) * PB[i];
00377     val2 += std::conj(PA[j]) * PB[j];
00378     }
00379   
00380   if(i < N)
00381     {
00382     val1 += std::conj(PA[i]) * PB[i];
00383     }
00384   
00385   return val1 + val2;
00386   }
00387 
00388 
00389 


armadillo_matrix
Author(s): Conrad Sanderson - NICTA (www.nicta.com.au), (Wrapper by Sjoerd van den Dries)
autogenerated on Tue Jan 7 2014 11:42:05