fn_as_scalar.hpp
Go to the documentation of this file.
00001 // Copyright (C) 2010-2011 NICTA (www.nicta.com.au)
00002 // Copyright (C) 2010-2011 Conrad Sanderson
00003 // 
00004 // This file is part of the Armadillo C++ library.
00005 // It is provided without any warranty of fitness
00006 // for any purpose. You can redistribute this file
00007 // and/or modify it under the terms of the GNU
00008 // Lesser General Public License (LGPL) as published
00009 // by the Free Software Foundation, either version 3
00010 // of the License or (at your option) any later version.
00011 // (see http://www.opensource.org/licenses for more info)
00012 
00013 
00016 
00017 
00018 
00019 template<uword N>
00020 struct as_scalar_redirect
00021   {
00022   template<typename T1>
00023   inline static typename T1::elem_type apply(const T1& X);
00024   };
00025 
00026 
00027 
00028 template<>
00029 struct as_scalar_redirect<2>
00030   {
00031   template<typename T1, typename T2>
00032   inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
00033   };
00034 
00035 
00036 template<>
00037 struct as_scalar_redirect<3>
00038   {
00039   template<typename T1, typename T2, typename T3>
00040   inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
00041   };
00042 
00043 
00044 
00045 template<uword N>
00046 template<typename T1>
00047 inline
00048 typename T1::elem_type
00049 as_scalar_redirect<N>::apply(const T1& X)
00050   {
00051   arma_extra_debug_sigprint();
00052   
00053   typedef typename T1::elem_type eT;
00054   
00055   const unwrap<T1>   tmp(X);
00056   const Mat<eT>& A = tmp.M;
00057   
00058   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00059   
00060   return A.mem[0];
00061   }
00062 
00063 
00064 
00065 template<typename T1, typename T2>
00066 inline
00067 typename T1::elem_type
00068 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
00069   {
00070   arma_extra_debug_sigprint();
00071   
00072   typedef typename T1::elem_type eT;
00073   
00074   // T1 must result in a matrix with one row
00075   // T2 must result in a matrix with one column
00076   
00077   const partial_unwrap<T1> tmp1(X.A);
00078   const partial_unwrap<T2> tmp2(X.B);
00079   
00080   const Mat<eT>& A = tmp1.M;
00081   const Mat<eT>& B = tmp2.M;
00082   
00083   const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00084   const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00085   
00086   const uword B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
00087   const uword B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
00088   
00089   const eT val = tmp1.get_val() * tmp2.get_val();
00090   
00091   arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
00092   
00093   return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00094   }
00095 
00096 
00097 
00098 template<typename T1, typename T2, typename T3>
00099 inline
00100 typename T1::elem_type
00101 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
00102   {
00103   arma_extra_debug_sigprint();
00104   
00105   typedef typename T1::elem_type eT;
00106   
00107   // T1 * T2 must result in a matrix with one row
00108   // T3 must result in a matrix with one column
00109   
00110   typedef typename strip_inv    <T2           >::stored_type T2_stripped_1;
00111   typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
00112   
00113   const strip_inv    <T2>            strip1(X.A.B);
00114   const strip_diagmat<T2_stripped_1> strip2(strip1.M);
00115   
00116   const bool tmp2_do_inv     = strip1.do_inv;
00117   const bool tmp2_do_diagmat = strip2.do_diagmat;
00118   
00119   if(tmp2_do_diagmat == false)
00120     {
00121     const Mat<eT> tmp(X);
00122     
00123     arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00124     
00125     return tmp[0];
00126     }
00127   else
00128     {
00129     const partial_unwrap<T1>            tmp1(X.A.A);
00130     const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
00131     const partial_unwrap<T3>            tmp3(X.B);
00132     
00133     const Mat<eT>& A = tmp1.M;
00134     const Mat<eT>& B = tmp2.M;
00135     const Mat<eT>& C = tmp3.M;
00136     
00137     const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00138     const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00139     
00140     const bool B_is_vec = B.is_vec();
00141     
00142     const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
00143     const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
00144     
00145     const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00146     const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00147     
00148     const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
00149     
00150     arma_debug_check
00151       (
00152       (A_n_rows != 1)        ||
00153       (C_n_cols != 1)        ||
00154       (A_n_cols != B_n_rows) ||
00155       (B_n_cols != C_n_rows)
00156       ,
00157       "as_scalar(): incompatible dimensions"
00158       );
00159     
00160     
00161     if(B_is_vec == true)
00162       {
00163       if(tmp2_do_inv == true)
00164         {
00165         return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
00166         }
00167       else
00168         {
00169         return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
00170         }
00171       }
00172     else
00173       {
00174       if(tmp2_do_inv == true)
00175         {
00176         return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
00177         }
00178       else
00179         {
00180         return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
00181         }
00182       }
00183     }
00184   }
00185 
00186 
00187 
00188 template<typename T1>
00189 inline
00190 typename T1::elem_type
00191 as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
00192   {
00193   arma_extra_debug_sigprint();
00194   
00195   typedef typename T1::elem_type eT;
00196   
00197   const unwrap<T1>   tmp(X.get_ref());
00198   const Mat<eT>& A = tmp.M;
00199   
00200   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00201   
00202   return A.mem[0];
00203   }
00204 
00205 
00206 
00207 template<typename T1, typename T2, typename T3>
00208 inline
00209 typename T1::elem_type
00210 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
00211   {
00212   arma_extra_debug_sigprint();
00213   
00214   typedef typename T1::elem_type eT;
00215   
00216   // T1 * T2 must result in a matrix with one row
00217   // T3 must result in a matrix with one column
00218   
00219   typedef typename strip_diagmat<T2>::stored_type T2_stripped;
00220   
00221   const strip_diagmat<T2> strip(X.A.B);
00222   
00223   const partial_unwrap<T1>          tmp1(X.A.A);
00224   const partial_unwrap<T2_stripped> tmp2(strip.M);
00225   const partial_unwrap<T3>          tmp3(X.B);
00226   
00227   const Mat<eT>& A = tmp1.M;
00228   const Mat<eT>& B = tmp2.M;
00229   const Mat<eT>& C = tmp3.M;
00230   
00231   
00232   const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
00233   const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
00234   
00235   const bool B_is_vec = B.is_vec();
00236   
00237   const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
00238   const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
00239   
00240   const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
00241   const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
00242   
00243   const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
00244   
00245   arma_debug_check
00246     (
00247     (A_n_rows != 1)        ||
00248     (C_n_cols != 1)        ||
00249     (A_n_cols != B_n_rows) ||
00250     (B_n_cols != C_n_rows)
00251     ,
00252     "as_scalar(): incompatible dimensions"
00253     );
00254   
00255   
00256   if(B_is_vec == true)
00257     {
00258     return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
00259     }
00260   else
00261     {
00262     return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
00263     }
00264   }
00265 
00266 
00267 
00268 template<typename T1, typename T2>
00269 arma_inline
00270 arma_warn_unused
00271 typename T1::elem_type
00272 as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0)
00273   {
00274   arma_extra_debug_sigprint();
00275   arma_ignore(junk);
00276   
00277   if(is_glue_times_diag<T1>::value == false)
00278     {
00279     const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00280     
00281     arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00282     
00283     return as_scalar_redirect<N_mat>::apply(X);
00284     }
00285   else
00286     {
00287     return as_scalar_diag(X);
00288     }
00289   }
00290 
00291 
00292 
00293 template<typename T1>
00294 inline
00295 arma_warn_unused
00296 typename T1::elem_type
00297 as_scalar(const Base<typename T1::elem_type,T1>& X)
00298   {
00299   arma_extra_debug_sigprint();
00300   
00301   typedef typename T1::elem_type eT;
00302   
00303   const unwrap<T1>   tmp(X.get_ref());
00304   const Mat<eT>& A = tmp.M;
00305   
00306   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00307   
00308   return A.mem[0];
00309   }
00310 
00311 
00312 
00313 template<typename T1>
00314 arma_inline
00315 arma_warn_unused
00316 typename T1::elem_type
00317 as_scalar(const eOp<T1, eop_neg>& X)
00318   {
00319   arma_extra_debug_sigprint();
00320   
00321   return -(as_scalar(X.P.Q));
00322   }
00323 
00324 
00325 
00326 template<typename T1>
00327 inline
00328 arma_warn_unused
00329 typename T1::elem_type
00330 as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
00331   {
00332   arma_extra_debug_sigprint();
00333   
00334   typedef typename T1::elem_type eT;
00335   
00336   const unwrap_cube<T1> tmp(X.get_ref());
00337   const Cube<eT>& A   = tmp.M;
00338   
00339   arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
00340   
00341   return A.mem[0];
00342   }
00343 
00344 
00345 
00346 template<typename T>
00347 arma_inline
00348 arma_warn_unused
00349 const typename arma_scalar_only<T>::result &
00350 as_scalar(const T& x)
00351   {
00352   return x;
00353   }
00354 
00355 
00356 


armadillo_matrix
Author(s): Conrad Sanderson - NICTA (www.nicta.com.au), (Wrapper by Sjoerd van den Dries)
autogenerated on Tue Jan 7 2014 11:42:03