$search
00001 // Copyright (C) 2010-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2010-2011 Conrad Sanderson 00003 // 00004 // This file is part of the Armadillo C++ library. 00005 // It is provided without any warranty of fitness 00006 // for any purpose. You can redistribute this file 00007 // and/or modify it under the terms of the GNU 00008 // Lesser General Public License (LGPL) as published 00009 // by the Free Software Foundation, either version 3 00010 // of the License or (at your option) any later version. 00011 // (see http://www.opensource.org/licenses for more info) 00012 00013 00016 00017 00018 00019 template<uword N> 00020 struct as_scalar_redirect 00021 { 00022 template<typename T1> 00023 inline static typename T1::elem_type apply(const T1& X); 00024 }; 00025 00026 00027 00028 template<> 00029 struct as_scalar_redirect<2> 00030 { 00031 template<typename T1, typename T2> 00032 inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X); 00033 }; 00034 00035 00036 template<> 00037 struct as_scalar_redirect<3> 00038 { 00039 template<typename T1, typename T2, typename T3> 00040 inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X); 00041 }; 00042 00043 00044 00045 template<uword N> 00046 template<typename T1> 00047 inline 00048 typename T1::elem_type 00049 as_scalar_redirect<N>::apply(const T1& X) 00050 { 00051 arma_extra_debug_sigprint(); 00052 00053 typedef typename T1::elem_type eT; 00054 00055 const unwrap<T1> tmp(X); 00056 const Mat<eT>& A = tmp.M; 00057 00058 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); 00059 00060 return A.mem[0]; 00061 } 00062 00063 00064 00065 template<typename T1, typename T2> 00066 inline 00067 typename T1::elem_type 00068 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X) 00069 { 00070 arma_extra_debug_sigprint(); 00071 00072 typedef typename T1::elem_type eT; 00073 00074 // T1 must result in a matrix with one row 00075 // T2 must result in a matrix with one column 00076 00077 const partial_unwrap<T1> tmp1(X.A); 00078 const partial_unwrap<T2> tmp2(X.B); 00079 00080 const Mat<eT>& A = tmp1.M; 00081 const Mat<eT>& B = tmp2.M; 00082 00083 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; 00084 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; 00085 00086 const uword B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols; 00087 const uword B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows; 00088 00089 const eT val = tmp1.get_val() * tmp2.get_val(); 00090 00091 arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" ); 00092 00093 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem); 00094 } 00095 00096 00097 00098 template<typename T1, typename T2, typename T3> 00099 inline 00100 typename T1::elem_type 00101 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X) 00102 { 00103 arma_extra_debug_sigprint(); 00104 00105 typedef typename T1::elem_type eT; 00106 00107 // T1 * T2 must result in a matrix with one row 00108 // T3 must result in a matrix with one column 00109 00110 typedef typename strip_inv <T2 >::stored_type T2_stripped_1; 00111 typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2; 00112 00113 const strip_inv <T2> strip1(X.A.B); 00114 const strip_diagmat<T2_stripped_1> strip2(strip1.M); 00115 00116 const bool tmp2_do_inv = strip1.do_inv; 00117 const bool tmp2_do_diagmat = strip2.do_diagmat; 00118 00119 if(tmp2_do_diagmat == false) 00120 { 00121 const Mat<eT> tmp(X); 00122 00123 arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); 00124 00125 return tmp[0]; 00126 } 00127 else 00128 { 00129 const partial_unwrap<T1> tmp1(X.A.A); 00130 const partial_unwrap<T2_stripped_2> tmp2(strip2.M); 00131 const partial_unwrap<T3> tmp3(X.B); 00132 00133 const Mat<eT>& A = tmp1.M; 00134 const Mat<eT>& B = tmp2.M; 00135 const Mat<eT>& C = tmp3.M; 00136 00137 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; 00138 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; 00139 00140 const bool B_is_vec = B.is_vec(); 00141 00142 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); 00143 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); 00144 00145 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; 00146 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; 00147 00148 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); 00149 00150 arma_debug_check 00151 ( 00152 (A_n_rows != 1) || 00153 (C_n_cols != 1) || 00154 (A_n_cols != B_n_rows) || 00155 (B_n_cols != C_n_rows) 00156 , 00157 "as_scalar(): incompatible dimensions" 00158 ); 00159 00160 00161 if(B_is_vec == true) 00162 { 00163 if(tmp2_do_inv == true) 00164 { 00165 return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem); 00166 } 00167 else 00168 { 00169 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); 00170 } 00171 } 00172 else 00173 { 00174 if(tmp2_do_inv == true) 00175 { 00176 return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem); 00177 } 00178 else 00179 { 00180 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); 00181 } 00182 } 00183 } 00184 } 00185 00186 00187 00188 template<typename T1> 00189 inline 00190 typename T1::elem_type 00191 as_scalar_diag(const Base<typename T1::elem_type,T1>& X) 00192 { 00193 arma_extra_debug_sigprint(); 00194 00195 typedef typename T1::elem_type eT; 00196 00197 const unwrap<T1> tmp(X.get_ref()); 00198 const Mat<eT>& A = tmp.M; 00199 00200 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); 00201 00202 return A.mem[0]; 00203 } 00204 00205 00206 00207 template<typename T1, typename T2, typename T3> 00208 inline 00209 typename T1::elem_type 00210 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X) 00211 { 00212 arma_extra_debug_sigprint(); 00213 00214 typedef typename T1::elem_type eT; 00215 00216 // T1 * T2 must result in a matrix with one row 00217 // T3 must result in a matrix with one column 00218 00219 typedef typename strip_diagmat<T2>::stored_type T2_stripped; 00220 00221 const strip_diagmat<T2> strip(X.A.B); 00222 00223 const partial_unwrap<T1> tmp1(X.A.A); 00224 const partial_unwrap<T2_stripped> tmp2(strip.M); 00225 const partial_unwrap<T3> tmp3(X.B); 00226 00227 const Mat<eT>& A = tmp1.M; 00228 const Mat<eT>& B = tmp2.M; 00229 const Mat<eT>& C = tmp3.M; 00230 00231 00232 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; 00233 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; 00234 00235 const bool B_is_vec = B.is_vec(); 00236 00237 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); 00238 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); 00239 00240 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; 00241 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; 00242 00243 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); 00244 00245 arma_debug_check 00246 ( 00247 (A_n_rows != 1) || 00248 (C_n_cols != 1) || 00249 (A_n_cols != B_n_rows) || 00250 (B_n_cols != C_n_rows) 00251 , 00252 "as_scalar(): incompatible dimensions" 00253 ); 00254 00255 00256 if(B_is_vec == true) 00257 { 00258 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); 00259 } 00260 else 00261 { 00262 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); 00263 } 00264 } 00265 00266 00267 00268 template<typename T1, typename T2> 00269 arma_inline 00270 arma_warn_unused 00271 typename T1::elem_type 00272 as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0) 00273 { 00274 arma_extra_debug_sigprint(); 00275 arma_ignore(junk); 00276 00277 if(is_glue_times_diag<T1>::value == false) 00278 { 00279 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; 00280 00281 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); 00282 00283 return as_scalar_redirect<N_mat>::apply(X); 00284 } 00285 else 00286 { 00287 return as_scalar_diag(X); 00288 } 00289 } 00290 00291 00292 00293 template<typename T1> 00294 inline 00295 arma_warn_unused 00296 typename T1::elem_type 00297 as_scalar(const Base<typename T1::elem_type,T1>& X) 00298 { 00299 arma_extra_debug_sigprint(); 00300 00301 typedef typename T1::elem_type eT; 00302 00303 const unwrap<T1> tmp(X.get_ref()); 00304 const Mat<eT>& A = tmp.M; 00305 00306 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); 00307 00308 return A.mem[0]; 00309 } 00310 00311 00312 00313 template<typename T1> 00314 arma_inline 00315 arma_warn_unused 00316 typename T1::elem_type 00317 as_scalar(const eOp<T1, eop_neg>& X) 00318 { 00319 arma_extra_debug_sigprint(); 00320 00321 return -(as_scalar(X.P.Q)); 00322 } 00323 00324 00325 00326 template<typename T1> 00327 inline 00328 arma_warn_unused 00329 typename T1::elem_type 00330 as_scalar(const BaseCube<typename T1::elem_type,T1>& X) 00331 { 00332 arma_extra_debug_sigprint(); 00333 00334 typedef typename T1::elem_type eT; 00335 00336 const unwrap_cube<T1> tmp(X.get_ref()); 00337 const Cube<eT>& A = tmp.M; 00338 00339 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); 00340 00341 return A.mem[0]; 00342 } 00343 00344 00345 00346 template<typename T> 00347 arma_inline 00348 arma_warn_unused 00349 const typename arma_scalar_only<T>::result & 00350 as_scalar(const T& x) 00351 { 00352 return x; 00353 } 00354 00355 00356