$search
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2008-2011 Conrad Sanderson 00003 // 00004 // This file is part of the Armadillo C++ library. 00005 // It is provided without any warranty of fitness 00006 // for any purpose. You can redistribute this file 00007 // and/or modify it under the terms of the GNU 00008 // Lesser General Public License (LGPL) as published 00009 // by the Free Software Foundation, either version 3 00010 // of the License or (at your option) any later version. 00011 // (see http://www.opensource.org/licenses for more info) 00012 00013 00016 00017 00018 00019 00021 template<typename eT> 00022 arma_hot 00023 arma_pure 00024 inline 00025 eT 00026 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) 00027 { 00028 arma_extra_debug_sigprint(); 00029 00030 eT val1 = eT(0); 00031 eT val2 = eT(0); 00032 00033 uword i, j; 00034 00035 for(i=0, j=1; j<n_elem; i+=2, j+=2) 00036 { 00037 val1 += A[i] * B[i]; 00038 val2 += A[j] * B[j]; 00039 } 00040 00041 if(i < n_elem) 00042 { 00043 val1 += A[i] * B[i]; 00044 } 00045 00046 return val1 + val2; 00047 } 00048 00049 00050 00052 template<typename eT> 00053 arma_hot 00054 arma_pure 00055 inline 00056 typename arma_float_only<eT>::result 00057 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) 00058 { 00059 arma_extra_debug_sigprint(); 00060 00061 if( n_elem <= (128/sizeof(eT)) ) 00062 { 00063 return op_dot::direct_dot_arma(n_elem, A, B); 00064 } 00065 else 00066 { 00067 #if defined(ARMA_USE_ATLAS) 00068 { 00069 arma_extra_debug_print("atlas::cblas_dot()"); 00070 00071 return atlas::cblas_dot(n_elem, A, B); 00072 } 00073 #elif defined(ARMA_USE_BLAS) 00074 { 00075 arma_extra_debug_print("blas::dot()"); 00076 00077 return blas::dot(n_elem, A, B); 00078 } 00079 #else 00080 { 00081 return op_dot::direct_dot_arma(n_elem, A, B); 00082 } 00083 #endif 00084 } 00085 } 00086 00087 00088 00090 template<typename eT> 00091 inline 00092 arma_hot 00093 arma_pure 00094 typename arma_cx_only<eT>::result 00095 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) 00096 { 00097 #if defined(ARMA_USE_ATLAS) 00098 { 00099 arma_extra_debug_print("atlas::cx_cblas_dot()"); 00100 00101 return atlas::cx_cblas_dot(n_elem, A, B); 00102 } 00103 #elif defined(ARMA_USE_BLAS) 00104 { 00105 // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS 00106 return op_dot::direct_dot_arma(n_elem, A, B); 00107 } 00108 #else 00109 { 00110 return op_dot::direct_dot_arma(n_elem, A, B); 00111 } 00112 #endif 00113 } 00114 00115 00116 00118 template<typename eT> 00119 arma_hot 00120 arma_pure 00121 inline 00122 typename arma_integral_only<eT>::result 00123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) 00124 { 00125 return op_dot::direct_dot_arma(n_elem, A, B); 00126 } 00127 00128 00129 00130 00132 template<typename eT> 00133 arma_hot 00134 arma_pure 00135 inline 00136 eT 00137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) 00138 { 00139 arma_extra_debug_sigprint(); 00140 00141 eT val = eT(0); 00142 00143 for(uword i=0; i<n_elem; ++i) 00144 { 00145 val += A[i] * B[i] * C[i]; 00146 } 00147 00148 return val; 00149 } 00150 00151 00152 00153 template<typename T1, typename T2> 00154 arma_hot 00155 arma_inline 00156 typename T1::elem_type 00157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) 00158 { 00159 arma_extra_debug_sigprint(); 00160 00161 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) ) 00162 { 00163 return op_dot::apply_unwrap(X,Y); 00164 } 00165 else 00166 { 00167 return op_dot::apply_proxy(X,Y); 00168 } 00169 } 00170 00171 00172 00173 template<typename T1, typename T2> 00174 arma_hot 00175 arma_inline 00176 typename T1::elem_type 00177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) 00178 { 00179 arma_extra_debug_sigprint(); 00180 00181 typedef typename T1::elem_type eT; 00182 00183 const unwrap<T1> tmp1(X.get_ref()); 00184 const unwrap<T2> tmp2(Y.get_ref()); 00185 00186 const Mat<eT>& A = tmp1.M; 00187 const Mat<eT>& B = tmp2.M; 00188 00189 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); 00190 00191 return op_dot::direct_dot(A.n_elem, A.mem, B.mem); 00192 } 00193 00194 00195 00196 template<typename T1, typename T2> 00197 arma_hot 00198 inline 00199 typename T1::elem_type 00200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) 00201 { 00202 arma_extra_debug_sigprint(); 00203 00204 typedef typename T1::elem_type eT; 00205 typedef typename Proxy<T1>::ea_type ea_type1; 00206 typedef typename Proxy<T2>::ea_type ea_type2; 00207 00208 const Proxy<T1> A(X.get_ref()); 00209 const Proxy<T2> B(Y.get_ref()); 00210 00211 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor); 00212 00213 if(prefer_at_accessor == false) 00214 { 00215 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" ); 00216 00217 const uword N = A.get_n_elem(); 00218 ea_type1 PA = A.get_ea(); 00219 ea_type2 PB = B.get_ea(); 00220 00221 eT val1 = eT(0); 00222 eT val2 = eT(0); 00223 00224 uword i,j; 00225 00226 for(i=0, j=1; j<N; i+=2, j+=2) 00227 { 00228 val1 += PA[i] * PB[i]; 00229 val2 += PA[j] * PB[j]; 00230 } 00231 00232 if(i < N) 00233 { 00234 val1 += PA[i] * PB[i]; 00235 } 00236 00237 return val1 + val2; 00238 } 00239 else 00240 { 00241 return op_dot::apply_unwrap(A.Q, B.Q); 00242 } 00243 } 00244 00245 00246 00247 // 00248 // op_norm_dot 00249 00250 00251 00252 template<typename T1, typename T2> 00253 arma_hot 00254 inline 00255 typename T1::elem_type 00256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) 00257 { 00258 arma_extra_debug_sigprint(); 00259 00260 typedef typename T1::elem_type eT; 00261 typedef typename Proxy<T1>::ea_type ea_type1; 00262 typedef typename Proxy<T2>::ea_type ea_type2; 00263 00264 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor); 00265 00266 if(prefer_at_accessor == false) 00267 { 00268 const Proxy<T1> A(X.get_ref()); 00269 const Proxy<T2> B(Y.get_ref()); 00270 00271 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" ); 00272 00273 const uword N = A.get_n_elem(); 00274 ea_type1 PA = A.get_ea(); 00275 ea_type2 PB = B.get_ea(); 00276 00277 eT acc1 = eT(0); 00278 eT acc2 = eT(0); 00279 eT acc3 = eT(0); 00280 00281 for(uword i=0; i<N; ++i) 00282 { 00283 const eT tmpA = PA[i]; 00284 const eT tmpB = PB[i]; 00285 00286 acc1 += tmpA * tmpA; 00287 acc2 += tmpB * tmpB; 00288 acc3 += tmpA * tmpB; 00289 } 00290 00291 return acc3 / ( std::sqrt(acc1 * acc2) ); 00292 } 00293 else 00294 { 00295 return op_norm_dot::apply_unwrap(X, Y); 00296 } 00297 } 00298 00299 00300 00301 template<typename T1, typename T2> 00302 arma_hot 00303 inline 00304 typename T1::elem_type 00305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) 00306 { 00307 arma_extra_debug_sigprint(); 00308 00309 typedef typename T1::elem_type eT; 00310 00311 const unwrap<T1> tmp1(X.get_ref()); 00312 const unwrap<T2> tmp2(Y.get_ref()); 00313 00314 const Mat<eT>& A = tmp1.M; 00315 const Mat<eT>& B = tmp2.M; 00316 00317 00318 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); 00319 00320 const uword N = A.n_elem; 00321 00322 const eT* A_mem = A.memptr(); 00323 const eT* B_mem = B.memptr(); 00324 00325 eT acc1 = eT(0); 00326 eT acc2 = eT(0); 00327 eT acc3 = eT(0); 00328 00329 for(uword i=0; i<N; ++i) 00330 { 00331 const eT tmpA = A_mem[i]; 00332 const eT tmpB = B_mem[i]; 00333 00334 acc1 += tmpA * tmpA; 00335 acc2 += tmpB * tmpB; 00336 acc3 += tmpA * tmpB; 00337 } 00338 00339 return acc3 / ( std::sqrt(acc1 * acc2) ); 00340 } 00341 00342 00343 00344 // 00345 // op_cdot 00346 00347 00348 00349 template<typename T1, typename T2> 00350 arma_hot 00351 arma_inline 00352 typename T1::elem_type 00353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) 00354 { 00355 arma_extra_debug_sigprint(); 00356 00357 typedef typename T1::elem_type eT; 00358 typedef typename Proxy<T1>::ea_type ea_type1; 00359 typedef typename Proxy<T2>::ea_type ea_type2; 00360 00361 const Proxy<T1> A(X.get_ref()); 00362 const Proxy<T2> B(Y.get_ref()); 00363 00364 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" ); 00365 00366 const uword N = A.get_n_elem(); 00367 ea_type1 PA = A.get_ea(); 00368 ea_type2 PB = B.get_ea(); 00369 00370 eT val1 = eT(0); 00371 eT val2 = eT(0); 00372 00373 uword i,j; 00374 for(i=0, j=1; j<N; i+=2, j+=2) 00375 { 00376 val1 += std::conj(PA[i]) * PB[i]; 00377 val2 += std::conj(PA[j]) * PB[j]; 00378 } 00379 00380 if(i < N) 00381 { 00382 val1 += std::conj(PA[i]) * PB[i]; 00383 } 00384 00385 return val1 + val2; 00386 } 00387 00388 00389