$search
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2008-2011 Conrad Sanderson 00003 // Copyright (C) 2009 Edmund Highcock 00004 // Copyright (C) 2011 James Sanders 00005 // Copyright (C) 2011 Stanislav Funiak 00006 // 00007 // This file is part of the Armadillo C++ library. 00008 // It is provided without any warranty of fitness 00009 // for any purpose. You can redistribute this file 00010 // and/or modify it under the terms of the GNU 00011 // Lesser General Public License (LGPL) as published 00012 // by the Free Software Foundation, either version 3 00013 // of the License or (at your option) any later version. 00014 // (see http://www.opensource.org/licenses for more info) 00015 00016 00019 00020 00021 00023 template<typename eT, typename T1> 00024 inline 00025 bool 00026 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow) 00027 { 00028 arma_extra_debug_sigprint(); 00029 00030 out = X.get_ref(); 00031 00032 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); 00033 00034 bool status = false; 00035 00036 const uword N = out.n_rows; 00037 00038 if( (N <= 4) && (slow == false) ) 00039 { 00040 status = auxlib::inv_inplace_tinymat(out, N); 00041 } 00042 00043 if( (N > 4) || (status == false) ) 00044 { 00045 status = auxlib::inv_inplace_lapack(out); 00046 } 00047 00048 return status; 00049 } 00050 00051 00052 00053 template<typename eT> 00054 inline 00055 bool 00056 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow) 00057 { 00058 arma_extra_debug_sigprint(); 00059 00060 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" ); 00061 00062 bool status = false; 00063 00064 const uword N = X.n_rows; 00065 00066 if( (N <= 4) && (slow == false) ) 00067 { 00068 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N); 00069 } 00070 00071 if( (N > 4) || (status == false) ) 00072 { 00073 out = X; 00074 status = auxlib::inv_inplace_lapack(out); 00075 } 00076 00077 return status; 00078 } 00079 00080 00081 00082 template<typename eT> 00083 inline 00084 bool 00085 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N) 00086 { 00087 arma_extra_debug_sigprint(); 00088 00089 bool det_ok = true; 00090 00091 out.set_size(N,N); 00092 00093 switch(N) 00094 { 00095 case 1: 00096 { 00097 out[0] = eT(1) / X[0]; 00098 }; 00099 break; 00100 00101 case 2: 00102 { 00103 const eT* Xm = X.memptr(); 00104 00105 const eT a = Xm[pos<0,0>::n2]; 00106 const eT b = Xm[pos<0,1>::n2]; 00107 const eT c = Xm[pos<1,0>::n2]; 00108 const eT d = Xm[pos<1,1>::n2]; 00109 00110 const eT tmp_det = (a*d - b*c); 00111 00112 if(tmp_det != eT(0)) 00113 { 00114 eT* outm = out.memptr(); 00115 00116 outm[pos<0,0>::n2] = d / tmp_det; 00117 outm[pos<0,1>::n2] = -b / tmp_det; 00118 outm[pos<1,0>::n2] = -c / tmp_det; 00119 outm[pos<1,1>::n2] = a / tmp_det; 00120 } 00121 else 00122 { 00123 det_ok = false; 00124 } 00125 }; 00126 break; 00127 00128 case 3: 00129 { 00130 const eT* X_col0 = X.colptr(0); 00131 const eT a11 = X_col0[0]; 00132 const eT a21 = X_col0[1]; 00133 const eT a31 = X_col0[2]; 00134 00135 const eT* X_col1 = X.colptr(1); 00136 const eT a12 = X_col1[0]; 00137 const eT a22 = X_col1[1]; 00138 const eT a32 = X_col1[2]; 00139 00140 const eT* X_col2 = X.colptr(2); 00141 const eT a13 = X_col2[0]; 00142 const eT a23 = X_col2[1]; 00143 const eT a33 = X_col2[2]; 00144 00145 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13); 00146 00147 if(tmp_det != eT(0)) 00148 { 00149 eT* out_col0 = out.colptr(0); 00150 out_col0[0] = (a33*a22 - a32*a23) / tmp_det; 00151 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det; 00152 out_col0[2] = (a32*a21 - a31*a22) / tmp_det; 00153 00154 eT* out_col1 = out.colptr(1); 00155 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det; 00156 out_col1[1] = (a33*a11 - a31*a13) / tmp_det; 00157 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det; 00158 00159 eT* out_col2 = out.colptr(2); 00160 out_col2[0] = (a23*a12 - a22*a13) / tmp_det; 00161 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det; 00162 out_col2[2] = (a22*a11 - a21*a12) / tmp_det; 00163 } 00164 else 00165 { 00166 det_ok = false; 00167 } 00168 }; 00169 break; 00170 00171 case 4: 00172 { 00173 const eT tmp_det = det(X); 00174 00175 if(tmp_det != eT(0)) 00176 { 00177 const eT* Xm = X.memptr(); 00178 eT* outm = out.memptr(); 00179 00180 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00181 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00182 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00183 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; 00184 00185 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00186 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00187 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00188 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; 00189 00190 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00191 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00192 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; 00193 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; 00194 00195 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; 00196 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; 00197 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; 00198 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det; 00199 } 00200 else 00201 { 00202 det_ok = false; 00203 } 00204 }; 00205 break; 00206 00207 default: 00208 ; 00209 } 00210 00211 return det_ok; 00212 } 00213 00214 00215 00216 template<typename eT> 00217 inline 00218 bool 00219 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N) 00220 { 00221 arma_extra_debug_sigprint(); 00222 00223 bool det_ok = true; 00224 00225 // for more info, see: 00226 // http://www.dr-lex.34sp.com/random/matrix_inv.html 00227 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html 00228 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm 00229 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl 00230 00231 switch(N) 00232 { 00233 case 1: 00234 { 00235 X[0] = eT(1) / X[0]; 00236 }; 00237 break; 00238 00239 case 2: 00240 { 00241 const eT a = X[pos<0,0>::n2]; 00242 const eT b = X[pos<0,1>::n2]; 00243 const eT c = X[pos<1,0>::n2]; 00244 const eT d = X[pos<1,1>::n2]; 00245 00246 const eT tmp_det = (a*d - b*c); 00247 00248 if(tmp_det != eT(0)) 00249 { 00250 X[pos<0,0>::n2] = d / tmp_det; 00251 X[pos<0,1>::n2] = -b / tmp_det; 00252 X[pos<1,0>::n2] = -c / tmp_det; 00253 X[pos<1,1>::n2] = a / tmp_det; 00254 } 00255 else 00256 { 00257 det_ok = false; 00258 } 00259 }; 00260 break; 00261 00262 case 3: 00263 { 00264 eT* X_col0 = X.colptr(0); 00265 eT* X_col1 = X.colptr(1); 00266 eT* X_col2 = X.colptr(2); 00267 00268 const eT a11 = X_col0[0]; 00269 const eT a21 = X_col0[1]; 00270 const eT a31 = X_col0[2]; 00271 00272 const eT a12 = X_col1[0]; 00273 const eT a22 = X_col1[1]; 00274 const eT a32 = X_col1[2]; 00275 00276 const eT a13 = X_col2[0]; 00277 const eT a23 = X_col2[1]; 00278 const eT a33 = X_col2[2]; 00279 00280 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13); 00281 00282 if(tmp_det != eT(0)) 00283 { 00284 X_col0[0] = (a33*a22 - a32*a23) / tmp_det; 00285 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det; 00286 X_col0[2] = (a32*a21 - a31*a22) / tmp_det; 00287 00288 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det; 00289 X_col1[1] = (a33*a11 - a31*a13) / tmp_det; 00290 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det; 00291 00292 X_col2[0] = (a23*a12 - a22*a13) / tmp_det; 00293 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det; 00294 X_col2[2] = (a22*a11 - a21*a12) / tmp_det; 00295 } 00296 else 00297 { 00298 det_ok = false; 00299 } 00300 }; 00301 break; 00302 00303 case 4: 00304 { 00305 const eT tmp_det = det(X); 00306 00307 if(tmp_det != eT(0)) 00308 { 00309 const Mat<eT> A(X); 00310 00311 const eT* Am = A.memptr(); 00312 eT* Xm = X.memptr(); 00313 00314 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00315 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00316 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00317 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; 00318 00319 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00320 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00321 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00322 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; 00323 00324 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00325 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00326 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; 00327 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; 00328 00329 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det; 00330 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det; 00331 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det; 00332 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det; 00333 } 00334 else 00335 { 00336 det_ok = false; 00337 } 00338 }; 00339 break; 00340 00341 default: 00342 ; 00343 } 00344 00345 return det_ok; 00346 } 00347 00348 00349 00350 template<typename eT> 00351 inline 00352 bool 00353 auxlib::inv_inplace_lapack(Mat<eT>& out) 00354 { 00355 arma_extra_debug_sigprint(); 00356 00357 if(out.is_empty()) 00358 { 00359 return true; 00360 } 00361 00362 #if defined(ARMA_USE_ATLAS) 00363 { 00364 podarray<int> ipiv(out.n_rows); 00365 00366 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr()); 00367 00368 if(info == 0) 00369 { 00370 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr()); 00371 } 00372 00373 return (info == 0); 00374 } 00375 #elif defined(ARMA_USE_LAPACK) 00376 { 00377 blas_int n_rows = out.n_rows; 00378 blas_int n_cols = out.n_cols; 00379 blas_int info = 0; 00380 00381 podarray<blas_int> ipiv(out.n_rows); 00382 00383 // 84 was empirically found -- it is the maximum value suggested by LAPACK (as provided by ATLAS v3.6) 00384 // based on tests with various matrix types on 32-bit and 64-bit machines 00385 // 00386 // the "work" array is deliberately long so that a secondary (time-consuming) 00387 // memory allocation is avoided, if possible 00388 00389 blas_int work_len = (std::max)(blas_int(1), n_rows*84); 00390 podarray<eT> work( static_cast<uword>(work_len) ); 00391 00392 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info); 00393 00394 if(info == 0) 00395 { 00396 // query for optimum size of work_len 00397 00398 blas_int work_len_tmp = -1; 00399 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len_tmp, &info); 00400 00401 if(info == 0) 00402 { 00403 blas_int proposed_work_len = static_cast<blas_int>(access::tmp_real(work[0])); 00404 00405 // if necessary, allocate more memory 00406 if(work_len < proposed_work_len) 00407 { 00408 work_len = proposed_work_len; 00409 work.set_size( static_cast<uword>(work_len) ); 00410 } 00411 } 00412 00413 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len, &info); 00414 } 00415 00416 return (info == 0); 00417 } 00418 #else 00419 { 00420 arma_ignore(out); 00421 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled"); 00422 return false; 00423 } 00424 #endif 00425 } 00426 00427 00428 00429 template<typename eT, typename T1> 00430 inline 00431 bool 00432 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout) 00433 { 00434 arma_extra_debug_sigprint(); 00435 00436 out = X.get_ref(); 00437 00438 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); 00439 00440 if(out.is_empty()) 00441 { 00442 return true; 00443 } 00444 00445 bool status; 00446 00447 #if defined(ARMA_USE_LAPACK) 00448 { 00449 char uplo = (layout == 0) ? 'U' : 'L'; 00450 char diag = 'N'; 00451 blas_int n = blas_int(out.n_rows); 00452 blas_int info = 0; 00453 00454 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info); 00455 00456 status = (info == 0); 00457 } 00458 #else 00459 { 00460 arma_ignore(layout); 00461 arma_stop("inv(): use of LAPACK needs to be enabled"); 00462 status = false; 00463 } 00464 #endif 00465 00466 00467 if(status == true) 00468 { 00469 if(layout == 0) 00470 { 00471 // upper triangular 00472 out = trimatu(out); 00473 } 00474 else 00475 { 00476 // lower triangular 00477 out = trimatl(out); 00478 } 00479 } 00480 00481 return status; 00482 } 00483 00484 00485 00486 template<typename eT, typename T1> 00487 inline 00488 bool 00489 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout) 00490 { 00491 arma_extra_debug_sigprint(); 00492 00493 out = X.get_ref(); 00494 00495 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); 00496 00497 if(out.is_empty()) 00498 { 00499 return true; 00500 } 00501 00502 bool status; 00503 00504 #if defined(ARMA_USE_LAPACK) 00505 { 00506 char uplo = (layout == 0) ? 'U' : 'L'; 00507 blas_int n = blas_int(out.n_rows); 00508 blas_int lwork = n*n; // TODO: use lwork = -1 to determine optimal size 00509 blas_int info = 0; 00510 00511 podarray<blas_int> ipiv; 00512 ipiv.set_size(out.n_rows); 00513 00514 podarray<eT> work; 00515 work.set_size( uword(lwork) ); 00516 00517 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info); 00518 00519 status = (info == 0); 00520 00521 if(status == true) 00522 { 00523 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info); 00524 00525 out = (layout == 0) ? symmatu(out) : symmatl(out); 00526 00527 status = (info == 0); 00528 } 00529 } 00530 #else 00531 { 00532 arma_ignore(layout); 00533 arma_stop("inv(): use of LAPACK needs to be enabled"); 00534 status = false; 00535 } 00536 #endif 00537 00538 return status; 00539 } 00540 00541 00542 00543 template<typename eT, typename T1> 00544 inline 00545 bool 00546 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout) 00547 { 00548 arma_extra_debug_sigprint(); 00549 00550 out = X.get_ref(); 00551 00552 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); 00553 00554 if(out.is_empty()) 00555 { 00556 return true; 00557 } 00558 00559 bool status; 00560 00561 #if defined(ARMA_USE_LAPACK) 00562 { 00563 char uplo = (layout == 0) ? 'U' : 'L'; 00564 blas_int n = blas_int(out.n_rows); 00565 blas_int info = 0; 00566 00567 lapack::potrf(&uplo, &n, out.memptr(), &n, &info); 00568 00569 status = (info == 0); 00570 00571 if(status == true) 00572 { 00573 lapack::potri(&uplo, &n, out.memptr(), &n, &info); 00574 00575 out = (layout == 0) ? symmatu(out) : symmatl(out); 00576 00577 status = (info == 0); 00578 } 00579 } 00580 #else 00581 { 00582 arma_ignore(layout); 00583 arma_stop("inv(): use of LAPACK needs to be enabled"); 00584 status = false; 00585 } 00586 #endif 00587 00588 return status; 00589 } 00590 00591 00592 00593 template<typename eT, typename T1> 00594 inline 00595 eT 00596 auxlib::det(const Base<eT,T1>& X, const bool slow) 00597 { 00598 const unwrap<T1> tmp(X.get_ref()); 00599 const Mat<eT>& A = tmp.M; 00600 00601 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" ); 00602 00603 const bool make_copy = (is_Mat<T1>::value == true) ? true : false; 00604 00605 if(slow == false) 00606 { 00607 const uword N = A.n_rows; 00608 00609 switch(N) 00610 { 00611 case 0: 00612 case 1: 00613 case 2: 00614 return auxlib::det_tinymat(A, N); 00615 break; 00616 00617 case 3: 00618 case 4: 00619 { 00620 const eT tmp_det = auxlib::det_tinymat(A, N); 00621 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy); 00622 } 00623 break; 00624 00625 default: 00626 return auxlib::det_lapack(A, make_copy); 00627 } 00628 } 00629 else 00630 { 00631 return auxlib::det_lapack(A, make_copy); 00632 } 00633 } 00634 00635 00636 00637 template<typename eT> 00638 inline 00639 eT 00640 auxlib::det_tinymat(const Mat<eT>& X, const uword N) 00641 { 00642 arma_extra_debug_sigprint(); 00643 00644 switch(N) 00645 { 00646 case 0: 00647 return eT(1); 00648 break; 00649 00650 case 1: 00651 return X[0]; 00652 break; 00653 00654 case 2: 00655 { 00656 const eT* Xm = X.memptr(); 00657 00658 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] ); 00659 } 00660 break; 00661 00662 case 3: 00663 { 00664 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2); 00665 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0); 00666 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1); 00667 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2); 00668 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0); 00669 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1); 00670 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6); 00671 00672 const eT* a_col0 = X.colptr(0); 00673 const eT a11 = a_col0[0]; 00674 const eT a21 = a_col0[1]; 00675 const eT a31 = a_col0[2]; 00676 00677 const eT* a_col1 = X.colptr(1); 00678 const eT a12 = a_col1[0]; 00679 const eT a22 = a_col1[1]; 00680 const eT a32 = a_col1[2]; 00681 00682 const eT* a_col2 = X.colptr(2); 00683 const eT a13 = a_col2[0]; 00684 const eT a23 = a_col2[1]; 00685 const eT a33 = a_col2[2]; 00686 00687 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) ); 00688 } 00689 break; 00690 00691 case 4: 00692 { 00693 const eT* Xm = X.memptr(); 00694 00695 const eT val = \ 00696 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ 00697 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ 00698 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ 00699 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ 00700 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ 00701 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ 00702 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ 00703 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ 00704 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ 00705 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ 00706 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ 00707 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ 00708 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ 00709 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ 00710 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ 00711 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ 00712 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ 00713 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ 00714 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ 00715 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ 00716 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ 00717 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ 00718 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ 00719 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ 00720 ; 00721 00722 return val; 00723 } 00724 break; 00725 00726 default: 00727 return eT(0); 00728 ; 00729 } 00730 } 00731 00732 00733 00735 template<typename eT> 00736 inline 00737 eT 00738 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy) 00739 { 00740 arma_extra_debug_sigprint(); 00741 00742 Mat<eT> X_copy; 00743 00744 if(make_copy == true) 00745 { 00746 X_copy = X; 00747 } 00748 00749 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X); 00750 00751 if(tmp.is_empty()) 00752 { 00753 return eT(1); 00754 } 00755 00756 00757 #if defined(ARMA_USE_ATLAS) 00758 { 00759 podarray<int> ipiv(tmp.n_rows); 00760 00761 //const int info = 00762 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); 00763 00764 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero 00765 eT val = tmp.at(0,0); 00766 for(uword i=1; i < tmp.n_rows; ++i) 00767 { 00768 val *= tmp.at(i,i); 00769 } 00770 00771 int sign = +1; 00772 for(uword i=0; i < tmp.n_rows; ++i) 00773 { 00774 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 00775 { 00776 sign *= -1; 00777 } 00778 } 00779 00780 return ( (sign < 0) ? -val : val ); 00781 } 00782 #elif defined(ARMA_USE_LAPACK) 00783 { 00784 podarray<blas_int> ipiv(tmp.n_rows); 00785 00786 blas_int info = 0; 00787 blas_int n_rows = blas_int(tmp.n_rows); 00788 blas_int n_cols = blas_int(tmp.n_cols); 00789 00790 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); 00791 00792 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero 00793 eT val = tmp.at(0,0); 00794 for(uword i=1; i < tmp.n_rows; ++i) 00795 { 00796 val *= tmp.at(i,i); 00797 } 00798 00799 blas_int sign = +1; 00800 for(uword i=0; i < tmp.n_rows; ++i) 00801 { 00802 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 00803 { 00804 sign *= -1; 00805 } 00806 } 00807 00808 return ( (sign < 0) ? -val : val ); 00809 } 00810 #else 00811 { 00812 arma_ignore(X); 00813 arma_ignore(make_copy); 00814 arma_ignore(tmp); 00815 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled"); 00816 return eT(0); 00817 } 00818 #endif 00819 } 00820 00821 00822 00824 template<typename eT, typename T1> 00825 inline 00826 bool 00827 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X) 00828 { 00829 arma_extra_debug_sigprint(); 00830 00831 typedef typename get_pod_type<eT>::result T; 00832 00833 #if defined(ARMA_USE_ATLAS) 00834 { 00835 Mat<eT> tmp(X.get_ref()); 00836 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" ); 00837 00838 if(tmp.is_empty()) 00839 { 00840 out_val = eT(0); 00841 out_sign = T(1); 00842 return true; 00843 } 00844 00845 podarray<int> ipiv(tmp.n_rows); 00846 00847 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); 00848 00849 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero 00850 00851 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; 00852 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); 00853 00854 for(uword i=1; i < tmp.n_rows; ++i) 00855 { 00856 const eT x = tmp.at(i,i); 00857 00858 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; 00859 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); 00860 } 00861 00862 for(uword i=0; i < tmp.n_rows; ++i) 00863 { 00864 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 00865 { 00866 sign *= -1; 00867 } 00868 } 00869 00870 out_val = val; 00871 out_sign = T(sign); 00872 00873 return (info == 0); 00874 } 00875 #elif defined(ARMA_USE_LAPACK) 00876 { 00877 Mat<eT> tmp(X.get_ref()); 00878 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" ); 00879 00880 if(tmp.is_empty()) 00881 { 00882 out_val = eT(0); 00883 out_sign = T(1); 00884 return true; 00885 } 00886 00887 podarray<blas_int> ipiv(tmp.n_rows); 00888 00889 blas_int info = 0; 00890 blas_int n_rows = blas_int(tmp.n_rows); 00891 blas_int n_cols = blas_int(tmp.n_cols); 00892 00893 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); 00894 00895 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero 00896 00897 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; 00898 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); 00899 00900 for(uword i=1; i < tmp.n_rows; ++i) 00901 { 00902 const eT x = tmp.at(i,i); 00903 00904 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; 00905 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); 00906 } 00907 00908 for(uword i=0; i < tmp.n_rows; ++i) 00909 { 00910 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 00911 { 00912 sign *= -1; 00913 } 00914 } 00915 00916 out_val = val; 00917 out_sign = T(sign); 00918 00919 return (info == 0); 00920 } 00921 #else 00922 { 00923 out_val = eT(0); 00924 out_sign = T(0); 00925 00926 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled"); 00927 00928 return false; 00929 } 00930 #endif 00931 } 00932 00933 00934 00936 template<typename eT, typename T1> 00937 inline 00938 bool 00939 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X) 00940 { 00941 arma_extra_debug_sigprint(); 00942 00943 U = X.get_ref(); 00944 00945 const uword U_n_rows = U.n_rows; 00946 const uword U_n_cols = U.n_cols; 00947 00948 if(U.is_empty()) 00949 { 00950 L.set_size(U_n_rows, 0); 00951 U.set_size(0, U_n_cols); 00952 ipiv.reset(); 00953 return true; 00954 } 00955 00956 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK) 00957 { 00958 bool status; 00959 00960 #if defined(ARMA_USE_ATLAS) 00961 { 00962 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); 00963 00964 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr()); 00965 00966 status = (info == 0); 00967 } 00968 #elif defined(ARMA_USE_LAPACK) 00969 { 00970 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); 00971 00972 blas_int info = 0; 00973 00974 blas_int n_rows = U_n_rows; 00975 blas_int n_cols = U_n_cols; 00976 00977 00978 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info); 00979 00980 // take into account that Fortran counts from 1 00981 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem); 00982 00983 status = (info == 0); 00984 } 00985 #endif 00986 00987 L.copy_size(U); 00988 00989 for(uword col=0; col < U_n_cols; ++col) 00990 { 00991 for(uword row=0; (row < col) && (row < U_n_rows); ++row) 00992 { 00993 L.at(row,col) = eT(0); 00994 } 00995 00996 if( L.in_range(col,col) == true ) 00997 { 00998 L.at(col,col) = eT(1); 00999 } 01000 01001 for(uword row = (col+1); row < U_n_rows; ++row) 01002 { 01003 L.at(row,col) = U.at(row,col); 01004 U.at(row,col) = eT(0); 01005 } 01006 } 01007 01008 return status; 01009 } 01010 #else 01011 { 01012 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled"); 01013 01014 return false; 01015 } 01016 #endif 01017 } 01018 01019 01020 01021 template<typename eT, typename T1> 01022 inline 01023 bool 01024 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X) 01025 { 01026 arma_extra_debug_sigprint(); 01027 01028 podarray<blas_int> ipiv1; 01029 const bool status = auxlib::lu(L, U, ipiv1, X); 01030 01031 if(status == true) 01032 { 01033 if(U.is_empty()) 01034 { 01035 // L and U have been already set to the correct empty matrices 01036 P.eye(L.n_rows, L.n_rows); 01037 return true; 01038 } 01039 01040 const uword n = ipiv1.n_elem; 01041 const uword P_rows = U.n_rows; 01042 01043 podarray<blas_int> ipiv2(P_rows); 01044 01045 const blas_int* ipiv1_mem = ipiv1.memptr(); 01046 blas_int* ipiv2_mem = ipiv2.memptr(); 01047 01048 for(uword i=0; i<P_rows; ++i) 01049 { 01050 ipiv2_mem[i] = blas_int(i); 01051 } 01052 01053 for(uword i=0; i<n; ++i) 01054 { 01055 const uword k = static_cast<uword>(ipiv1_mem[i]); 01056 01057 if( ipiv2_mem[i] != ipiv2_mem[k] ) 01058 { 01059 std::swap( ipiv2_mem[i], ipiv2_mem[k] ); 01060 } 01061 } 01062 01063 P.zeros(P_rows, P_rows); 01064 01065 for(uword row=0; row<P_rows; ++row) 01066 { 01067 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1); 01068 } 01069 01070 if(L.n_cols > U.n_rows) 01071 { 01072 L.shed_cols(U.n_rows, L.n_cols-1); 01073 } 01074 01075 if(U.n_rows > L.n_cols) 01076 { 01077 U.shed_rows(L.n_cols, U.n_rows-1); 01078 } 01079 } 01080 01081 return status; 01082 } 01083 01084 01085 01086 template<typename eT, typename T1> 01087 inline 01088 bool 01089 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X) 01090 { 01091 arma_extra_debug_sigprint(); 01092 01093 podarray<blas_int> ipiv1; 01094 const bool status = auxlib::lu(L, U, ipiv1, X); 01095 01096 if(status == true) 01097 { 01098 if(U.is_empty()) 01099 { 01100 // L and U have been already set to the correct empty matrices 01101 return true; 01102 } 01103 01104 const uword n = ipiv1.n_elem; 01105 const uword P_rows = U.n_rows; 01106 01107 podarray<blas_int> ipiv2(P_rows); 01108 01109 const blas_int* ipiv1_mem = ipiv1.memptr(); 01110 blas_int* ipiv2_mem = ipiv2.memptr(); 01111 01112 for(uword i=0; i<P_rows; ++i) 01113 { 01114 ipiv2_mem[i] = blas_int(i); 01115 } 01116 01117 for(uword i=0; i<n; ++i) 01118 { 01119 const uword k = static_cast<uword>(ipiv1_mem[i]); 01120 01121 if( ipiv2_mem[i] != ipiv2_mem[k] ) 01122 { 01123 std::swap( ipiv2_mem[i], ipiv2_mem[k] ); 01124 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) ); 01125 } 01126 } 01127 01128 if(L.n_cols > U.n_rows) 01129 { 01130 L.shed_cols(U.n_rows, L.n_cols-1); 01131 } 01132 01133 if(U.n_rows > L.n_cols) 01134 { 01135 U.shed_rows(L.n_cols, U.n_rows-1); 01136 } 01137 } 01138 01139 return status; 01140 } 01141 01142 01143 01145 template<typename eT, typename T1> 01146 inline 01147 bool 01148 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X) 01149 { 01150 arma_extra_debug_sigprint(); 01151 01152 #if defined(ARMA_USE_LAPACK) 01153 { 01154 Mat<eT> A(X.get_ref()); 01155 01156 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square"); 01157 01158 if(A.is_empty()) 01159 { 01160 eigval.reset(); 01161 return true; 01162 } 01163 01164 // rudimentary "better-than-nothing" test for symmetry 01165 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" ); 01166 01167 char jobz = 'N'; 01168 char uplo = 'U'; 01169 01170 blas_int n_rows = A.n_rows; 01171 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1); 01172 01173 eigval.set_size( static_cast<uword>(n_rows) ); 01174 podarray<eT> work( static_cast<uword>(lwork) ); 01175 01176 blas_int info; 01177 01178 arma_extra_debug_print("lapack::syev()"); 01179 lapack::syev(&jobz, &uplo, &n_rows, A.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info); 01180 01181 return (info == 0); 01182 } 01183 #else 01184 { 01185 arma_ignore(eigval); 01186 arma_ignore(X); 01187 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); 01188 return false; 01189 } 01190 #endif 01191 } 01192 01193 01194 01196 template<typename T, typename T1> 01197 inline 01198 bool 01199 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X) 01200 { 01201 arma_extra_debug_sigprint(); 01202 01203 typedef typename std::complex<T> eT; 01204 01205 #if defined(ARMA_USE_LAPACK) 01206 { 01207 Mat<eT> A(X.get_ref()); 01208 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not hermitian"); 01209 01210 if(A.is_empty()) 01211 { 01212 eigval.reset(); 01213 return true; 01214 } 01215 01216 char jobz = 'N'; 01217 char uplo = 'U'; 01218 01219 blas_int n_rows = A.n_rows; 01220 blas_int lda = A.n_rows; 01221 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork 01222 01223 eigval.set_size( static_cast<uword>(n_rows) ); 01224 01225 podarray<eT> work( static_cast<uword>(lwork) ); 01226 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) ); 01227 01228 blas_int info; 01229 01230 arma_extra_debug_print("lapack::heev()"); 01231 lapack::heev(&jobz, &uplo, &n_rows, A.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); 01232 01233 return (info == 0); 01234 } 01235 #else 01236 { 01237 arma_ignore(eigval); 01238 arma_ignore(X); 01239 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); 01240 return false; 01241 } 01242 #endif 01243 } 01244 01245 01246 01248 template<typename eT, typename T1> 01249 inline 01250 bool 01251 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X) 01252 { 01253 arma_extra_debug_sigprint(); 01254 01255 #if defined(ARMA_USE_LAPACK) 01256 { 01257 eigvec = X.get_ref(); 01258 01259 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" ); 01260 01261 if(eigvec.is_empty()) 01262 { 01263 eigval.reset(); 01264 eigvec.reset(); 01265 return true; 01266 } 01267 01268 // rudimentary "better-than-nothing" test for symmetry 01269 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" ); 01270 01271 char jobz = 'V'; 01272 char uplo = 'U'; 01273 01274 blas_int n_rows = eigvec.n_rows; 01275 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1); 01276 01277 eigval.set_size( static_cast<uword>(n_rows) ); 01278 podarray<eT> work( static_cast<uword>(lwork) ); 01279 01280 blas_int info; 01281 01282 arma_extra_debug_print("lapack::syev()"); 01283 lapack::syev(&jobz, &uplo, &n_rows, eigvec.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info); 01284 01285 return (info == 0); 01286 } 01287 #else 01288 { 01289 arma_ignore(eigval); 01290 arma_ignore(eigvec); 01291 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); 01292 01293 return false; 01294 } 01295 #endif 01296 } 01297 01298 01299 01301 template<typename T, typename T1> 01302 inline 01303 bool 01304 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X) 01305 { 01306 arma_extra_debug_sigprint(); 01307 01308 typedef typename std::complex<T> eT; 01309 01310 #if defined(ARMA_USE_LAPACK) 01311 { 01312 eigvec = X.get_ref(); 01313 01314 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not hermitian" ); 01315 01316 if(eigvec.is_empty()) 01317 { 01318 eigval.reset(); 01319 eigvec.reset(); 01320 return true; 01321 } 01322 01323 char jobz = 'V'; 01324 char uplo = 'U'; 01325 01326 blas_int n_rows = eigvec.n_rows; 01327 blas_int lda = eigvec.n_rows; 01328 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork 01329 01330 eigval.set_size( static_cast<uword>(n_rows) ); 01331 01332 podarray<eT> work( static_cast<uword>(lwork) ); 01333 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) ); 01334 01335 blas_int info; 01336 01337 arma_extra_debug_print("lapack::heev()"); 01338 lapack::heev(&jobz, &uplo, &n_rows, eigvec.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); 01339 01340 return (info == 0); 01341 } 01342 #else 01343 { 01344 arma_ignore(eigval); 01345 arma_ignore(eigvec); 01346 arma_ignore(X); 01347 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); 01348 return false; 01349 } 01350 #endif 01351 } 01352 01353 01354 01358 template<typename T, typename T1> 01359 inline 01360 bool 01361 auxlib::eig_gen 01362 ( 01363 Col< std::complex<T> >& eigval, 01364 Mat<T>& l_eigvec, 01365 Mat<T>& r_eigvec, 01366 const Base<T,T1>& X, 01367 const char side 01368 ) 01369 { 01370 arma_extra_debug_sigprint(); 01371 01372 #if defined(ARMA_USE_LAPACK) 01373 { 01374 char jobvl; 01375 char jobvr; 01376 01377 switch(side) 01378 { 01379 case 'l': // left 01380 jobvl = 'V'; 01381 jobvr = 'N'; 01382 break; 01383 01384 case 'r': // right 01385 jobvl = 'N'; 01386 jobvr = 'V'; 01387 break; 01388 01389 case 'b': // both 01390 jobvl = 'V'; 01391 jobvr = 'V'; 01392 break; 01393 01394 case 'n': // neither 01395 jobvl = 'N'; 01396 jobvr = 'N'; 01397 break; 01398 01399 default: 01400 arma_stop("eig_gen(): parameter 'side' is invalid"); 01401 return false; 01402 } 01403 01404 Mat<T> A(X.get_ref()); 01405 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" ); 01406 01407 if(A.is_empty()) 01408 { 01409 eigval.reset(); 01410 l_eigvec.reset(); 01411 r_eigvec.reset(); 01412 return true; 01413 } 01414 01415 uword A_n_rows = A.n_rows; 01416 01417 blas_int n_rows = A_n_rows; 01418 blas_int lda = A_n_rows; 01419 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork 01420 01421 eigval.set_size(A_n_rows); 01422 l_eigvec.set_size(A_n_rows, A_n_rows); 01423 r_eigvec.set_size(A_n_rows, A_n_rows); 01424 01425 podarray<T> work( static_cast<uword>(lwork) ); 01426 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); 01427 01428 podarray<T> wr(A_n_rows); 01429 podarray<T> wi(A_n_rows); 01430 01431 Mat<T> A_copy = A; 01432 blas_int info; 01433 01434 arma_extra_debug_print("lapack::geev()"); 01435 lapack::geev(&jobvl, &jobvr, &n_rows, A_copy.memptr(), &lda, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, &info); 01436 01437 01438 eigval.set_size(A_n_rows); 01439 for(uword i=0; i<A_n_rows; ++i) 01440 { 01441 eigval[i] = std::complex<T>(wr[i], wi[i]); 01442 } 01443 01444 return (info == 0); 01445 } 01446 #else 01447 { 01448 arma_ignore(eigval); 01449 arma_ignore(l_eigvec); 01450 arma_ignore(r_eigvec); 01451 arma_ignore(X); 01452 arma_ignore(side); 01453 arma_stop("eig_gen(): use of LAPACK needs to be enabled"); 01454 return false; 01455 } 01456 #endif 01457 } 01458 01459 01460 01461 01462 01466 template<typename T, typename T1> 01467 inline 01468 bool 01469 auxlib::eig_gen 01470 ( 01471 Col< std::complex<T> >& eigval, 01472 Mat< std::complex<T> >& l_eigvec, 01473 Mat< std::complex<T> >& r_eigvec, 01474 const Base< std::complex<T>, T1 >& X, 01475 const char side 01476 ) 01477 { 01478 arma_extra_debug_sigprint(); 01479 01480 typedef typename std::complex<T> eT; 01481 01482 #if defined(ARMA_USE_LAPACK) 01483 { 01484 char jobvl; 01485 char jobvr; 01486 01487 switch(side) 01488 { 01489 case 'l': // left 01490 jobvl = 'V'; 01491 jobvr = 'N'; 01492 break; 01493 01494 case 'r': // right 01495 jobvl = 'N'; 01496 jobvr = 'V'; 01497 break; 01498 01499 case 'b': // both 01500 jobvl = 'V'; 01501 jobvr = 'V'; 01502 break; 01503 01504 case 'n': // neither 01505 jobvl = 'N'; 01506 jobvr = 'N'; 01507 break; 01508 01509 default: 01510 arma_stop("eig_gen(): parameter 'side' is invalid"); 01511 return false; 01512 } 01513 01514 Mat<eT> A(X.get_ref()); 01515 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" ); 01516 01517 if(A.is_empty()) 01518 { 01519 eigval.reset(); 01520 l_eigvec.reset(); 01521 r_eigvec.reset(); 01522 return true; 01523 } 01524 01525 uword A_n_rows = A.n_rows; 01526 01527 blas_int n_rows = A_n_rows; 01528 blas_int lda = A_n_rows; 01529 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork 01530 01531 eigval.set_size(A_n_rows); 01532 l_eigvec.set_size(A_n_rows, A_n_rows); 01533 r_eigvec.set_size(A_n_rows, A_n_rows); 01534 01535 podarray<eT> work( static_cast<uword>(lwork) ); 01536 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); // was 2,3 01537 01538 blas_int info; 01539 01540 arma_extra_debug_print("lapack::cx_geev()"); 01541 lapack::cx_geev(&jobvl, &jobvr, &n_rows, A.memptr(), &lda, eigval.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, rwork.memptr(), &info); 01542 01543 return (info == 0); 01544 } 01545 #else 01546 { 01547 arma_ignore(eigval); 01548 arma_ignore(l_eigvec); 01549 arma_ignore(r_eigvec); 01550 arma_ignore(X); 01551 arma_ignore(side); 01552 arma_stop("eig_gen(): use of LAPACK needs to be enabled"); 01553 return false; 01554 } 01555 #endif 01556 } 01557 01558 01559 01560 template<typename eT, typename T1> 01561 inline 01562 bool 01563 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X) 01564 { 01565 arma_extra_debug_sigprint(); 01566 01567 #if defined(ARMA_USE_LAPACK) 01568 { 01569 out = X.get_ref(); 01570 01571 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" ); 01572 01573 if(out.is_empty()) 01574 { 01575 return true; 01576 } 01577 01578 const uword out_n_rows = out.n_rows; 01579 01580 char uplo = 'U'; 01581 blas_int n = out_n_rows; 01582 blas_int info; 01583 01584 lapack::potrf(&uplo, &n, out.memptr(), &n, &info); 01585 01586 for(uword col=0; col<out_n_rows; ++col) 01587 { 01588 eT* colptr = out.colptr(col); 01589 01590 for(uword row=(col+1); row < out_n_rows; ++row) 01591 { 01592 colptr[row] = eT(0); 01593 } 01594 } 01595 01596 return (info == 0); 01597 } 01598 #else 01599 { 01600 arma_ignore(out); 01601 arma_stop("chol(): use of LAPACK needs to be enabled"); 01602 return false; 01603 } 01604 #endif 01605 } 01606 01607 01608 01609 template<typename eT, typename T1> 01610 inline 01611 bool 01612 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X) 01613 { 01614 arma_extra_debug_sigprint(); 01615 01616 #if defined(ARMA_USE_LAPACK) 01617 { 01618 R = X.get_ref(); 01619 01620 const uword R_n_rows = R.n_rows; 01621 const uword R_n_cols = R.n_cols; 01622 01623 if(R.is_empty()) 01624 { 01625 Q.eye(R_n_rows, R_n_rows); 01626 return true; 01627 } 01628 01629 blas_int m = static_cast<blas_int>(R_n_rows); 01630 blas_int n = static_cast<blas_int>(R_n_cols); 01631 blas_int work_len = (std::max)(blas_int(1),n); 01632 blas_int work_len_tmp; 01633 blas_int k = (std::min)(m,n); 01634 blas_int info; 01635 01636 podarray<eT> tau( static_cast<uword>(k) ); 01637 podarray<eT> work( static_cast<uword>(work_len) ); 01638 01639 // query for the optimum value of work_len 01640 work_len_tmp = -1; 01641 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info); 01642 01643 if(info == 0) 01644 { 01645 work_len = static_cast<blas_int>(access::tmp_real(work[0])); 01646 work.set_size( static_cast<uword>(work_len) ); 01647 } 01648 01649 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info); 01650 01651 Q.set_size(R_n_rows, R_n_rows); 01652 01653 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); 01654 01655 // 01656 // construct R 01657 01658 for(uword col=0; col < R_n_cols; ++col) 01659 { 01660 for(uword row=(col+1); row < R_n_rows; ++row) 01661 { 01662 R.at(row,col) = eT(0); 01663 } 01664 } 01665 01666 01667 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) ) 01668 { 01669 // query for the optimum value of work_len 01670 work_len_tmp = -1; 01671 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info); 01672 01673 if(info == 0) 01674 { 01675 work_len = static_cast<blas_int>(access::tmp_real(work[0])); 01676 work.set_size( static_cast<uword>(work_len) ); 01677 } 01678 01679 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info); 01680 } 01681 else 01682 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) ) 01683 { 01684 // query for the optimum value of work_len 01685 work_len_tmp = -1; 01686 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info); 01687 01688 if(info == 0) 01689 { 01690 work_len = static_cast<blas_int>(access::tmp_real(work[0])); 01691 work.set_size( static_cast<uword>(work_len) ); 01692 } 01693 01694 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info); 01695 } 01696 01697 return (info == 0); 01698 } 01699 #else 01700 { 01701 arma_ignore(Q); 01702 arma_ignore(R); 01703 arma_ignore(X); 01704 arma_stop("qr(): use of LAPACK needs to be enabled"); 01705 return false; 01706 } 01707 #endif 01708 } 01709 01710 01711 01712 template<typename eT, typename T1> 01713 inline 01714 bool 01715 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols) 01716 { 01717 arma_extra_debug_sigprint(); 01718 01719 #if defined(ARMA_USE_LAPACK) 01720 { 01721 Mat<eT> A(X.get_ref()); 01722 01723 X_n_rows = A.n_rows; 01724 X_n_cols = A.n_cols; 01725 01726 if(A.is_empty()) 01727 { 01728 S.reset(); 01729 return true; 01730 } 01731 01732 Mat<eT> U(1, 1); 01733 Mat<eT> V(1, A.n_cols); 01734 01735 char jobu = 'N'; 01736 char jobvt = 'N'; 01737 01738 blas_int m = A.n_rows; 01739 blas_int n = A.n_cols; 01740 blas_int lda = A.n_rows; 01741 blas_int ldu = U.n_rows; 01742 blas_int ldvt = V.n_rows; 01743 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); 01744 blas_int info; 01745 01746 S.set_size( static_cast<uword>((std::min)(m, n)) ); 01747 01748 podarray<eT> work( static_cast<uword>(lwork) ); 01749 01750 01751 // let gesvd_() calculate the optimum size of the workspace 01752 blas_int lwork_tmp = -1; 01753 01754 lapack::gesvd<eT> 01755 ( 01756 &jobu, &jobvt, 01757 &m,&n, 01758 A.memptr(), &lda, 01759 S.memptr(), 01760 U.memptr(), &ldu, 01761 V.memptr(), &ldvt, 01762 work.memptr(), &lwork_tmp, 01763 &info 01764 ); 01765 01766 if(info == 0) 01767 { 01768 blas_int proposed_lwork = static_cast<blas_int>(work[0]); 01769 01770 if(proposed_lwork > lwork) 01771 { 01772 lwork = proposed_lwork; 01773 work.set_size( static_cast<uword>(lwork) ); 01774 } 01775 01776 lapack::gesvd<eT> 01777 ( 01778 &jobu, &jobvt, 01779 &m, &n, 01780 A.memptr(), &lda, 01781 S.memptr(), 01782 U.memptr(), &ldu, 01783 V.memptr(), &ldvt, 01784 work.memptr(), &lwork, 01785 &info 01786 ); 01787 } 01788 01789 return (info == 0); 01790 } 01791 #else 01792 { 01793 arma_ignore(S); 01794 arma_ignore(X); 01795 arma_ignore(X_n_rows); 01796 arma_ignore(X_n_cols); 01797 arma_stop("svd(): use of LAPACK needs to be enabled"); 01798 return false; 01799 } 01800 #endif 01801 } 01802 01803 01804 01805 template<typename T, typename T1> 01806 inline 01807 bool 01808 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols) 01809 { 01810 arma_extra_debug_sigprint(); 01811 01812 typedef std::complex<T> eT; 01813 01814 #if defined(ARMA_USE_LAPACK) 01815 { 01816 Mat<eT> A(X.get_ref()); 01817 01818 X_n_rows = A.n_rows; 01819 X_n_cols = A.n_cols; 01820 01821 if(A.is_empty()) 01822 { 01823 S.reset(); 01824 return true; 01825 } 01826 01827 Mat<eT> U(1, 1); 01828 Mat<eT> V(1, A.n_cols); 01829 01830 char jobu = 'N'; 01831 char jobvt = 'N'; 01832 01833 blas_int m = A.n_rows; 01834 blas_int n = A.n_cols; 01835 blas_int lda = A.n_rows; 01836 blas_int ldu = U.n_rows; 01837 blas_int ldvt = V.n_rows; 01838 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) ); 01839 blas_int info; 01840 01841 S.set_size( static_cast<uword>((std::min)(m,n)) ); 01842 01843 podarray<eT> work( static_cast<uword>(lwork) ); 01844 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) ); 01845 01846 // let gesvd_() calculate the optimum size of the workspace 01847 blas_int lwork_tmp = -1; 01848 01849 lapack::cx_gesvd<T> 01850 ( 01851 &jobu, &jobvt, 01852 &m, &n, 01853 A.memptr(), &lda, 01854 S.memptr(), 01855 U.memptr(), &ldu, 01856 V.memptr(), &ldvt, 01857 work.memptr(), &lwork_tmp, 01858 rwork.memptr(), 01859 &info 01860 ); 01861 01862 if(info == 0) 01863 { 01864 blas_int proposed_lwork = static_cast<blas_int>(real(work[0])); 01865 if(proposed_lwork > lwork) 01866 { 01867 lwork = proposed_lwork; 01868 work.set_size( static_cast<uword>(lwork) ); 01869 } 01870 01871 lapack::cx_gesvd<T> 01872 ( 01873 &jobu, &jobvt, 01874 &m, &n, 01875 A.memptr(), &lda, 01876 S.memptr(), 01877 U.memptr(), &ldu, 01878 V.memptr(), &ldvt, 01879 work.memptr(), &lwork, 01880 rwork.memptr(), 01881 &info 01882 ); 01883 } 01884 01885 return (info == 0); 01886 } 01887 #else 01888 { 01889 arma_ignore(S); 01890 arma_ignore(X); 01891 arma_ignore(X_n_rows); 01892 arma_ignore(X_n_cols); 01893 01894 arma_stop("svd(): use of LAPACK needs to be enabled"); 01895 return false; 01896 } 01897 #endif 01898 } 01899 01900 01901 01902 template<typename eT, typename T1> 01903 inline 01904 bool 01905 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X) 01906 { 01907 arma_extra_debug_sigprint(); 01908 01909 uword junk; 01910 return auxlib::svd(S, X, junk, junk); 01911 } 01912 01913 01914 01915 template<typename T, typename T1> 01916 inline 01917 bool 01918 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X) 01919 { 01920 arma_extra_debug_sigprint(); 01921 01922 uword junk; 01923 return auxlib::svd(S, X, junk, junk); 01924 } 01925 01926 01927 01928 template<typename eT, typename T1> 01929 inline 01930 bool 01931 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X) 01932 { 01933 arma_extra_debug_sigprint(); 01934 01935 #if defined(ARMA_USE_LAPACK) 01936 { 01937 Mat<eT> A(X.get_ref()); 01938 01939 if(A.is_empty()) 01940 { 01941 U.eye(A.n_rows, A.n_rows); 01942 S.reset(); 01943 V.eye(A.n_cols, A.n_cols); 01944 return true; 01945 } 01946 01947 U.set_size(A.n_rows, A.n_rows); 01948 V.set_size(A.n_cols, A.n_cols); 01949 01950 char jobu = 'A'; 01951 char jobvt = 'A'; 01952 01953 blas_int m = A.n_rows; 01954 blas_int n = A.n_cols; 01955 blas_int lda = A.n_rows; 01956 blas_int ldu = U.n_rows; 01957 blas_int ldvt = V.n_rows; 01958 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); 01959 blas_int info; 01960 01961 01962 S.set_size( static_cast<uword>((std::min)(m,n)) ); 01963 podarray<eT> work( static_cast<uword>(lwork) ); 01964 01965 // let gesvd_() calculate the optimum size of the workspace 01966 blas_int lwork_tmp = -1; 01967 01968 lapack::gesvd<eT> 01969 ( 01970 &jobu, &jobvt, 01971 &m, &n, 01972 A.memptr(), &lda, 01973 S.memptr(), 01974 U.memptr(), &ldu, 01975 V.memptr(), &ldvt, 01976 work.memptr(), &lwork_tmp, 01977 &info 01978 ); 01979 01980 if(info == 0) 01981 { 01982 blas_int proposed_lwork = static_cast<blas_int>(work[0]); 01983 if(proposed_lwork > lwork) 01984 { 01985 lwork = proposed_lwork; 01986 work.set_size( static_cast<uword>(lwork) ); 01987 } 01988 01989 lapack::gesvd<eT> 01990 ( 01991 &jobu, &jobvt, 01992 &m, &n, 01993 A.memptr(), &lda, 01994 S.memptr(), 01995 U.memptr(), &ldu, 01996 V.memptr(), &ldvt, 01997 work.memptr(), &lwork, 01998 &info 01999 ); 02000 02001 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done 02002 } 02003 02004 return (info == 0); 02005 } 02006 #else 02007 { 02008 arma_ignore(U); 02009 arma_ignore(S); 02010 arma_ignore(V); 02011 arma_ignore(X); 02012 arma_stop("svd(): use of LAPACK needs to be enabled"); 02013 return false; 02014 } 02015 #endif 02016 } 02017 02018 02019 02020 template<typename T, typename T1> 02021 inline 02022 bool 02023 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X) 02024 { 02025 arma_extra_debug_sigprint(); 02026 02027 typedef std::complex<T> eT; 02028 02029 #if defined(ARMA_USE_LAPACK) 02030 { 02031 Mat<eT> A(X.get_ref()); 02032 02033 if(A.is_empty()) 02034 { 02035 U.eye(A.n_rows, A.n_rows); 02036 S.reset(); 02037 V.eye(A.n_cols, A.n_cols); 02038 return true; 02039 } 02040 02041 U.set_size(A.n_rows, A.n_rows); 02042 V.set_size(A.n_cols, A.n_cols); 02043 02044 char jobu = 'A'; 02045 char jobvt = 'A'; 02046 02047 blas_int m = A.n_rows; 02048 blas_int n = A.n_cols; 02049 blas_int lda = A.n_rows; 02050 blas_int ldu = U.n_rows; 02051 blas_int ldvt = V.n_rows; 02052 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) ); 02053 blas_int info; 02054 02055 S.set_size( static_cast<uword>((std::min)(m,n)) ); 02056 02057 podarray<eT> work( static_cast<uword>(lwork) ); 02058 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) ); 02059 02060 // let gesvd_() calculate the optimum size of the workspace 02061 blas_int lwork_tmp = -1; 02062 lapack::cx_gesvd<T> 02063 ( 02064 &jobu, &jobvt, 02065 &m, &n, 02066 A.memptr(), &lda, 02067 S.memptr(), 02068 U.memptr(), &ldu, 02069 V.memptr(), &ldvt, 02070 work.memptr(), &lwork_tmp, 02071 rwork.memptr(), 02072 &info 02073 ); 02074 02075 if(info == 0) 02076 { 02077 blas_int proposed_lwork = static_cast<blas_int>(real(work[0])); 02078 if(proposed_lwork > lwork) 02079 { 02080 lwork = proposed_lwork; 02081 work.set_size( static_cast<uword>(lwork) ); 02082 } 02083 02084 lapack::cx_gesvd<T> 02085 ( 02086 &jobu, &jobvt, 02087 &m, &n, 02088 A.memptr(), &lda, 02089 S.memptr(), 02090 U.memptr(), &ldu, 02091 V.memptr(), &ldvt, 02092 work.memptr(), &lwork, 02093 rwork.memptr(), 02094 &info 02095 ); 02096 02097 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done 02098 } 02099 02100 return (info == 0); 02101 } 02102 #else 02103 { 02104 arma_ignore(U); 02105 arma_ignore(S); 02106 arma_ignore(V); 02107 arma_ignore(X); 02108 arma_stop("svd(): use of LAPACK needs to be enabled"); 02109 return false; 02110 } 02111 #endif 02112 02113 } 02114 02115 02116 02117 template<typename eT, typename T1> 02118 inline 02119 bool 02120 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode) 02121 { 02122 arma_extra_debug_sigprint(); 02123 02124 #if defined(ARMA_USE_LAPACK) 02125 { 02126 Mat<eT> A(X.get_ref()); 02127 02128 blas_int m = A.n_rows; 02129 blas_int n = A.n_cols; 02130 blas_int lda = A.n_rows; 02131 02132 S.set_size( static_cast<uword>((std::min)(m,n)) ); 02133 02134 blas_int ldu = 0; 02135 blas_int ldvt = 0; 02136 02137 char jobu; 02138 char jobvt; 02139 02140 switch(mode) 02141 { 02142 case 'l': 02143 jobu = 'S'; 02144 jobvt = 'N'; 02145 02146 ldu = m; 02147 ldvt = 1; 02148 02149 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); 02150 V.reset(); 02151 02152 break; 02153 02154 02155 case 'r': 02156 jobu = 'N'; 02157 jobvt = 'S'; 02158 02159 ldu = 1; 02160 ldvt = (std::min)(m,n); 02161 02162 U.reset(); 02163 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); 02164 02165 break; 02166 02167 02168 case 'b': 02169 jobu = 'S'; 02170 jobvt = 'S'; 02171 02172 ldu = m; 02173 ldvt = (std::min)(m,n); 02174 02175 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); 02176 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); 02177 02178 break; 02179 02180 02181 default: 02182 U.reset(); 02183 S.reset(); 02184 V.reset(); 02185 return false; 02186 } 02187 02188 02189 if(A.is_empty()) 02190 { 02191 U.eye(); 02192 S.reset(); 02193 V.eye(); 02194 return true; 02195 } 02196 02197 02198 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); 02199 blas_int info = 0; 02200 02201 02202 podarray<eT> work( static_cast<uword>(lwork) ); 02203 02204 // let gesvd_() calculate the optimum size of the workspace 02205 blas_int lwork_tmp = -1; 02206 02207 lapack::gesvd<eT> 02208 ( 02209 &jobu, &jobvt, 02210 &m, &n, 02211 A.memptr(), &lda, 02212 S.memptr(), 02213 U.memptr(), &ldu, 02214 V.memptr(), &ldvt, 02215 work.memptr(), &lwork_tmp, 02216 &info 02217 ); 02218 02219 if(info == 0) 02220 { 02221 blas_int proposed_lwork = static_cast<blas_int>(work[0]); 02222 if(proposed_lwork > lwork) 02223 { 02224 lwork = proposed_lwork; 02225 work.set_size( static_cast<uword>(lwork) ); 02226 } 02227 02228 lapack::gesvd<eT> 02229 ( 02230 &jobu, &jobvt, 02231 &m, &n, 02232 A.memptr(), &lda, 02233 S.memptr(), 02234 U.memptr(), &ldu, 02235 V.memptr(), &ldvt, 02236 work.memptr(), &lwork, 02237 &info 02238 ); 02239 02240 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done 02241 } 02242 02243 return (info == 0); 02244 } 02245 #else 02246 { 02247 arma_ignore(U); 02248 arma_ignore(S); 02249 arma_ignore(V); 02250 arma_ignore(X); 02251 arma_ignore(mode); 02252 arma_stop("svd(): use of LAPACK needs to be enabled"); 02253 return false; 02254 } 02255 #endif 02256 } 02257 02258 02259 02260 template<typename T, typename T1> 02261 inline 02262 bool 02263 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode) 02264 { 02265 arma_extra_debug_sigprint(); 02266 02267 typedef std::complex<T> eT; 02268 02269 #if defined(ARMA_USE_LAPACK) 02270 { 02271 Mat<eT> A(X.get_ref()); 02272 02273 blas_int m = A.n_rows; 02274 blas_int n = A.n_cols; 02275 blas_int lda = A.n_rows; 02276 02277 S.set_size( static_cast<uword>((std::min)(m,n)) ); 02278 02279 blas_int ldu = 0; 02280 blas_int ldvt = 0; 02281 02282 char jobu; 02283 char jobvt; 02284 02285 switch(mode) 02286 { 02287 case 'l': 02288 jobu = 'S'; 02289 jobvt = 'N'; 02290 02291 ldu = m; 02292 ldvt = 1; 02293 02294 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); 02295 V.reset(); 02296 02297 break; 02298 02299 02300 case 'r': 02301 jobu = 'N'; 02302 jobvt = 'S'; 02303 02304 ldu = 1; 02305 ldvt = (std::min)(m,n); 02306 02307 U.reset(); 02308 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); 02309 02310 break; 02311 02312 02313 case 'b': 02314 jobu = 'S'; 02315 jobvt = 'S'; 02316 02317 ldu = m; 02318 ldvt = (std::min)(m,n); 02319 02320 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); 02321 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); 02322 02323 break; 02324 02325 02326 default: 02327 U.reset(); 02328 S.reset(); 02329 V.reset(); 02330 return false; 02331 } 02332 02333 02334 if(A.is_empty()) 02335 { 02336 U.eye(); 02337 S.reset(); 02338 V.eye(); 02339 return true; 02340 } 02341 02342 02343 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); 02344 blas_int info = 0; 02345 02346 02347 podarray<eT> work( static_cast<uword>(lwork) ); 02348 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) ); 02349 02350 // let gesvd_() calculate the optimum size of the workspace 02351 blas_int lwork_tmp = -1; 02352 02353 lapack::cx_gesvd<T> 02354 ( 02355 &jobu, &jobvt, 02356 &m, &n, 02357 A.memptr(), &lda, 02358 S.memptr(), 02359 U.memptr(), &ldu, 02360 V.memptr(), &ldvt, 02361 work.memptr(), &lwork_tmp, 02362 rwork.memptr(), 02363 &info 02364 ); 02365 02366 if(info == 0) 02367 { 02368 blas_int proposed_lwork = static_cast<blas_int>(real(work[0])); 02369 if(proposed_lwork > lwork) 02370 { 02371 lwork = proposed_lwork; 02372 work.set_size( static_cast<uword>(lwork) ); 02373 } 02374 02375 lapack::cx_gesvd<T> 02376 ( 02377 &jobu, &jobvt, 02378 &m, &n, 02379 A.memptr(), &lda, 02380 S.memptr(), 02381 U.memptr(), &ldu, 02382 V.memptr(), &ldvt, 02383 work.memptr(), &lwork, 02384 rwork.memptr(), 02385 &info 02386 ); 02387 02388 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done 02389 } 02390 02391 return (info == 0); 02392 } 02393 #else 02394 { 02395 arma_ignore(U); 02396 arma_ignore(S); 02397 arma_ignore(V); 02398 arma_ignore(X); 02399 arma_ignore(mode); 02400 arma_stop("svd(): use of LAPACK needs to be enabled"); 02401 return false; 02402 } 02403 #endif 02404 } 02405 02406 02407 02410 template<typename eT> 02411 inline 02412 bool 02413 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B, const bool slow) 02414 { 02415 arma_extra_debug_sigprint(); 02416 02417 if(A.is_empty() || B.is_empty()) 02418 { 02419 out.zeros(A.n_cols, B.n_cols); 02420 return true; 02421 } 02422 else 02423 { 02424 const uword A_n_rows = A.n_rows; 02425 02426 bool status = false; 02427 02428 if( (A_n_rows <= 4) && (slow == false) ) 02429 { 02430 Mat<eT> A_inv; 02431 02432 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows); 02433 02434 if(status == true) 02435 { 02436 out.set_size(A_n_rows, B.n_cols); 02437 02438 gemm_emul<false,false,false,false>::apply(out, A_inv, B); 02439 02440 return true; 02441 } 02442 } 02443 02444 if( (A_n_rows > 4) || (status == false) ) 02445 { 02446 #if defined(ARMA_USE_ATLAS) 02447 { 02448 podarray<int> ipiv(A_n_rows); 02449 02450 out = B; 02451 02452 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B.n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows); 02453 02454 return (info == 0); 02455 } 02456 #elif defined(ARMA_USE_LAPACK) 02457 { 02458 blas_int n = A_n_rows; 02459 blas_int lda = A_n_rows; 02460 blas_int ldb = A_n_rows; 02461 blas_int nrhs = B.n_cols; 02462 blas_int info; 02463 02464 podarray<blas_int> ipiv(A_n_rows); 02465 02466 out = B; 02467 02468 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); 02469 02470 return (info == 0); 02471 } 02472 #else 02473 { 02474 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled"); 02475 return false; 02476 } 02477 #endif 02478 } 02479 } 02480 02481 return true; 02482 } 02483 02484 02485 02488 template<typename eT> 02489 inline 02490 bool 02491 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B) 02492 { 02493 arma_extra_debug_sigprint(); 02494 02495 #if defined(ARMA_USE_LAPACK) 02496 { 02497 if(A.is_empty() || B.is_empty()) 02498 { 02499 out.zeros(A.n_cols, B.n_cols); 02500 return true; 02501 } 02502 02503 char trans = 'N'; 02504 02505 blas_int m = A.n_rows; 02506 blas_int n = A.n_cols; 02507 blas_int lda = A.n_rows; 02508 blas_int ldb = A.n_rows; 02509 blas_int nrhs = B.n_cols; 02510 blas_int lwork = n + (std::max)(n, nrhs); 02511 blas_int info; 02512 02513 Mat<eT> tmp = B; 02514 02515 podarray<eT> work( static_cast<uword>(lwork) ); 02516 02517 arma_extra_debug_print("lapack::gels()"); 02518 02519 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems 02520 02521 lapack::gels<eT> 02522 ( 02523 &trans, &m, &n, &nrhs, 02524 A.memptr(), &lda, 02525 tmp.memptr(), &ldb, 02526 work.memptr(), &lwork, 02527 &info 02528 ); 02529 02530 arma_extra_debug_print("lapack::gels() -- finished"); 02531 02532 out.set_size(A.n_cols, B.n_cols); 02533 02534 for(uword col=0; col<B.n_cols; ++col) 02535 { 02536 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols ); 02537 } 02538 02539 return (info == 0); 02540 } 02541 #else 02542 { 02543 arma_ignore(out); 02544 arma_ignore(A); 02545 arma_ignore(B); 02546 arma_stop("solve(): use of LAPACK needs to be enabled"); 02547 return false; 02548 } 02549 #endif 02550 } 02551 02552 02553 02556 template<typename eT> 02557 inline 02558 bool 02559 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B) 02560 { 02561 arma_extra_debug_sigprint(); 02562 02563 #if defined(ARMA_USE_LAPACK) 02564 { 02565 if(A.is_empty() || B.is_empty()) 02566 { 02567 out.zeros(A.n_cols, B.n_cols); 02568 return true; 02569 } 02570 02571 char trans = 'N'; 02572 02573 blas_int m = A.n_rows; 02574 blas_int n = A.n_cols; 02575 blas_int lda = A.n_rows; 02576 blas_int ldb = A.n_cols; 02577 blas_int nrhs = B.n_cols; 02578 blas_int lwork = m + (std::max)(m,nrhs); 02579 blas_int info; 02580 02581 02582 Mat<eT> tmp; 02583 tmp.zeros(A.n_cols, B.n_cols); 02584 02585 for(uword col=0; col<B.n_cols; ++col) 02586 { 02587 eT* tmp_colmem = tmp.colptr(col); 02588 02589 arrayops::copy( tmp_colmem, B.colptr(col), B.n_rows ); 02590 02591 for(uword row=B.n_rows; row<A.n_cols; ++row) 02592 { 02593 tmp_colmem[row] = eT(0); 02594 } 02595 } 02596 02597 podarray<eT> work( static_cast<uword>(lwork) ); 02598 02599 arma_extra_debug_print("lapack::gels()"); 02600 02601 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems 02602 02603 lapack::gels<eT> 02604 ( 02605 &trans, &m, &n, &nrhs, 02606 A.memptr(), &lda, 02607 tmp.memptr(), &ldb, 02608 work.memptr(), &lwork, 02609 &info 02610 ); 02611 02612 arma_extra_debug_print("lapack::gels() -- finished"); 02613 02614 out.set_size(A.n_cols, B.n_cols); 02615 02616 for(uword col=0; col<B.n_cols; ++col) 02617 { 02618 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols ); 02619 } 02620 02621 return (info == 0); 02622 } 02623 #else 02624 { 02625 arma_ignore(out); 02626 arma_ignore(A); 02627 arma_ignore(B); 02628 arma_stop("solve(): use of LAPACK needs to be enabled"); 02629 return false; 02630 } 02631 #endif 02632 } 02633 02634 02635 02636 // 02637 // solve_tr 02638 02639 template<typename eT> 02640 inline 02641 bool 02642 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout) 02643 { 02644 arma_extra_debug_sigprint(); 02645 02646 #if defined(ARMA_USE_LAPACK) 02647 { 02648 if(A.is_empty() || B.is_empty()) 02649 { 02650 out.zeros(A.n_cols, B.n_cols); 02651 return true; 02652 } 02653 02654 out = B; 02655 02656 char uplo = (layout == 0) ? 'U' : 'L'; 02657 char trans = 'N'; 02658 char diag = 'N'; 02659 blas_int n = blas_int(A.n_rows); 02660 blas_int nrhs = blas_int(B.n_cols); 02661 blas_int info = 0; 02662 02663 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); 02664 02665 return (info == 0); 02666 } 02667 #else 02668 { 02669 arma_ignore(out); 02670 arma_ignore(A); 02671 arma_ignore(B); 02672 arma_ignore(layout); 02673 arma_stop("solve(): use of LAPACK needs to be enabled"); 02674 return false; 02675 } 02676 #endif 02677 } 02678 02679 02680 02681 // 02682 // Schur decomposition 02683 02684 template<typename eT> 02685 inline 02686 bool 02687 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A) 02688 { 02689 arma_extra_debug_sigprint(); 02690 02691 #if defined(ARMA_USE_LAPACK) 02692 { 02693 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" ); 02694 02695 if(A.is_empty()) 02696 { 02697 Z.reset(); 02698 T.reset(); 02699 return true; 02700 } 02701 02702 const uword A_n_rows = A.n_rows; 02703 02704 char jobvs = 'V'; // get Schur vectors (Z) 02705 char sort = 'N'; // do not sort eigenvalues/vectors 02706 blas_int* select = 0; // pointer to sorting function 02707 blas_int n = blas_int(A_n_rows); 02708 blas_int sdim = 0; // output for sorting 02709 02710 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done 02711 02712 podarray<eT> work( static_cast<uword>(lwork) ); 02713 podarray<blas_int> bwork(A_n_rows); 02714 02715 blas_int info = 0; 02716 02717 Z.set_size(A_n_rows, A_n_rows); 02718 T = A; 02719 02720 podarray<eT> wr(A_n_rows); // output for eigenvalues 02721 podarray<eT> wi(A_n_rows); // output for eigenvalues 02722 02723 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info); 02724 02725 return (info == 0); 02726 } 02727 #else 02728 { 02729 arma_ignore(Z); 02730 arma_ignore(T); 02731 arma_stop("schur_dec(): use of LAPACK needs to be enabled"); 02732 return false; 02733 } 02734 #endif 02735 } 02736 02737 02738 02739 template<typename cT> 02740 inline 02741 bool 02742 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A) 02743 { 02744 arma_extra_debug_sigprint(); 02745 02746 #if defined(ARMA_USE_LAPACK) 02747 { 02748 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" ); 02749 02750 if(A.is_empty()) 02751 { 02752 Z.reset(); 02753 T.reset(); 02754 return true; 02755 } 02756 02757 typedef std::complex<cT> eT; 02758 02759 const uword A_n_rows = A.n_rows; 02760 02761 char jobvs = 'V'; // get Schur vectors (Z) 02762 char sort = 'N'; // do not sort eigenvalues/vectors 02763 blas_int* select = 0; // pointer to sorting function 02764 blas_int n = blas_int(A_n_rows); 02765 blas_int sdim = 0; // output for sorting 02766 02767 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done 02768 02769 podarray<eT> work( static_cast<uword>(lwork) ); 02770 podarray<blas_int> bwork(A_n_rows); 02771 02772 blas_int info = 0; 02773 02774 Z.set_size(A_n_rows, A_n_rows); 02775 T = A; 02776 02777 podarray<eT> w(A_n_rows); // output for eigenvalues 02778 podarray<cT> rwork(A_n_rows); 02779 02780 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info); 02781 02782 return (info == 0); 02783 } 02784 #else 02785 { 02786 arma_ignore(Z); 02787 arma_ignore(T); 02788 arma_stop("schur_dec(): use of LAPACK needs to be enabled"); 02789 return false; 02790 } 02791 #endif 02792 } 02793 02794 02795 02796 // 02797 // syl (solution of the Sylvester equation AX + XB = C) 02798 02799 template<typename eT> 02800 inline 02801 bool 02802 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C) 02803 { 02804 arma_extra_debug_sigprint(); 02805 02806 arma_debug_check( (A.is_square() == false), "syl(): matrix A is not square" ); 02807 arma_debug_check( (B.is_square() == false), "syl(): matrix B is not square" ); 02808 02809 arma_debug_check( (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols), "syl(): matrices are not conformant" ); 02810 02811 if(A.is_empty() || B.is_empty() || C.is_empty()) 02812 { 02813 X.reset(); 02814 return true; 02815 } 02816 02817 bool status; 02818 02819 #if defined(ARMA_USE_LAPACK) 02820 { 02821 Mat<eT> Z1, Z2, T1, T2; 02822 02823 status = auxlib::schur_dec(Z1, T1, A); 02824 if(status == false) 02825 { 02826 return false; 02827 } 02828 02829 status = auxlib::schur_dec(Z2, T2, B); 02830 if(status == false) 02831 { 02832 return false; 02833 } 02834 02835 char trana = 'N'; 02836 char tranb = 'N'; 02837 blas_int isgn = +1; 02838 blas_int m = blas_int(T1.n_rows); 02839 blas_int n = blas_int(T2.n_cols); 02840 02841 eT scale = eT(0); 02842 blas_int info = 0; 02843 02844 Mat<eT> Y = trans(Z1) * C * Z2; 02845 02846 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info); 02847 02848 //Y /= scale; 02849 Y /= (-scale); 02850 02851 X = Z1 * Y * trans(Z2); 02852 02853 status = (info == 0); 02854 } 02855 #else 02856 { 02857 arma_stop("syl(): use of LAPACK needs to be enabled"); 02858 return false; 02859 } 02860 #endif 02861 02862 02863 return status; 02864 } 02865 02866 02867 02868 // 02869 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0) 02870 02871 template<typename eT> 02872 inline 02873 bool 02874 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q) 02875 { 02876 arma_extra_debug_sigprint(); 02877 02878 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square"); 02879 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square"); 02880 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions"); 02881 02882 Mat<eT> htransA; 02883 op_htrans::apply_noalias(htransA, A); 02884 02885 const Mat<eT> mQ = -Q; 02886 02887 return auxlib::syl(X, A, htransA, mQ); 02888 } 02889 02890 02891 02892 // 02893 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0) 02894 02895 template<typename eT> 02896 inline 02897 bool 02898 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q) 02899 { 02900 arma_extra_debug_sigprint(); 02901 02902 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square"); 02903 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square"); 02904 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions"); 02905 02906 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1); 02907 02908 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A); 02909 02910 Col<eT> vecX; 02911 02912 const bool status = solve(vecX, M, vecQ); 02913 02914 if(status == true) 02915 { 02916 X = reshape(vecX, Q.n_rows, Q.n_cols); 02917 return true; 02918 } 02919 else 02920 { 02921 X.reset(); 02922 return false; 02923 } 02924 } 02925 02926 02927