$search
00001 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2009-2011 Conrad Sanderson 00003 // Copyright (C) 2009-2010 Dimitrios Bouzas 00004 // Copyright (C) 2011 Stanislav Funiak 00005 // 00006 // This file is part of the Armadillo C++ library. 00007 // It is provided without any warranty of fitness 00008 // for any purpose. You can redistribute this file 00009 // and/or modify it under the terms of the GNU 00010 // Lesser General Public License (LGPL) as published 00011 // by the Free Software Foundation, either version 3 00012 // of the License or (at your option) any later version. 00013 // (see http://www.opensource.org/licenses for more info) 00014 00015 00016 00019 00020 00021 00022 template<typename eT> 00023 inline 00024 void 00025 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol) 00026 { 00027 arma_extra_debug_sigprint(); 00028 00029 typedef typename get_pod_type<eT>::result T; 00030 00031 T tol = access::tmp_real(in_tol); 00032 00033 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); 00034 00035 const uword n_rows = A.n_rows; 00036 const uword n_cols = A.n_cols; 00037 00038 // economical SVD decomposition 00039 Mat<eT> U; 00040 Col< T> s; 00041 Mat<eT> V; 00042 00043 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b'); 00044 00045 if(status == false) 00046 { 00047 out.reset(); 00048 arma_bad("pinv(): svd failed"); 00049 return; 00050 } 00051 00052 const uword s_n_elem = s.n_elem; 00053 const T* s_mem = s.memptr(); 00054 00055 // set tolerance to default if it hasn't been specified as an argument 00056 if( (tol == T(0)) && (s_n_elem > 0) ) 00057 { 00058 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) ); 00059 } 00060 00061 00062 // count non zero valued elements in s 00063 00064 uword count = 0; 00065 00066 for(uword i = 0; i < s_n_elem; ++i) 00067 { 00068 if(s_mem[i] > tol) 00069 { 00070 ++count; 00071 } 00072 } 00073 00074 if(count != 0) 00075 { 00076 Col<T> s2(count); 00077 00078 T* s2_mem = s2.memptr(); 00079 00080 uword count2 = 0; 00081 00082 for(uword i=0; i < s_n_elem; ++i) 00083 { 00084 const T val = s_mem[i]; 00085 00086 if(val > tol) 00087 { 00088 s2_mem[count2] = T(1) / val; 00089 ++count2; 00090 } 00091 } 00092 00093 00094 if(n_rows >= n_cols) 00095 { 00096 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U ); 00097 } 00098 else 00099 { 00100 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V ); 00101 } 00102 } 00103 else 00104 { 00105 out.zeros(n_cols, n_rows); 00106 } 00107 } 00108 00109 00110 00111 template<typename T1> 00112 inline 00113 void 00114 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in) 00115 { 00116 arma_extra_debug_sigprint(); 00117 00118 typedef typename T1::elem_type eT; 00119 00120 const unwrap<T1> tmp(in.m); 00121 const Mat<eT>& A = tmp.M; 00122 00123 op_pinv::direct_pinv(out, A, in.aux); 00124 } 00125 00126 00127