00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00016
00017
00018
00019 template<uword N>
00020 template<typename T1, typename T2>
00021 inline
00022 void
00023 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00024 {
00025 arma_extra_debug_sigprint();
00026
00027 typedef typename T1::elem_type eT;
00028
00029 const partial_unwrap_check<T1> tmp1(X.A, out);
00030 const partial_unwrap_check<T2> tmp2(X.B, out);
00031
00032 const Mat<eT>& A = tmp1.M;
00033 const Mat<eT>& B = tmp2.M;
00034
00035 const bool do_trans_A = tmp1.do_trans;
00036 const bool do_trans_B = tmp2.do_trans;
00037
00038 const bool use_alpha = tmp1.do_times || tmp2.do_times;
00039 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
00040
00041 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00042 }
00043
00044
00045
00046 template<typename T1, typename T2, typename T3>
00047 inline
00048 void
00049 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
00050 {
00051 arma_extra_debug_sigprint();
00052
00053 typedef typename T1::elem_type eT;
00054
00055
00056
00057
00058 const partial_unwrap_check<T1> tmp1(X.A.A, out);
00059 const partial_unwrap_check<T2> tmp2(X.A.B, out);
00060 const partial_unwrap_check<T3> tmp3(X.B, out);
00061
00062 const Mat<eT>& A = tmp1.M;
00063 const Mat<eT>& B = tmp2.M;
00064 const Mat<eT>& C = tmp3.M;
00065
00066 const bool do_trans_A = tmp1.do_trans;
00067 const bool do_trans_B = tmp2.do_trans;
00068 const bool do_trans_C = tmp3.do_trans;
00069
00070 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times;
00071 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
00072
00073 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00074 }
00075
00076
00077
00078 template<typename T1, typename T2, typename T3, typename T4>
00079 inline
00080 void
00081 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
00082 {
00083 arma_extra_debug_sigprint();
00084
00085 typedef typename T1::elem_type eT;
00086
00087
00088
00089
00090 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
00091 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
00092 const partial_unwrap_check<T3> tmp3(X.A.B, out);
00093 const partial_unwrap_check<T4> tmp4(X.B, out);
00094
00095 const Mat<eT>& A = tmp1.M;
00096 const Mat<eT>& B = tmp2.M;
00097 const Mat<eT>& C = tmp3.M;
00098 const Mat<eT>& D = tmp4.M;
00099
00100 const bool do_trans_A = tmp1.do_trans;
00101 const bool do_trans_B = tmp2.do_trans;
00102 const bool do_trans_C = tmp3.do_trans;
00103 const bool do_trans_D = tmp4.do_trans;
00104
00105 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times;
00106 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
00107
00108 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00109 }
00110
00111
00112
00113 template<typename T1, typename T2>
00114 inline
00115 void
00116 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00117 {
00118 arma_extra_debug_sigprint();
00119
00120 typedef typename T1::elem_type eT;
00121
00122 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00123
00124 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00125
00126 glue_times_redirect<N_mat>::apply(out, X);
00127 }
00128
00129
00130
00131 template<typename T1>
00132 inline
00133 void
00134 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
00135 {
00136 arma_extra_debug_sigprint();
00137
00138 typedef typename T1::elem_type eT;
00139
00140 const unwrap_check<T1> tmp(X, out);
00141 const Mat<eT>& B = tmp.M;
00142
00143 arma_debug_assert_mul_size(out, B, "matrix multiplication");
00144
00145 const uword out_n_rows = out.n_rows;
00146 const uword out_n_cols = out.n_cols;
00147
00148 if(out_n_cols == B.n_cols)
00149 {
00150
00151
00152 podarray<eT> tmp(out_n_cols);
00153
00154 eT* tmp_rowdata = tmp.memptr();
00155
00156 for(uword row=0; row < out_n_rows; ++row)
00157 {
00158 tmp.copy_row(out, row);
00159
00160 for(uword col=0; col < out_n_cols; ++col)
00161 {
00162 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) );
00163 }
00164 }
00165
00166 }
00167 else
00168 {
00169 const Mat<eT> tmp(out);
00170 glue_times::apply(out, tmp, B, eT(1), false, false, false);
00171 }
00172
00173 }
00174
00175
00176
00177 template<typename T1, typename T2>
00178 arma_hot
00179 inline
00180 void
00181 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
00182 {
00183 arma_extra_debug_sigprint();
00184
00185 typedef typename T1::elem_type eT;
00186
00187 const partial_unwrap_check<T1> tmp1(X.A, out);
00188 const partial_unwrap_check<T2> tmp2(X.B, out);
00189
00190 const Mat<eT>& A = tmp1.M;
00191 const Mat<eT>& B = tmp2.M;
00192 const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) );
00193
00194 const bool do_trans_A = tmp1.do_trans;
00195 const bool do_trans_B = tmp2.do_trans;
00196 const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0));
00197
00198 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
00199
00200 const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00201 const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00202
00203 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition");
00204
00205 if(out.n_elem > 0)
00206 {
00207 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00208 {
00209 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00210 {
00211 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00212 }
00213 else
00214 if(B.n_cols == 1)
00215 {
00216 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00217 }
00218 else
00219 {
00220 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
00221 }
00222 }
00223 else
00224 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00225 {
00226 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00227 {
00228 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00229 }
00230 else
00231 if(B.n_cols == 1)
00232 {
00233 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00234 }
00235 else
00236 {
00237 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
00238 }
00239 }
00240 else
00241 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00242 {
00243 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00244 {
00245 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00246 }
00247 else
00248 if(B.n_cols == 1)
00249 {
00250 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00251 }
00252 else
00253 {
00254 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
00255 }
00256 }
00257 else
00258 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00259 {
00260 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00261 {
00262 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00263 }
00264 else
00265 if(B.n_cols == 1)
00266 {
00267 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00268 }
00269 else
00270 {
00271 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
00272 }
00273 }
00274 else
00275 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00276 {
00277 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00278 {
00279 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00280 }
00281 else
00282 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00283 {
00284 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00285 }
00286 else
00287 {
00288 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
00289 }
00290 }
00291 else
00292 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00293 {
00294 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00295 {
00296 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00297 }
00298 else
00299 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00300 {
00301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00302 }
00303 else
00304 {
00305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
00306 }
00307 }
00308 else
00309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00310 {
00311 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00312 {
00313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00314 }
00315 else
00316 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00317 {
00318 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00319 }
00320 else
00321 {
00322 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
00323 }
00324 }
00325 else
00326 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00327 {
00328 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00329 {
00330 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
00331 }
00332 else
00333 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00334 {
00335 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
00336 }
00337 else
00338 {
00339 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
00340 }
00341 }
00342 }
00343
00344
00345 }
00346
00347
00348
00349 template<typename eT>
00350 arma_inline
00351 uword
00352 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
00353 {
00354 const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00355 const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00356
00357 return final_A_n_rows * final_B_n_cols;
00358 }
00359
00360
00361
00362 template<typename eT>
00363 arma_hot
00364 inline
00365 void
00366 glue_times::apply
00367 (
00368 Mat<eT>& out,
00369 const Mat<eT>& A,
00370 const Mat<eT>& B,
00371 const eT alpha,
00372 const bool do_trans_A,
00373 const bool do_trans_B,
00374 const bool use_alpha
00375 )
00376 {
00377 arma_extra_debug_sigprint();
00378
00379 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
00380
00381 const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
00382 const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
00383
00384 out.set_size(final_n_rows, final_n_cols);
00385
00386 if( (A.n_elem > 0) && (B.n_elem > 0) )
00387 {
00388 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
00389 {
00390 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00391 {
00392 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00393 }
00394 else
00395 if(B.n_cols == 1)
00396 {
00397 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00398 }
00399 else
00400 {
00401 gemm<false, false, false, false>::apply(out, A, B);
00402 }
00403 }
00404 else
00405 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
00406 {
00407 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00408 {
00409 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00410 }
00411 else
00412 if(B.n_cols == 1)
00413 {
00414 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00415 }
00416 else
00417 {
00418 gemm<false, false, true, false>::apply(out, A, B, alpha);
00419 }
00420 }
00421 else
00422 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
00423 {
00424 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00425 {
00426 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
00427 }
00428 else
00429 if(B.n_cols == 1)
00430 {
00431 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00432 }
00433 else
00434 {
00435 gemm<true, false, false, false>::apply(out, A, B);
00436 }
00437 }
00438 else
00439 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
00440 {
00441 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00442 {
00443 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00444 }
00445 else
00446 if(B.n_cols == 1)
00447 {
00448 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00449 }
00450 else
00451 {
00452 gemm<true, false, true, false>::apply(out, A, B, alpha);
00453 }
00454 }
00455 else
00456 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
00457 {
00458 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00459 {
00460 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00461 }
00462 else
00463 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00464 {
00465 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
00466 }
00467 else
00468 {
00469 gemm<false, true, false, false>::apply(out, A, B);
00470 }
00471 }
00472 else
00473 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
00474 {
00475 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
00476 {
00477 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00478 }
00479 else
00480 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00481 {
00482 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00483 }
00484 else
00485 {
00486 gemm<false, true, true, false>::apply(out, A, B, alpha);
00487 }
00488 }
00489 else
00490 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
00491 {
00492 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00493 {
00494 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
00495 }
00496 else
00497 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00498 {
00499 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
00500 }
00501 else
00502 {
00503 gemm<true, true, false, false>::apply(out, A, B);
00504 }
00505 }
00506 else
00507 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
00508 {
00509 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
00510 {
00511 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
00512 }
00513 else
00514 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
00515 {
00516 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
00517 }
00518 else
00519 {
00520 gemm<true, true, true, false>::apply(out, A, B, alpha);
00521 }
00522 }
00523 }
00524 else
00525 {
00526 out.zeros();
00527 }
00528 }
00529
00530
00531
00532 template<typename eT>
00533 inline
00534 void
00535 glue_times::apply
00536 (
00537 Mat<eT>& out,
00538 const Mat<eT>& A,
00539 const Mat<eT>& B,
00540 const Mat<eT>& C,
00541 const eT alpha,
00542 const bool do_trans_A,
00543 const bool do_trans_B,
00544 const bool do_trans_C,
00545 const bool use_alpha
00546 )
00547 {
00548 arma_extra_debug_sigprint();
00549
00550 Mat<eT> tmp;
00551
00552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
00553 {
00554
00555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
00556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
00557 }
00558 else
00559 {
00560
00561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
00562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
00563 }
00564 }
00565
00566
00567
00568 template<typename eT>
00569 inline
00570 void
00571 glue_times::apply
00572 (
00573 Mat<eT>& out,
00574 const Mat<eT>& A,
00575 const Mat<eT>& B,
00576 const Mat<eT>& C,
00577 const Mat<eT>& D,
00578 const eT alpha,
00579 const bool do_trans_A,
00580 const bool do_trans_B,
00581 const bool do_trans_C,
00582 const bool do_trans_D,
00583 const bool use_alpha
00584 )
00585 {
00586 arma_extra_debug_sigprint();
00587
00588 Mat<eT> tmp;
00589
00590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
00591 {
00592
00593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
00594
00595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
00596 }
00597 else
00598 {
00599
00600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
00601
00602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
00603 }
00604 }
00605
00606
00607
00608
00609
00610
00611
00612 template<typename T1, typename T2>
00613 arma_hot
00614 inline
00615 void
00616 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
00617 {
00618 arma_extra_debug_sigprint();
00619
00620 typedef typename T1::elem_type eT;
00621
00622 const strip_diagmat<T1> S1(X.A);
00623 const strip_diagmat<T2> S2(X.B);
00624
00625 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
00626 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00627
00628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
00629 {
00630 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00631
00632 const unwrap_check<T2> tmp(X.B, out);
00633 const Mat<eT>& B = tmp.M;
00634
00635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication");
00636
00637 out.set_size(A.n_elem, B.n_cols);
00638
00639 for(uword col=0; col<B.n_cols; ++col)
00640 {
00641 eT* out_coldata = out.colptr(col);
00642 const eT* B_coldata = B.colptr(col);
00643
00644 for(uword row=0; row<B.n_rows; ++row)
00645 {
00646 out_coldata[row] = A[row] * B_coldata[row];
00647 }
00648 }
00649 }
00650 else
00651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
00652 {
00653 const unwrap_check<T1> tmp(X.A, out);
00654 const Mat<eT>& A = tmp.M;
00655
00656 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00657
00658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication");
00659
00660 out.set_size(A.n_rows, B.n_elem);
00661
00662 for(uword col=0; col<A.n_cols; ++col)
00663 {
00664 const eT val = B[col];
00665
00666 eT* out_coldata = out.colptr(col);
00667 const eT* A_coldata = A.colptr(col);
00668
00669 for(uword row=0; row<A.n_rows; ++row)
00670 {
00671 out_coldata[row] = A_coldata[row] * val;
00672 }
00673 }
00674 }
00675 else
00676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
00677 {
00678 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
00679 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
00680
00681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication");
00682
00683 out.zeros(A.n_elem, A.n_elem);
00684
00685 for(uword i=0; i<A.n_elem; ++i)
00686 {
00687 out.at(i,i) = A[i] * B[i];
00688 }
00689 }
00690 }
00691
00692
00693