00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00016
00017
00018
00019
00021 template<typename eT>
00022 arma_hot
00023 arma_pure
00024 inline
00025 eT
00026 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
00027 {
00028 arma_extra_debug_sigprint();
00029
00030 eT val1 = eT(0);
00031 eT val2 = eT(0);
00032
00033 uword i, j;
00034
00035 for(i=0, j=1; j<n_elem; i+=2, j+=2)
00036 {
00037 val1 += A[i] * B[i];
00038 val2 += A[j] * B[j];
00039 }
00040
00041 if(i < n_elem)
00042 {
00043 val1 += A[i] * B[i];
00044 }
00045
00046 return val1 + val2;
00047 }
00048
00049
00050
00052 template<typename eT>
00053 arma_hot
00054 arma_pure
00055 inline
00056 typename arma_float_only<eT>::result
00057 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
00058 {
00059 arma_extra_debug_sigprint();
00060
00061 if( n_elem <= (128/sizeof(eT)) )
00062 {
00063 return op_dot::direct_dot_arma(n_elem, A, B);
00064 }
00065 else
00066 {
00067 #if defined(ARMA_USE_ATLAS)
00068 {
00069 arma_extra_debug_print("atlas::cblas_dot()");
00070
00071 return atlas::cblas_dot(n_elem, A, B);
00072 }
00073 #elif defined(ARMA_USE_BLAS)
00074 {
00075 arma_extra_debug_print("blas::dot()");
00076
00077 return blas::dot(n_elem, A, B);
00078 }
00079 #else
00080 {
00081 return op_dot::direct_dot_arma(n_elem, A, B);
00082 }
00083 #endif
00084 }
00085 }
00086
00087
00088
00090 template<typename eT>
00091 inline
00092 arma_hot
00093 arma_pure
00094 typename arma_cx_only<eT>::result
00095 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
00096 {
00097 #if defined(ARMA_USE_ATLAS)
00098 {
00099 arma_extra_debug_print("atlas::cx_cblas_dot()");
00100
00101 return atlas::cx_cblas_dot(n_elem, A, B);
00102 }
00103 #elif defined(ARMA_USE_BLAS)
00104 {
00105
00106 return op_dot::direct_dot_arma(n_elem, A, B);
00107 }
00108 #else
00109 {
00110 return op_dot::direct_dot_arma(n_elem, A, B);
00111 }
00112 #endif
00113 }
00114
00115
00116
00118 template<typename eT>
00119 arma_hot
00120 arma_pure
00121 inline
00122 typename arma_integral_only<eT>::result
00123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
00124 {
00125 return op_dot::direct_dot_arma(n_elem, A, B);
00126 }
00127
00128
00129
00130
00132 template<typename eT>
00133 arma_hot
00134 arma_pure
00135 inline
00136 eT
00137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
00138 {
00139 arma_extra_debug_sigprint();
00140
00141 eT val = eT(0);
00142
00143 for(uword i=0; i<n_elem; ++i)
00144 {
00145 val += A[i] * B[i] * C[i];
00146 }
00147
00148 return val;
00149 }
00150
00151
00152
00153 template<typename T1, typename T2>
00154 arma_hot
00155 arma_inline
00156 typename T1::elem_type
00157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00158 {
00159 arma_extra_debug_sigprint();
00160
00161 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00162 {
00163 return op_dot::apply_unwrap(X,Y);
00164 }
00165 else
00166 {
00167 return op_dot::apply_proxy(X,Y);
00168 }
00169 }
00170
00171
00172
00173 template<typename T1, typename T2>
00174 arma_hot
00175 arma_inline
00176 typename T1::elem_type
00177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00178 {
00179 arma_extra_debug_sigprint();
00180
00181 typedef typename T1::elem_type eT;
00182
00183 const unwrap<T1> tmp1(X.get_ref());
00184 const unwrap<T2> tmp2(Y.get_ref());
00185
00186 const Mat<eT>& A = tmp1.M;
00187 const Mat<eT>& B = tmp2.M;
00188
00189 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00190
00191 return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00192 }
00193
00194
00195
00196 template<typename T1, typename T2>
00197 arma_hot
00198 inline
00199 typename T1::elem_type
00200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00201 {
00202 arma_extra_debug_sigprint();
00203
00204 typedef typename T1::elem_type eT;
00205 typedef typename Proxy<T1>::ea_type ea_type1;
00206 typedef typename Proxy<T2>::ea_type ea_type2;
00207
00208 const Proxy<T1> A(X.get_ref());
00209 const Proxy<T2> B(Y.get_ref());
00210
00211 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
00212
00213 if(prefer_at_accessor == false)
00214 {
00215 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" );
00216
00217 const uword N = A.get_n_elem();
00218 ea_type1 PA = A.get_ea();
00219 ea_type2 PB = B.get_ea();
00220
00221 eT val1 = eT(0);
00222 eT val2 = eT(0);
00223
00224 uword i,j;
00225
00226 for(i=0, j=1; j<N; i+=2, j+=2)
00227 {
00228 val1 += PA[i] * PB[i];
00229 val2 += PA[j] * PB[j];
00230 }
00231
00232 if(i < N)
00233 {
00234 val1 += PA[i] * PB[i];
00235 }
00236
00237 return val1 + val2;
00238 }
00239 else
00240 {
00241 return op_dot::apply_unwrap(A.Q, B.Q);
00242 }
00243 }
00244
00245
00246
00247
00248
00249
00250
00251
00252 template<typename T1, typename T2>
00253 arma_hot
00254 inline
00255 typename T1::elem_type
00256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00257 {
00258 arma_extra_debug_sigprint();
00259
00260 typedef typename T1::elem_type eT;
00261 typedef typename Proxy<T1>::ea_type ea_type1;
00262 typedef typename Proxy<T2>::ea_type ea_type2;
00263
00264 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
00265
00266 if(prefer_at_accessor == false)
00267 {
00268 const Proxy<T1> A(X.get_ref());
00269 const Proxy<T2> B(Y.get_ref());
00270
00271 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" );
00272
00273 const uword N = A.get_n_elem();
00274 ea_type1 PA = A.get_ea();
00275 ea_type2 PB = B.get_ea();
00276
00277 eT acc1 = eT(0);
00278 eT acc2 = eT(0);
00279 eT acc3 = eT(0);
00280
00281 for(uword i=0; i<N; ++i)
00282 {
00283 const eT tmpA = PA[i];
00284 const eT tmpB = PB[i];
00285
00286 acc1 += tmpA * tmpA;
00287 acc2 += tmpB * tmpB;
00288 acc3 += tmpA * tmpB;
00289 }
00290
00291 return acc3 / ( std::sqrt(acc1 * acc2) );
00292 }
00293 else
00294 {
00295 return op_norm_dot::apply_unwrap(X, Y);
00296 }
00297 }
00298
00299
00300
00301 template<typename T1, typename T2>
00302 arma_hot
00303 inline
00304 typename T1::elem_type
00305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00306 {
00307 arma_extra_debug_sigprint();
00308
00309 typedef typename T1::elem_type eT;
00310
00311 const unwrap<T1> tmp1(X.get_ref());
00312 const unwrap<T2> tmp2(Y.get_ref());
00313
00314 const Mat<eT>& A = tmp1.M;
00315 const Mat<eT>& B = tmp2.M;
00316
00317
00318 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00319
00320 const uword N = A.n_elem;
00321
00322 const eT* A_mem = A.memptr();
00323 const eT* B_mem = B.memptr();
00324
00325 eT acc1 = eT(0);
00326 eT acc2 = eT(0);
00327 eT acc3 = eT(0);
00328
00329 for(uword i=0; i<N; ++i)
00330 {
00331 const eT tmpA = A_mem[i];
00332 const eT tmpB = B_mem[i];
00333
00334 acc1 += tmpA * tmpA;
00335 acc2 += tmpB * tmpB;
00336 acc3 += tmpA * tmpB;
00337 }
00338
00339 return acc3 / ( std::sqrt(acc1 * acc2) );
00340 }
00341
00342
00343
00344
00345
00346
00347
00348
00349 template<typename T1, typename T2>
00350 arma_hot
00351 arma_inline
00352 typename T1::elem_type
00353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00354 {
00355 arma_extra_debug_sigprint();
00356
00357 typedef typename T1::elem_type eT;
00358 typedef typename Proxy<T1>::ea_type ea_type1;
00359 typedef typename Proxy<T2>::ea_type ea_type2;
00360
00361 const Proxy<T1> A(X.get_ref());
00362 const Proxy<T2> B(Y.get_ref());
00363
00364 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" );
00365
00366 const uword N = A.get_n_elem();
00367 ea_type1 PA = A.get_ea();
00368 ea_type2 PB = B.get_ea();
00369
00370 eT val1 = eT(0);
00371 eT val2 = eT(0);
00372
00373 uword i,j;
00374 for(i=0, j=1; j<N; i+=2, j+=2)
00375 {
00376 val1 += std::conj(PA[i]) * PB[i];
00377 val2 += std::conj(PA[j]) * PB[j];
00378 }
00379
00380 if(i < N)
00381 {
00382 val1 += std::conj(PA[i]) * PB[i];
00383 }
00384
00385 return val1 + val2;
00386 }
00387
00388
00389