00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00019
00020
00021
00023 template<typename eT, typename T1>
00024 inline
00025 bool
00026 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow)
00027 {
00028 arma_extra_debug_sigprint();
00029
00030 out = X.get_ref();
00031
00032 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
00033
00034 bool status = false;
00035
00036 const uword N = out.n_rows;
00037
00038 if( (N <= 4) && (slow == false) )
00039 {
00040 status = auxlib::inv_inplace_tinymat(out, N);
00041 }
00042
00043 if( (N > 4) || (status == false) )
00044 {
00045 status = auxlib::inv_inplace_lapack(out);
00046 }
00047
00048 return status;
00049 }
00050
00051
00052
00053 template<typename eT>
00054 inline
00055 bool
00056 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow)
00057 {
00058 arma_extra_debug_sigprint();
00059
00060 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" );
00061
00062 bool status = false;
00063
00064 const uword N = X.n_rows;
00065
00066 if( (N <= 4) && (slow == false) )
00067 {
00068 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N);
00069 }
00070
00071 if( (N > 4) || (status == false) )
00072 {
00073 out = X;
00074 status = auxlib::inv_inplace_lapack(out);
00075 }
00076
00077 return status;
00078 }
00079
00080
00081
00082 template<typename eT>
00083 inline
00084 bool
00085 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N)
00086 {
00087 arma_extra_debug_sigprint();
00088
00089 bool det_ok = true;
00090
00091 out.set_size(N,N);
00092
00093 switch(N)
00094 {
00095 case 1:
00096 {
00097 out[0] = eT(1) / X[0];
00098 };
00099 break;
00100
00101 case 2:
00102 {
00103 const eT* Xm = X.memptr();
00104
00105 const eT a = Xm[pos<0,0>::n2];
00106 const eT b = Xm[pos<0,1>::n2];
00107 const eT c = Xm[pos<1,0>::n2];
00108 const eT d = Xm[pos<1,1>::n2];
00109
00110 const eT tmp_det = (a*d - b*c);
00111
00112 if(tmp_det != eT(0))
00113 {
00114 eT* outm = out.memptr();
00115
00116 outm[pos<0,0>::n2] = d / tmp_det;
00117 outm[pos<0,1>::n2] = -b / tmp_det;
00118 outm[pos<1,0>::n2] = -c / tmp_det;
00119 outm[pos<1,1>::n2] = a / tmp_det;
00120 }
00121 else
00122 {
00123 det_ok = false;
00124 }
00125 };
00126 break;
00127
00128 case 3:
00129 {
00130 const eT* X_col0 = X.colptr(0);
00131 const eT a11 = X_col0[0];
00132 const eT a21 = X_col0[1];
00133 const eT a31 = X_col0[2];
00134
00135 const eT* X_col1 = X.colptr(1);
00136 const eT a12 = X_col1[0];
00137 const eT a22 = X_col1[1];
00138 const eT a32 = X_col1[2];
00139
00140 const eT* X_col2 = X.colptr(2);
00141 const eT a13 = X_col2[0];
00142 const eT a23 = X_col2[1];
00143 const eT a33 = X_col2[2];
00144
00145 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
00146
00147 if(tmp_det != eT(0))
00148 {
00149 eT* out_col0 = out.colptr(0);
00150 out_col0[0] = (a33*a22 - a32*a23) / tmp_det;
00151 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
00152 out_col0[2] = (a32*a21 - a31*a22) / tmp_det;
00153
00154 eT* out_col1 = out.colptr(1);
00155 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
00156 out_col1[1] = (a33*a11 - a31*a13) / tmp_det;
00157 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
00158
00159 eT* out_col2 = out.colptr(2);
00160 out_col2[0] = (a23*a12 - a22*a13) / tmp_det;
00161 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
00162 out_col2[2] = (a22*a11 - a21*a12) / tmp_det;
00163 }
00164 else
00165 {
00166 det_ok = false;
00167 }
00168 };
00169 break;
00170
00171 case 4:
00172 {
00173 const eT tmp_det = det(X);
00174
00175 if(tmp_det != eT(0))
00176 {
00177 const eT* Xm = X.memptr();
00178 eT* outm = out.memptr();
00179
00180 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00181 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00182 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00183 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
00184
00185 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00186 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00187 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00188 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
00189
00190 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00191 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00192 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
00193 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
00194
00195 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
00196 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
00197 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
00198 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det;
00199 }
00200 else
00201 {
00202 det_ok = false;
00203 }
00204 };
00205 break;
00206
00207 default:
00208 ;
00209 }
00210
00211 return det_ok;
00212 }
00213
00214
00215
00216 template<typename eT>
00217 inline
00218 bool
00219 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N)
00220 {
00221 arma_extra_debug_sigprint();
00222
00223 bool det_ok = true;
00224
00225
00226
00227
00228
00229
00230
00231 switch(N)
00232 {
00233 case 1:
00234 {
00235 X[0] = eT(1) / X[0];
00236 };
00237 break;
00238
00239 case 2:
00240 {
00241 const eT a = X[pos<0,0>::n2];
00242 const eT b = X[pos<0,1>::n2];
00243 const eT c = X[pos<1,0>::n2];
00244 const eT d = X[pos<1,1>::n2];
00245
00246 const eT tmp_det = (a*d - b*c);
00247
00248 if(tmp_det != eT(0))
00249 {
00250 X[pos<0,0>::n2] = d / tmp_det;
00251 X[pos<0,1>::n2] = -b / tmp_det;
00252 X[pos<1,0>::n2] = -c / tmp_det;
00253 X[pos<1,1>::n2] = a / tmp_det;
00254 }
00255 else
00256 {
00257 det_ok = false;
00258 }
00259 };
00260 break;
00261
00262 case 3:
00263 {
00264 eT* X_col0 = X.colptr(0);
00265 eT* X_col1 = X.colptr(1);
00266 eT* X_col2 = X.colptr(2);
00267
00268 const eT a11 = X_col0[0];
00269 const eT a21 = X_col0[1];
00270 const eT a31 = X_col0[2];
00271
00272 const eT a12 = X_col1[0];
00273 const eT a22 = X_col1[1];
00274 const eT a32 = X_col1[2];
00275
00276 const eT a13 = X_col2[0];
00277 const eT a23 = X_col2[1];
00278 const eT a33 = X_col2[2];
00279
00280 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
00281
00282 if(tmp_det != eT(0))
00283 {
00284 X_col0[0] = (a33*a22 - a32*a23) / tmp_det;
00285 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
00286 X_col0[2] = (a32*a21 - a31*a22) / tmp_det;
00287
00288 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
00289 X_col1[1] = (a33*a11 - a31*a13) / tmp_det;
00290 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
00291
00292 X_col2[0] = (a23*a12 - a22*a13) / tmp_det;
00293 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
00294 X_col2[2] = (a22*a11 - a21*a12) / tmp_det;
00295 }
00296 else
00297 {
00298 det_ok = false;
00299 }
00300 };
00301 break;
00302
00303 case 4:
00304 {
00305 const eT tmp_det = det(X);
00306
00307 if(tmp_det != eT(0))
00308 {
00309 const Mat<eT> A(X);
00310
00311 const eT* Am = A.memptr();
00312 eT* Xm = X.memptr();
00313
00314 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00315 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00316 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00317 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
00318
00319 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00320 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00321 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00322 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
00323
00324 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00325 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00326 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
00327 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
00328
00329 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
00330 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
00331 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
00332 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det;
00333 }
00334 else
00335 {
00336 det_ok = false;
00337 }
00338 };
00339 break;
00340
00341 default:
00342 ;
00343 }
00344
00345 return det_ok;
00346 }
00347
00348
00349
00350 template<typename eT>
00351 inline
00352 bool
00353 auxlib::inv_inplace_lapack(Mat<eT>& out)
00354 {
00355 arma_extra_debug_sigprint();
00356
00357 if(out.is_empty())
00358 {
00359 return true;
00360 }
00361
00362 #if defined(ARMA_USE_ATLAS)
00363 {
00364 podarray<int> ipiv(out.n_rows);
00365
00366 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr());
00367
00368 if(info == 0)
00369 {
00370 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr());
00371 }
00372
00373 return (info == 0);
00374 }
00375 #elif defined(ARMA_USE_LAPACK)
00376 {
00377 blas_int n_rows = out.n_rows;
00378 blas_int n_cols = out.n_cols;
00379 blas_int info = 0;
00380
00381 podarray<blas_int> ipiv(out.n_rows);
00382
00383
00384
00385
00386
00387
00388
00389 blas_int work_len = (std::max)(blas_int(1), n_rows*84);
00390 podarray<eT> work( static_cast<uword>(work_len) );
00391
00392 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info);
00393
00394 if(info == 0)
00395 {
00396
00397
00398 blas_int work_len_tmp = -1;
00399 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len_tmp, &info);
00400
00401 if(info == 0)
00402 {
00403 blas_int proposed_work_len = static_cast<blas_int>(access::tmp_real(work[0]));
00404
00405
00406 if(work_len < proposed_work_len)
00407 {
00408 work_len = proposed_work_len;
00409 work.set_size( static_cast<uword>(work_len) );
00410 }
00411 }
00412
00413 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len, &info);
00414 }
00415
00416 return (info == 0);
00417 }
00418 #else
00419 {
00420 arma_ignore(out);
00421 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled");
00422 return false;
00423 }
00424 #endif
00425 }
00426
00427
00428
00429 template<typename eT, typename T1>
00430 inline
00431 bool
00432 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
00433 {
00434 arma_extra_debug_sigprint();
00435
00436 out = X.get_ref();
00437
00438 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
00439
00440 if(out.is_empty())
00441 {
00442 return true;
00443 }
00444
00445 bool status;
00446
00447 #if defined(ARMA_USE_LAPACK)
00448 {
00449 char uplo = (layout == 0) ? 'U' : 'L';
00450 char diag = 'N';
00451 blas_int n = blas_int(out.n_rows);
00452 blas_int info = 0;
00453
00454 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info);
00455
00456 status = (info == 0);
00457 }
00458 #else
00459 {
00460 arma_ignore(layout);
00461 arma_stop("inv(): use of LAPACK needs to be enabled");
00462 status = false;
00463 }
00464 #endif
00465
00466
00467 if(status == true)
00468 {
00469 if(layout == 0)
00470 {
00471
00472 out = trimatu(out);
00473 }
00474 else
00475 {
00476
00477 out = trimatl(out);
00478 }
00479 }
00480
00481 return status;
00482 }
00483
00484
00485
00486 template<typename eT, typename T1>
00487 inline
00488 bool
00489 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
00490 {
00491 arma_extra_debug_sigprint();
00492
00493 out = X.get_ref();
00494
00495 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
00496
00497 if(out.is_empty())
00498 {
00499 return true;
00500 }
00501
00502 bool status;
00503
00504 #if defined(ARMA_USE_LAPACK)
00505 {
00506 char uplo = (layout == 0) ? 'U' : 'L';
00507 blas_int n = blas_int(out.n_rows);
00508 blas_int lwork = n*n;
00509 blas_int info = 0;
00510
00511 podarray<blas_int> ipiv;
00512 ipiv.set_size(out.n_rows);
00513
00514 podarray<eT> work;
00515 work.set_size( uword(lwork) );
00516
00517 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info);
00518
00519 status = (info == 0);
00520
00521 if(status == true)
00522 {
00523 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info);
00524
00525 out = (layout == 0) ? symmatu(out) : symmatl(out);
00526
00527 status = (info == 0);
00528 }
00529 }
00530 #else
00531 {
00532 arma_ignore(layout);
00533 arma_stop("inv(): use of LAPACK needs to be enabled");
00534 status = false;
00535 }
00536 #endif
00537
00538 return status;
00539 }
00540
00541
00542
00543 template<typename eT, typename T1>
00544 inline
00545 bool
00546 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
00547 {
00548 arma_extra_debug_sigprint();
00549
00550 out = X.get_ref();
00551
00552 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
00553
00554 if(out.is_empty())
00555 {
00556 return true;
00557 }
00558
00559 bool status;
00560
00561 #if defined(ARMA_USE_LAPACK)
00562 {
00563 char uplo = (layout == 0) ? 'U' : 'L';
00564 blas_int n = blas_int(out.n_rows);
00565 blas_int info = 0;
00566
00567 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
00568
00569 status = (info == 0);
00570
00571 if(status == true)
00572 {
00573 lapack::potri(&uplo, &n, out.memptr(), &n, &info);
00574
00575 out = (layout == 0) ? symmatu(out) : symmatl(out);
00576
00577 status = (info == 0);
00578 }
00579 }
00580 #else
00581 {
00582 arma_ignore(layout);
00583 arma_stop("inv(): use of LAPACK needs to be enabled");
00584 status = false;
00585 }
00586 #endif
00587
00588 return status;
00589 }
00590
00591
00592
00593 template<typename eT, typename T1>
00594 inline
00595 eT
00596 auxlib::det(const Base<eT,T1>& X, const bool slow)
00597 {
00598 const unwrap<T1> tmp(X.get_ref());
00599 const Mat<eT>& A = tmp.M;
00600
00601 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" );
00602
00603 const bool make_copy = (is_Mat<T1>::value == true) ? true : false;
00604
00605 if(slow == false)
00606 {
00607 const uword N = A.n_rows;
00608
00609 switch(N)
00610 {
00611 case 0:
00612 case 1:
00613 case 2:
00614 return auxlib::det_tinymat(A, N);
00615 break;
00616
00617 case 3:
00618 case 4:
00619 {
00620 const eT tmp_det = auxlib::det_tinymat(A, N);
00621 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy);
00622 }
00623 break;
00624
00625 default:
00626 return auxlib::det_lapack(A, make_copy);
00627 }
00628 }
00629 else
00630 {
00631 return auxlib::det_lapack(A, make_copy);
00632 }
00633 }
00634
00635
00636
00637 template<typename eT>
00638 inline
00639 eT
00640 auxlib::det_tinymat(const Mat<eT>& X, const uword N)
00641 {
00642 arma_extra_debug_sigprint();
00643
00644 switch(N)
00645 {
00646 case 0:
00647 return eT(1);
00648 break;
00649
00650 case 1:
00651 return X[0];
00652 break;
00653
00654 case 2:
00655 {
00656 const eT* Xm = X.memptr();
00657
00658 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
00659 }
00660 break;
00661
00662 case 3:
00663 {
00664
00665
00666
00667
00668
00669
00670
00671
00672 const eT* a_col0 = X.colptr(0);
00673 const eT a11 = a_col0[0];
00674 const eT a21 = a_col0[1];
00675 const eT a31 = a_col0[2];
00676
00677 const eT* a_col1 = X.colptr(1);
00678 const eT a12 = a_col1[0];
00679 const eT a22 = a_col1[1];
00680 const eT a32 = a_col1[2];
00681
00682 const eT* a_col2 = X.colptr(2);
00683 const eT a13 = a_col2[0];
00684 const eT a23 = a_col2[1];
00685 const eT a33 = a_col2[2];
00686
00687 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) );
00688 }
00689 break;
00690
00691 case 4:
00692 {
00693 const eT* Xm = X.memptr();
00694
00695 const eT val = \
00696 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
00697 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
00698 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
00699 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
00700 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
00701 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
00702 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
00703 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
00704 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
00705 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
00706 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
00707 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
00708 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
00709 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
00710 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
00711 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
00712 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
00713 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
00714 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
00715 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
00716 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
00717 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
00718 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
00719 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
00720 ;
00721
00722 return val;
00723 }
00724 break;
00725
00726 default:
00727 return eT(0);
00728 ;
00729 }
00730 }
00731
00732
00733
00735 template<typename eT>
00736 inline
00737 eT
00738 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy)
00739 {
00740 arma_extra_debug_sigprint();
00741
00742 Mat<eT> X_copy;
00743
00744 if(make_copy == true)
00745 {
00746 X_copy = X;
00747 }
00748
00749 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X);
00750
00751 if(tmp.is_empty())
00752 {
00753 return eT(1);
00754 }
00755
00756
00757 #if defined(ARMA_USE_ATLAS)
00758 {
00759 podarray<int> ipiv(tmp.n_rows);
00760
00761
00762 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
00763
00764
00765 eT val = tmp.at(0,0);
00766 for(uword i=1; i < tmp.n_rows; ++i)
00767 {
00768 val *= tmp.at(i,i);
00769 }
00770
00771 int sign = +1;
00772 for(uword i=0; i < tmp.n_rows; ++i)
00773 {
00774 if( int(i) != ipiv.mem[i] )
00775 {
00776 sign *= -1;
00777 }
00778 }
00779
00780 return ( (sign < 0) ? -val : val );
00781 }
00782 #elif defined(ARMA_USE_LAPACK)
00783 {
00784 podarray<blas_int> ipiv(tmp.n_rows);
00785
00786 blas_int info = 0;
00787 blas_int n_rows = blas_int(tmp.n_rows);
00788 blas_int n_cols = blas_int(tmp.n_cols);
00789
00790 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
00791
00792
00793 eT val = tmp.at(0,0);
00794 for(uword i=1; i < tmp.n_rows; ++i)
00795 {
00796 val *= tmp.at(i,i);
00797 }
00798
00799 blas_int sign = +1;
00800 for(uword i=0; i < tmp.n_rows; ++i)
00801 {
00802 if( blas_int(i) != (ipiv.mem[i] - 1) )
00803 {
00804 sign *= -1;
00805 }
00806 }
00807
00808 return ( (sign < 0) ? -val : val );
00809 }
00810 #else
00811 {
00812 arma_ignore(X);
00813 arma_ignore(make_copy);
00814 arma_ignore(tmp);
00815 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled");
00816 return eT(0);
00817 }
00818 #endif
00819 }
00820
00821
00822
00824 template<typename eT, typename T1>
00825 inline
00826 bool
00827 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X)
00828 {
00829 arma_extra_debug_sigprint();
00830
00831 typedef typename get_pod_type<eT>::result T;
00832
00833 #if defined(ARMA_USE_ATLAS)
00834 {
00835 Mat<eT> tmp(X.get_ref());
00836 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
00837
00838 if(tmp.is_empty())
00839 {
00840 out_val = eT(0);
00841 out_sign = T(1);
00842 return true;
00843 }
00844
00845 podarray<int> ipiv(tmp.n_rows);
00846
00847 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
00848
00849
00850
00851 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
00852 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
00853
00854 for(uword i=1; i < tmp.n_rows; ++i)
00855 {
00856 const eT x = tmp.at(i,i);
00857
00858 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
00859 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
00860 }
00861
00862 for(uword i=0; i < tmp.n_rows; ++i)
00863 {
00864 if( int(i) != ipiv.mem[i] )
00865 {
00866 sign *= -1;
00867 }
00868 }
00869
00870 out_val = val;
00871 out_sign = T(sign);
00872
00873 return (info == 0);
00874 }
00875 #elif defined(ARMA_USE_LAPACK)
00876 {
00877 Mat<eT> tmp(X.get_ref());
00878 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
00879
00880 if(tmp.is_empty())
00881 {
00882 out_val = eT(0);
00883 out_sign = T(1);
00884 return true;
00885 }
00886
00887 podarray<blas_int> ipiv(tmp.n_rows);
00888
00889 blas_int info = 0;
00890 blas_int n_rows = blas_int(tmp.n_rows);
00891 blas_int n_cols = blas_int(tmp.n_cols);
00892
00893 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
00894
00895
00896
00897 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
00898 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
00899
00900 for(uword i=1; i < tmp.n_rows; ++i)
00901 {
00902 const eT x = tmp.at(i,i);
00903
00904 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
00905 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
00906 }
00907
00908 for(uword i=0; i < tmp.n_rows; ++i)
00909 {
00910 if( blas_int(i) != (ipiv.mem[i] - 1) )
00911 {
00912 sign *= -1;
00913 }
00914 }
00915
00916 out_val = val;
00917 out_sign = T(sign);
00918
00919 return (info == 0);
00920 }
00921 #else
00922 {
00923 out_val = eT(0);
00924 out_sign = T(0);
00925
00926 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled");
00927
00928 return false;
00929 }
00930 #endif
00931 }
00932
00933
00934
00936 template<typename eT, typename T1>
00937 inline
00938 bool
00939 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X)
00940 {
00941 arma_extra_debug_sigprint();
00942
00943 U = X.get_ref();
00944
00945 const uword U_n_rows = U.n_rows;
00946 const uword U_n_cols = U.n_cols;
00947
00948 if(U.is_empty())
00949 {
00950 L.set_size(U_n_rows, 0);
00951 U.set_size(0, U_n_cols);
00952 ipiv.reset();
00953 return true;
00954 }
00955
00956 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK)
00957 {
00958 bool status;
00959
00960 #if defined(ARMA_USE_ATLAS)
00961 {
00962 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
00963
00964 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr());
00965
00966 status = (info == 0);
00967 }
00968 #elif defined(ARMA_USE_LAPACK)
00969 {
00970 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
00971
00972 blas_int info = 0;
00973
00974 blas_int n_rows = U_n_rows;
00975 blas_int n_cols = U_n_cols;
00976
00977
00978 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info);
00979
00980
00981 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem);
00982
00983 status = (info == 0);
00984 }
00985 #endif
00986
00987 L.copy_size(U);
00988
00989 for(uword col=0; col < U_n_cols; ++col)
00990 {
00991 for(uword row=0; (row < col) && (row < U_n_rows); ++row)
00992 {
00993 L.at(row,col) = eT(0);
00994 }
00995
00996 if( L.in_range(col,col) == true )
00997 {
00998 L.at(col,col) = eT(1);
00999 }
01000
01001 for(uword row = (col+1); row < U_n_rows; ++row)
01002 {
01003 L.at(row,col) = U.at(row,col);
01004 U.at(row,col) = eT(0);
01005 }
01006 }
01007
01008 return status;
01009 }
01010 #else
01011 {
01012 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled");
01013
01014 return false;
01015 }
01016 #endif
01017 }
01018
01019
01020
01021 template<typename eT, typename T1>
01022 inline
01023 bool
01024 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X)
01025 {
01026 arma_extra_debug_sigprint();
01027
01028 podarray<blas_int> ipiv1;
01029 const bool status = auxlib::lu(L, U, ipiv1, X);
01030
01031 if(status == true)
01032 {
01033 if(U.is_empty())
01034 {
01035
01036 P.eye(L.n_rows, L.n_rows);
01037 return true;
01038 }
01039
01040 const uword n = ipiv1.n_elem;
01041 const uword P_rows = U.n_rows;
01042
01043 podarray<blas_int> ipiv2(P_rows);
01044
01045 const blas_int* ipiv1_mem = ipiv1.memptr();
01046 blas_int* ipiv2_mem = ipiv2.memptr();
01047
01048 for(uword i=0; i<P_rows; ++i)
01049 {
01050 ipiv2_mem[i] = blas_int(i);
01051 }
01052
01053 for(uword i=0; i<n; ++i)
01054 {
01055 const uword k = static_cast<uword>(ipiv1_mem[i]);
01056
01057 if( ipiv2_mem[i] != ipiv2_mem[k] )
01058 {
01059 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
01060 }
01061 }
01062
01063 P.zeros(P_rows, P_rows);
01064
01065 for(uword row=0; row<P_rows; ++row)
01066 {
01067 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1);
01068 }
01069
01070 if(L.n_cols > U.n_rows)
01071 {
01072 L.shed_cols(U.n_rows, L.n_cols-1);
01073 }
01074
01075 if(U.n_rows > L.n_cols)
01076 {
01077 U.shed_rows(L.n_cols, U.n_rows-1);
01078 }
01079 }
01080
01081 return status;
01082 }
01083
01084
01085
01086 template<typename eT, typename T1>
01087 inline
01088 bool
01089 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X)
01090 {
01091 arma_extra_debug_sigprint();
01092
01093 podarray<blas_int> ipiv1;
01094 const bool status = auxlib::lu(L, U, ipiv1, X);
01095
01096 if(status == true)
01097 {
01098 if(U.is_empty())
01099 {
01100
01101 return true;
01102 }
01103
01104 const uword n = ipiv1.n_elem;
01105 const uword P_rows = U.n_rows;
01106
01107 podarray<blas_int> ipiv2(P_rows);
01108
01109 const blas_int* ipiv1_mem = ipiv1.memptr();
01110 blas_int* ipiv2_mem = ipiv2.memptr();
01111
01112 for(uword i=0; i<P_rows; ++i)
01113 {
01114 ipiv2_mem[i] = blas_int(i);
01115 }
01116
01117 for(uword i=0; i<n; ++i)
01118 {
01119 const uword k = static_cast<uword>(ipiv1_mem[i]);
01120
01121 if( ipiv2_mem[i] != ipiv2_mem[k] )
01122 {
01123 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
01124 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) );
01125 }
01126 }
01127
01128 if(L.n_cols > U.n_rows)
01129 {
01130 L.shed_cols(U.n_rows, L.n_cols-1);
01131 }
01132
01133 if(U.n_rows > L.n_cols)
01134 {
01135 U.shed_rows(L.n_cols, U.n_rows-1);
01136 }
01137 }
01138
01139 return status;
01140 }
01141
01142
01143
01145 template<typename eT, typename T1>
01146 inline
01147 bool
01148 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X)
01149 {
01150 arma_extra_debug_sigprint();
01151
01152 #if defined(ARMA_USE_LAPACK)
01153 {
01154 Mat<eT> A(X.get_ref());
01155
01156 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
01157
01158 if(A.is_empty())
01159 {
01160 eigval.reset();
01161 return true;
01162 }
01163
01164
01165
01166
01167 char jobz = 'N';
01168 char uplo = 'U';
01169
01170 blas_int n_rows = A.n_rows;
01171 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
01172
01173 eigval.set_size( static_cast<uword>(n_rows) );
01174 podarray<eT> work( static_cast<uword>(lwork) );
01175
01176 blas_int info;
01177
01178 arma_extra_debug_print("lapack::syev()");
01179 lapack::syev(&jobz, &uplo, &n_rows, A.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
01180
01181 return (info == 0);
01182 }
01183 #else
01184 {
01185 arma_ignore(eigval);
01186 arma_ignore(X);
01187 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
01188 return false;
01189 }
01190 #endif
01191 }
01192
01193
01194
01196 template<typename T, typename T1>
01197 inline
01198 bool
01199 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X)
01200 {
01201 arma_extra_debug_sigprint();
01202
01203 typedef typename std::complex<T> eT;
01204
01205 #if defined(ARMA_USE_LAPACK)
01206 {
01207 Mat<eT> A(X.get_ref());
01208 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not hermitian");
01209
01210 if(A.is_empty())
01211 {
01212 eigval.reset();
01213 return true;
01214 }
01215
01216 char jobz = 'N';
01217 char uplo = 'U';
01218
01219 blas_int n_rows = A.n_rows;
01220 blas_int lda = A.n_rows;
01221 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1);
01222
01223 eigval.set_size( static_cast<uword>(n_rows) );
01224
01225 podarray<eT> work( static_cast<uword>(lwork) );
01226 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
01227
01228 blas_int info;
01229
01230 arma_extra_debug_print("lapack::heev()");
01231 lapack::heev(&jobz, &uplo, &n_rows, A.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
01232
01233 return (info == 0);
01234 }
01235 #else
01236 {
01237 arma_ignore(eigval);
01238 arma_ignore(X);
01239 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
01240 return false;
01241 }
01242 #endif
01243 }
01244
01245
01246
01248 template<typename eT, typename T1>
01249 inline
01250 bool
01251 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
01252 {
01253 arma_extra_debug_sigprint();
01254
01255 #if defined(ARMA_USE_LAPACK)
01256 {
01257 eigvec = X.get_ref();
01258
01259 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
01260
01261 if(eigvec.is_empty())
01262 {
01263 eigval.reset();
01264 eigvec.reset();
01265 return true;
01266 }
01267
01268
01269
01270
01271 char jobz = 'V';
01272 char uplo = 'U';
01273
01274 blas_int n_rows = eigvec.n_rows;
01275 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
01276
01277 eigval.set_size( static_cast<uword>(n_rows) );
01278 podarray<eT> work( static_cast<uword>(lwork) );
01279
01280 blas_int info;
01281
01282 arma_extra_debug_print("lapack::syev()");
01283 lapack::syev(&jobz, &uplo, &n_rows, eigvec.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
01284
01285 return (info == 0);
01286 }
01287 #else
01288 {
01289 arma_ignore(eigval);
01290 arma_ignore(eigvec);
01291 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
01292
01293 return false;
01294 }
01295 #endif
01296 }
01297
01298
01299
01301 template<typename T, typename T1>
01302 inline
01303 bool
01304 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
01305 {
01306 arma_extra_debug_sigprint();
01307
01308 typedef typename std::complex<T> eT;
01309
01310 #if defined(ARMA_USE_LAPACK)
01311 {
01312 eigvec = X.get_ref();
01313
01314 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not hermitian" );
01315
01316 if(eigvec.is_empty())
01317 {
01318 eigval.reset();
01319 eigvec.reset();
01320 return true;
01321 }
01322
01323 char jobz = 'V';
01324 char uplo = 'U';
01325
01326 blas_int n_rows = eigvec.n_rows;
01327 blas_int lda = eigvec.n_rows;
01328 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1);
01329
01330 eigval.set_size( static_cast<uword>(n_rows) );
01331
01332 podarray<eT> work( static_cast<uword>(lwork) );
01333 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
01334
01335 blas_int info;
01336
01337 arma_extra_debug_print("lapack::heev()");
01338 lapack::heev(&jobz, &uplo, &n_rows, eigvec.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
01339
01340 return (info == 0);
01341 }
01342 #else
01343 {
01344 arma_ignore(eigval);
01345 arma_ignore(eigvec);
01346 arma_ignore(X);
01347 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
01348 return false;
01349 }
01350 #endif
01351 }
01352
01353
01354
01358 template<typename T, typename T1>
01359 inline
01360 bool
01361 auxlib::eig_gen
01362 (
01363 Col< std::complex<T> >& eigval,
01364 Mat<T>& l_eigvec,
01365 Mat<T>& r_eigvec,
01366 const Base<T,T1>& X,
01367 const char side
01368 )
01369 {
01370 arma_extra_debug_sigprint();
01371
01372 #if defined(ARMA_USE_LAPACK)
01373 {
01374 char jobvl;
01375 char jobvr;
01376
01377 switch(side)
01378 {
01379 case 'l':
01380 jobvl = 'V';
01381 jobvr = 'N';
01382 break;
01383
01384 case 'r':
01385 jobvl = 'N';
01386 jobvr = 'V';
01387 break;
01388
01389 case 'b':
01390 jobvl = 'V';
01391 jobvr = 'V';
01392 break;
01393
01394 case 'n':
01395 jobvl = 'N';
01396 jobvr = 'N';
01397 break;
01398
01399 default:
01400 arma_stop("eig_gen(): parameter 'side' is invalid");
01401 return false;
01402 }
01403
01404 Mat<T> A(X.get_ref());
01405 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
01406
01407 if(A.is_empty())
01408 {
01409 eigval.reset();
01410 l_eigvec.reset();
01411 r_eigvec.reset();
01412 return true;
01413 }
01414
01415 uword A_n_rows = A.n_rows;
01416
01417 blas_int n_rows = A_n_rows;
01418 blas_int lda = A_n_rows;
01419 blas_int lwork = (std::max)(blas_int(1), 4*n_rows);
01420
01421 eigval.set_size(A_n_rows);
01422 l_eigvec.set_size(A_n_rows, A_n_rows);
01423 r_eigvec.set_size(A_n_rows, A_n_rows);
01424
01425 podarray<T> work( static_cast<uword>(lwork) );
01426 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) );
01427
01428 podarray<T> wr(A_n_rows);
01429 podarray<T> wi(A_n_rows);
01430
01431 Mat<T> A_copy = A;
01432 blas_int info;
01433
01434 arma_extra_debug_print("lapack::geev()");
01435 lapack::geev(&jobvl, &jobvr, &n_rows, A_copy.memptr(), &lda, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, &info);
01436
01437
01438 eigval.set_size(A_n_rows);
01439 for(uword i=0; i<A_n_rows; ++i)
01440 {
01441 eigval[i] = std::complex<T>(wr[i], wi[i]);
01442 }
01443
01444 return (info == 0);
01445 }
01446 #else
01447 {
01448 arma_ignore(eigval);
01449 arma_ignore(l_eigvec);
01450 arma_ignore(r_eigvec);
01451 arma_ignore(X);
01452 arma_ignore(side);
01453 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
01454 return false;
01455 }
01456 #endif
01457 }
01458
01459
01460
01461
01462
01466 template<typename T, typename T1>
01467 inline
01468 bool
01469 auxlib::eig_gen
01470 (
01471 Col< std::complex<T> >& eigval,
01472 Mat< std::complex<T> >& l_eigvec,
01473 Mat< std::complex<T> >& r_eigvec,
01474 const Base< std::complex<T>, T1 >& X,
01475 const char side
01476 )
01477 {
01478 arma_extra_debug_sigprint();
01479
01480 typedef typename std::complex<T> eT;
01481
01482 #if defined(ARMA_USE_LAPACK)
01483 {
01484 char jobvl;
01485 char jobvr;
01486
01487 switch(side)
01488 {
01489 case 'l':
01490 jobvl = 'V';
01491 jobvr = 'N';
01492 break;
01493
01494 case 'r':
01495 jobvl = 'N';
01496 jobvr = 'V';
01497 break;
01498
01499 case 'b':
01500 jobvl = 'V';
01501 jobvr = 'V';
01502 break;
01503
01504 case 'n':
01505 jobvl = 'N';
01506 jobvr = 'N';
01507 break;
01508
01509 default:
01510 arma_stop("eig_gen(): parameter 'side' is invalid");
01511 return false;
01512 }
01513
01514 Mat<eT> A(X.get_ref());
01515 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
01516
01517 if(A.is_empty())
01518 {
01519 eigval.reset();
01520 l_eigvec.reset();
01521 r_eigvec.reset();
01522 return true;
01523 }
01524
01525 uword A_n_rows = A.n_rows;
01526
01527 blas_int n_rows = A_n_rows;
01528 blas_int lda = A_n_rows;
01529 blas_int lwork = (std::max)(blas_int(1), 4*n_rows);
01530
01531 eigval.set_size(A_n_rows);
01532 l_eigvec.set_size(A_n_rows, A_n_rows);
01533 r_eigvec.set_size(A_n_rows, A_n_rows);
01534
01535 podarray<eT> work( static_cast<uword>(lwork) );
01536 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) );
01537
01538 blas_int info;
01539
01540 arma_extra_debug_print("lapack::cx_geev()");
01541 lapack::cx_geev(&jobvl, &jobvr, &n_rows, A.memptr(), &lda, eigval.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, rwork.memptr(), &info);
01542
01543 return (info == 0);
01544 }
01545 #else
01546 {
01547 arma_ignore(eigval);
01548 arma_ignore(l_eigvec);
01549 arma_ignore(r_eigvec);
01550 arma_ignore(X);
01551 arma_ignore(side);
01552 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
01553 return false;
01554 }
01555 #endif
01556 }
01557
01558
01559
01560 template<typename eT, typename T1>
01561 inline
01562 bool
01563 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X)
01564 {
01565 arma_extra_debug_sigprint();
01566
01567 #if defined(ARMA_USE_LAPACK)
01568 {
01569 out = X.get_ref();
01570
01571 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" );
01572
01573 if(out.is_empty())
01574 {
01575 return true;
01576 }
01577
01578 const uword out_n_rows = out.n_rows;
01579
01580 char uplo = 'U';
01581 blas_int n = out_n_rows;
01582 blas_int info;
01583
01584 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
01585
01586 for(uword col=0; col<out_n_rows; ++col)
01587 {
01588 eT* colptr = out.colptr(col);
01589
01590 for(uword row=(col+1); row < out_n_rows; ++row)
01591 {
01592 colptr[row] = eT(0);
01593 }
01594 }
01595
01596 return (info == 0);
01597 }
01598 #else
01599 {
01600 arma_ignore(out);
01601 arma_stop("chol(): use of LAPACK needs to be enabled");
01602 return false;
01603 }
01604 #endif
01605 }
01606
01607
01608
01609 template<typename eT, typename T1>
01610 inline
01611 bool
01612 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
01613 {
01614 arma_extra_debug_sigprint();
01615
01616 #if defined(ARMA_USE_LAPACK)
01617 {
01618 R = X.get_ref();
01619
01620 const uword R_n_rows = R.n_rows;
01621 const uword R_n_cols = R.n_cols;
01622
01623 if(R.is_empty())
01624 {
01625 Q.eye(R_n_rows, R_n_rows);
01626 return true;
01627 }
01628
01629 blas_int m = static_cast<blas_int>(R_n_rows);
01630 blas_int n = static_cast<blas_int>(R_n_cols);
01631 blas_int work_len = (std::max)(blas_int(1),n);
01632 blas_int work_len_tmp;
01633 blas_int k = (std::min)(m,n);
01634 blas_int info;
01635
01636 podarray<eT> tau( static_cast<uword>(k) );
01637 podarray<eT> work( static_cast<uword>(work_len) );
01638
01639
01640 work_len_tmp = -1;
01641 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
01642
01643 if(info == 0)
01644 {
01645 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
01646 work.set_size( static_cast<uword>(work_len) );
01647 }
01648
01649 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
01650
01651 Q.set_size(R_n_rows, R_n_rows);
01652
01653 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) );
01654
01655
01656
01657
01658 for(uword col=0; col < R_n_cols; ++col)
01659 {
01660 for(uword row=(col+1); row < R_n_rows; ++row)
01661 {
01662 R.at(row,col) = eT(0);
01663 }
01664 }
01665
01666
01667 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
01668 {
01669
01670 work_len_tmp = -1;
01671 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
01672
01673 if(info == 0)
01674 {
01675 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
01676 work.set_size( static_cast<uword>(work_len) );
01677 }
01678
01679 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
01680 }
01681 else
01682 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
01683 {
01684
01685 work_len_tmp = -1;
01686 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
01687
01688 if(info == 0)
01689 {
01690 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
01691 work.set_size( static_cast<uword>(work_len) );
01692 }
01693
01694 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
01695 }
01696
01697 return (info == 0);
01698 }
01699 #else
01700 {
01701 arma_ignore(Q);
01702 arma_ignore(R);
01703 arma_ignore(X);
01704 arma_stop("qr(): use of LAPACK needs to be enabled");
01705 return false;
01706 }
01707 #endif
01708 }
01709
01710
01711
01712 template<typename eT, typename T1>
01713 inline
01714 bool
01715 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols)
01716 {
01717 arma_extra_debug_sigprint();
01718
01719 #if defined(ARMA_USE_LAPACK)
01720 {
01721 Mat<eT> A(X.get_ref());
01722
01723 X_n_rows = A.n_rows;
01724 X_n_cols = A.n_cols;
01725
01726 if(A.is_empty())
01727 {
01728 S.reset();
01729 return true;
01730 }
01731
01732 Mat<eT> U(1, 1);
01733 Mat<eT> V(1, A.n_cols);
01734
01735 char jobu = 'N';
01736 char jobvt = 'N';
01737
01738 blas_int m = A.n_rows;
01739 blas_int n = A.n_cols;
01740 blas_int lda = A.n_rows;
01741 blas_int ldu = U.n_rows;
01742 blas_int ldvt = V.n_rows;
01743 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
01744 blas_int info;
01745
01746 S.set_size( static_cast<uword>((std::min)(m, n)) );
01747
01748 podarray<eT> work( static_cast<uword>(lwork) );
01749
01750
01751
01752 blas_int lwork_tmp = -1;
01753
01754 lapack::gesvd<eT>
01755 (
01756 &jobu, &jobvt,
01757 &m,&n,
01758 A.memptr(), &lda,
01759 S.memptr(),
01760 U.memptr(), &ldu,
01761 V.memptr(), &ldvt,
01762 work.memptr(), &lwork_tmp,
01763 &info
01764 );
01765
01766 if(info == 0)
01767 {
01768 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
01769
01770 if(proposed_lwork > lwork)
01771 {
01772 lwork = proposed_lwork;
01773 work.set_size( static_cast<uword>(lwork) );
01774 }
01775
01776 lapack::gesvd<eT>
01777 (
01778 &jobu, &jobvt,
01779 &m, &n,
01780 A.memptr(), &lda,
01781 S.memptr(),
01782 U.memptr(), &ldu,
01783 V.memptr(), &ldvt,
01784 work.memptr(), &lwork,
01785 &info
01786 );
01787 }
01788
01789 return (info == 0);
01790 }
01791 #else
01792 {
01793 arma_ignore(S);
01794 arma_ignore(X);
01795 arma_ignore(X_n_rows);
01796 arma_ignore(X_n_cols);
01797 arma_stop("svd(): use of LAPACK needs to be enabled");
01798 return false;
01799 }
01800 #endif
01801 }
01802
01803
01804
01805 template<typename T, typename T1>
01806 inline
01807 bool
01808 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols)
01809 {
01810 arma_extra_debug_sigprint();
01811
01812 typedef std::complex<T> eT;
01813
01814 #if defined(ARMA_USE_LAPACK)
01815 {
01816 Mat<eT> A(X.get_ref());
01817
01818 X_n_rows = A.n_rows;
01819 X_n_cols = A.n_cols;
01820
01821 if(A.is_empty())
01822 {
01823 S.reset();
01824 return true;
01825 }
01826
01827 Mat<eT> U(1, 1);
01828 Mat<eT> V(1, A.n_cols);
01829
01830 char jobu = 'N';
01831 char jobvt = 'N';
01832
01833 blas_int m = A.n_rows;
01834 blas_int n = A.n_cols;
01835 blas_int lda = A.n_rows;
01836 blas_int ldu = U.n_rows;
01837 blas_int ldvt = V.n_rows;
01838 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
01839 blas_int info;
01840
01841 S.set_size( static_cast<uword>((std::min)(m,n)) );
01842
01843 podarray<eT> work( static_cast<uword>(lwork) );
01844 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
01845
01846
01847 blas_int lwork_tmp = -1;
01848
01849 lapack::cx_gesvd<T>
01850 (
01851 &jobu, &jobvt,
01852 &m, &n,
01853 A.memptr(), &lda,
01854 S.memptr(),
01855 U.memptr(), &ldu,
01856 V.memptr(), &ldvt,
01857 work.memptr(), &lwork_tmp,
01858 rwork.memptr(),
01859 &info
01860 );
01861
01862 if(info == 0)
01863 {
01864 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
01865 if(proposed_lwork > lwork)
01866 {
01867 lwork = proposed_lwork;
01868 work.set_size( static_cast<uword>(lwork) );
01869 }
01870
01871 lapack::cx_gesvd<T>
01872 (
01873 &jobu, &jobvt,
01874 &m, &n,
01875 A.memptr(), &lda,
01876 S.memptr(),
01877 U.memptr(), &ldu,
01878 V.memptr(), &ldvt,
01879 work.memptr(), &lwork,
01880 rwork.memptr(),
01881 &info
01882 );
01883 }
01884
01885 return (info == 0);
01886 }
01887 #else
01888 {
01889 arma_ignore(S);
01890 arma_ignore(X);
01891 arma_ignore(X_n_rows);
01892 arma_ignore(X_n_cols);
01893
01894 arma_stop("svd(): use of LAPACK needs to be enabled");
01895 return false;
01896 }
01897 #endif
01898 }
01899
01900
01901
01902 template<typename eT, typename T1>
01903 inline
01904 bool
01905 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X)
01906 {
01907 arma_extra_debug_sigprint();
01908
01909 uword junk;
01910 return auxlib::svd(S, X, junk, junk);
01911 }
01912
01913
01914
01915 template<typename T, typename T1>
01916 inline
01917 bool
01918 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X)
01919 {
01920 arma_extra_debug_sigprint();
01921
01922 uword junk;
01923 return auxlib::svd(S, X, junk, junk);
01924 }
01925
01926
01927
01928 template<typename eT, typename T1>
01929 inline
01930 bool
01931 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
01932 {
01933 arma_extra_debug_sigprint();
01934
01935 #if defined(ARMA_USE_LAPACK)
01936 {
01937 Mat<eT> A(X.get_ref());
01938
01939 if(A.is_empty())
01940 {
01941 U.eye(A.n_rows, A.n_rows);
01942 S.reset();
01943 V.eye(A.n_cols, A.n_cols);
01944 return true;
01945 }
01946
01947 U.set_size(A.n_rows, A.n_rows);
01948 V.set_size(A.n_cols, A.n_cols);
01949
01950 char jobu = 'A';
01951 char jobvt = 'A';
01952
01953 blas_int m = A.n_rows;
01954 blas_int n = A.n_cols;
01955 blas_int lda = A.n_rows;
01956 blas_int ldu = U.n_rows;
01957 blas_int ldvt = V.n_rows;
01958 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
01959 blas_int info;
01960
01961
01962 S.set_size( static_cast<uword>((std::min)(m,n)) );
01963 podarray<eT> work( static_cast<uword>(lwork) );
01964
01965
01966 blas_int lwork_tmp = -1;
01967
01968 lapack::gesvd<eT>
01969 (
01970 &jobu, &jobvt,
01971 &m, &n,
01972 A.memptr(), &lda,
01973 S.memptr(),
01974 U.memptr(), &ldu,
01975 V.memptr(), &ldvt,
01976 work.memptr(), &lwork_tmp,
01977 &info
01978 );
01979
01980 if(info == 0)
01981 {
01982 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
01983 if(proposed_lwork > lwork)
01984 {
01985 lwork = proposed_lwork;
01986 work.set_size( static_cast<uword>(lwork) );
01987 }
01988
01989 lapack::gesvd<eT>
01990 (
01991 &jobu, &jobvt,
01992 &m, &n,
01993 A.memptr(), &lda,
01994 S.memptr(),
01995 U.memptr(), &ldu,
01996 V.memptr(), &ldvt,
01997 work.memptr(), &lwork,
01998 &info
01999 );
02000
02001 op_strans::apply(V,V);
02002 }
02003
02004 return (info == 0);
02005 }
02006 #else
02007 {
02008 arma_ignore(U);
02009 arma_ignore(S);
02010 arma_ignore(V);
02011 arma_ignore(X);
02012 arma_stop("svd(): use of LAPACK needs to be enabled");
02013 return false;
02014 }
02015 #endif
02016 }
02017
02018
02019
02020 template<typename T, typename T1>
02021 inline
02022 bool
02023 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
02024 {
02025 arma_extra_debug_sigprint();
02026
02027 typedef std::complex<T> eT;
02028
02029 #if defined(ARMA_USE_LAPACK)
02030 {
02031 Mat<eT> A(X.get_ref());
02032
02033 if(A.is_empty())
02034 {
02035 U.eye(A.n_rows, A.n_rows);
02036 S.reset();
02037 V.eye(A.n_cols, A.n_cols);
02038 return true;
02039 }
02040
02041 U.set_size(A.n_rows, A.n_rows);
02042 V.set_size(A.n_cols, A.n_cols);
02043
02044 char jobu = 'A';
02045 char jobvt = 'A';
02046
02047 blas_int m = A.n_rows;
02048 blas_int n = A.n_cols;
02049 blas_int lda = A.n_rows;
02050 blas_int ldu = U.n_rows;
02051 blas_int ldvt = V.n_rows;
02052 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
02053 blas_int info;
02054
02055 S.set_size( static_cast<uword>((std::min)(m,n)) );
02056
02057 podarray<eT> work( static_cast<uword>(lwork) );
02058 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
02059
02060
02061 blas_int lwork_tmp = -1;
02062 lapack::cx_gesvd<T>
02063 (
02064 &jobu, &jobvt,
02065 &m, &n,
02066 A.memptr(), &lda,
02067 S.memptr(),
02068 U.memptr(), &ldu,
02069 V.memptr(), &ldvt,
02070 work.memptr(), &lwork_tmp,
02071 rwork.memptr(),
02072 &info
02073 );
02074
02075 if(info == 0)
02076 {
02077 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
02078 if(proposed_lwork > lwork)
02079 {
02080 lwork = proposed_lwork;
02081 work.set_size( static_cast<uword>(lwork) );
02082 }
02083
02084 lapack::cx_gesvd<T>
02085 (
02086 &jobu, &jobvt,
02087 &m, &n,
02088 A.memptr(), &lda,
02089 S.memptr(),
02090 U.memptr(), &ldu,
02091 V.memptr(), &ldvt,
02092 work.memptr(), &lwork,
02093 rwork.memptr(),
02094 &info
02095 );
02096
02097 op_htrans::apply(V,V);
02098 }
02099
02100 return (info == 0);
02101 }
02102 #else
02103 {
02104 arma_ignore(U);
02105 arma_ignore(S);
02106 arma_ignore(V);
02107 arma_ignore(X);
02108 arma_stop("svd(): use of LAPACK needs to be enabled");
02109 return false;
02110 }
02111 #endif
02112
02113 }
02114
02115
02116
02117 template<typename eT, typename T1>
02118 inline
02119 bool
02120 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode)
02121 {
02122 arma_extra_debug_sigprint();
02123
02124 #if defined(ARMA_USE_LAPACK)
02125 {
02126 Mat<eT> A(X.get_ref());
02127
02128 blas_int m = A.n_rows;
02129 blas_int n = A.n_cols;
02130 blas_int lda = A.n_rows;
02131
02132 S.set_size( static_cast<uword>((std::min)(m,n)) );
02133
02134 blas_int ldu = 0;
02135 blas_int ldvt = 0;
02136
02137 char jobu;
02138 char jobvt;
02139
02140 switch(mode)
02141 {
02142 case 'l':
02143 jobu = 'S';
02144 jobvt = 'N';
02145
02146 ldu = m;
02147 ldvt = 1;
02148
02149 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
02150 V.reset();
02151
02152 break;
02153
02154
02155 case 'r':
02156 jobu = 'N';
02157 jobvt = 'S';
02158
02159 ldu = 1;
02160 ldvt = (std::min)(m,n);
02161
02162 U.reset();
02163 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
02164
02165 break;
02166
02167
02168 case 'b':
02169 jobu = 'S';
02170 jobvt = 'S';
02171
02172 ldu = m;
02173 ldvt = (std::min)(m,n);
02174
02175 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
02176 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
02177
02178 break;
02179
02180
02181 default:
02182 U.reset();
02183 S.reset();
02184 V.reset();
02185 return false;
02186 }
02187
02188
02189 if(A.is_empty())
02190 {
02191 U.eye();
02192 S.reset();
02193 V.eye();
02194 return true;
02195 }
02196
02197
02198 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
02199 blas_int info = 0;
02200
02201
02202 podarray<eT> work( static_cast<uword>(lwork) );
02203
02204
02205 blas_int lwork_tmp = -1;
02206
02207 lapack::gesvd<eT>
02208 (
02209 &jobu, &jobvt,
02210 &m, &n,
02211 A.memptr(), &lda,
02212 S.memptr(),
02213 U.memptr(), &ldu,
02214 V.memptr(), &ldvt,
02215 work.memptr(), &lwork_tmp,
02216 &info
02217 );
02218
02219 if(info == 0)
02220 {
02221 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
02222 if(proposed_lwork > lwork)
02223 {
02224 lwork = proposed_lwork;
02225 work.set_size( static_cast<uword>(lwork) );
02226 }
02227
02228 lapack::gesvd<eT>
02229 (
02230 &jobu, &jobvt,
02231 &m, &n,
02232 A.memptr(), &lda,
02233 S.memptr(),
02234 U.memptr(), &ldu,
02235 V.memptr(), &ldvt,
02236 work.memptr(), &lwork,
02237 &info
02238 );
02239
02240 op_strans::apply(V,V);
02241 }
02242
02243 return (info == 0);
02244 }
02245 #else
02246 {
02247 arma_ignore(U);
02248 arma_ignore(S);
02249 arma_ignore(V);
02250 arma_ignore(X);
02251 arma_ignore(mode);
02252 arma_stop("svd(): use of LAPACK needs to be enabled");
02253 return false;
02254 }
02255 #endif
02256 }
02257
02258
02259
02260 template<typename T, typename T1>
02261 inline
02262 bool
02263 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode)
02264 {
02265 arma_extra_debug_sigprint();
02266
02267 typedef std::complex<T> eT;
02268
02269 #if defined(ARMA_USE_LAPACK)
02270 {
02271 Mat<eT> A(X.get_ref());
02272
02273 blas_int m = A.n_rows;
02274 blas_int n = A.n_cols;
02275 blas_int lda = A.n_rows;
02276
02277 S.set_size( static_cast<uword>((std::min)(m,n)) );
02278
02279 blas_int ldu = 0;
02280 blas_int ldvt = 0;
02281
02282 char jobu;
02283 char jobvt;
02284
02285 switch(mode)
02286 {
02287 case 'l':
02288 jobu = 'S';
02289 jobvt = 'N';
02290
02291 ldu = m;
02292 ldvt = 1;
02293
02294 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
02295 V.reset();
02296
02297 break;
02298
02299
02300 case 'r':
02301 jobu = 'N';
02302 jobvt = 'S';
02303
02304 ldu = 1;
02305 ldvt = (std::min)(m,n);
02306
02307 U.reset();
02308 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
02309
02310 break;
02311
02312
02313 case 'b':
02314 jobu = 'S';
02315 jobvt = 'S';
02316
02317 ldu = m;
02318 ldvt = (std::min)(m,n);
02319
02320 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
02321 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
02322
02323 break;
02324
02325
02326 default:
02327 U.reset();
02328 S.reset();
02329 V.reset();
02330 return false;
02331 }
02332
02333
02334 if(A.is_empty())
02335 {
02336 U.eye();
02337 S.reset();
02338 V.eye();
02339 return true;
02340 }
02341
02342
02343 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
02344 blas_int info = 0;
02345
02346
02347 podarray<eT> work( static_cast<uword>(lwork) );
02348 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
02349
02350
02351 blas_int lwork_tmp = -1;
02352
02353 lapack::cx_gesvd<T>
02354 (
02355 &jobu, &jobvt,
02356 &m, &n,
02357 A.memptr(), &lda,
02358 S.memptr(),
02359 U.memptr(), &ldu,
02360 V.memptr(), &ldvt,
02361 work.memptr(), &lwork_tmp,
02362 rwork.memptr(),
02363 &info
02364 );
02365
02366 if(info == 0)
02367 {
02368 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
02369 if(proposed_lwork > lwork)
02370 {
02371 lwork = proposed_lwork;
02372 work.set_size( static_cast<uword>(lwork) );
02373 }
02374
02375 lapack::cx_gesvd<T>
02376 (
02377 &jobu, &jobvt,
02378 &m, &n,
02379 A.memptr(), &lda,
02380 S.memptr(),
02381 U.memptr(), &ldu,
02382 V.memptr(), &ldvt,
02383 work.memptr(), &lwork,
02384 rwork.memptr(),
02385 &info
02386 );
02387
02388 op_htrans::apply(V,V);
02389 }
02390
02391 return (info == 0);
02392 }
02393 #else
02394 {
02395 arma_ignore(U);
02396 arma_ignore(S);
02397 arma_ignore(V);
02398 arma_ignore(X);
02399 arma_ignore(mode);
02400 arma_stop("svd(): use of LAPACK needs to be enabled");
02401 return false;
02402 }
02403 #endif
02404 }
02405
02406
02407
02410 template<typename eT>
02411 inline
02412 bool
02413 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B, const bool slow)
02414 {
02415 arma_extra_debug_sigprint();
02416
02417 if(A.is_empty() || B.is_empty())
02418 {
02419 out.zeros(A.n_cols, B.n_cols);
02420 return true;
02421 }
02422 else
02423 {
02424 const uword A_n_rows = A.n_rows;
02425
02426 bool status = false;
02427
02428 if( (A_n_rows <= 4) && (slow == false) )
02429 {
02430 Mat<eT> A_inv;
02431
02432 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows);
02433
02434 if(status == true)
02435 {
02436 out.set_size(A_n_rows, B.n_cols);
02437
02438 gemm_emul<false,false,false,false>::apply(out, A_inv, B);
02439
02440 return true;
02441 }
02442 }
02443
02444 if( (A_n_rows > 4) || (status == false) )
02445 {
02446 #if defined(ARMA_USE_ATLAS)
02447 {
02448 podarray<int> ipiv(A_n_rows);
02449
02450 out = B;
02451
02452 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B.n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows);
02453
02454 return (info == 0);
02455 }
02456 #elif defined(ARMA_USE_LAPACK)
02457 {
02458 blas_int n = A_n_rows;
02459 blas_int lda = A_n_rows;
02460 blas_int ldb = A_n_rows;
02461 blas_int nrhs = B.n_cols;
02462 blas_int info;
02463
02464 podarray<blas_int> ipiv(A_n_rows);
02465
02466 out = B;
02467
02468 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info);
02469
02470 return (info == 0);
02471 }
02472 #else
02473 {
02474 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled");
02475 return false;
02476 }
02477 #endif
02478 }
02479 }
02480
02481 return true;
02482 }
02483
02484
02485
02488 template<typename eT>
02489 inline
02490 bool
02491 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
02492 {
02493 arma_extra_debug_sigprint();
02494
02495 #if defined(ARMA_USE_LAPACK)
02496 {
02497 if(A.is_empty() || B.is_empty())
02498 {
02499 out.zeros(A.n_cols, B.n_cols);
02500 return true;
02501 }
02502
02503 char trans = 'N';
02504
02505 blas_int m = A.n_rows;
02506 blas_int n = A.n_cols;
02507 blas_int lda = A.n_rows;
02508 blas_int ldb = A.n_rows;
02509 blas_int nrhs = B.n_cols;
02510 blas_int lwork = n + (std::max)(n, nrhs);
02511 blas_int info;
02512
02513 Mat<eT> tmp = B;
02514
02515 podarray<eT> work( static_cast<uword>(lwork) );
02516
02517 arma_extra_debug_print("lapack::gels()");
02518
02519
02520
02521 lapack::gels<eT>
02522 (
02523 &trans, &m, &n, &nrhs,
02524 A.memptr(), &lda,
02525 tmp.memptr(), &ldb,
02526 work.memptr(), &lwork,
02527 &info
02528 );
02529
02530 arma_extra_debug_print("lapack::gels() -- finished");
02531
02532 out.set_size(A.n_cols, B.n_cols);
02533
02534 for(uword col=0; col<B.n_cols; ++col)
02535 {
02536 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
02537 }
02538
02539 return (info == 0);
02540 }
02541 #else
02542 {
02543 arma_ignore(out);
02544 arma_ignore(A);
02545 arma_ignore(B);
02546 arma_stop("solve(): use of LAPACK needs to be enabled");
02547 return false;
02548 }
02549 #endif
02550 }
02551
02552
02553
02556 template<typename eT>
02557 inline
02558 bool
02559 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
02560 {
02561 arma_extra_debug_sigprint();
02562
02563 #if defined(ARMA_USE_LAPACK)
02564 {
02565 if(A.is_empty() || B.is_empty())
02566 {
02567 out.zeros(A.n_cols, B.n_cols);
02568 return true;
02569 }
02570
02571 char trans = 'N';
02572
02573 blas_int m = A.n_rows;
02574 blas_int n = A.n_cols;
02575 blas_int lda = A.n_rows;
02576 blas_int ldb = A.n_cols;
02577 blas_int nrhs = B.n_cols;
02578 blas_int lwork = m + (std::max)(m,nrhs);
02579 blas_int info;
02580
02581
02582 Mat<eT> tmp;
02583 tmp.zeros(A.n_cols, B.n_cols);
02584
02585 for(uword col=0; col<B.n_cols; ++col)
02586 {
02587 eT* tmp_colmem = tmp.colptr(col);
02588
02589 arrayops::copy( tmp_colmem, B.colptr(col), B.n_rows );
02590
02591 for(uword row=B.n_rows; row<A.n_cols; ++row)
02592 {
02593 tmp_colmem[row] = eT(0);
02594 }
02595 }
02596
02597 podarray<eT> work( static_cast<uword>(lwork) );
02598
02599 arma_extra_debug_print("lapack::gels()");
02600
02601
02602
02603 lapack::gels<eT>
02604 (
02605 &trans, &m, &n, &nrhs,
02606 A.memptr(), &lda,
02607 tmp.memptr(), &ldb,
02608 work.memptr(), &lwork,
02609 &info
02610 );
02611
02612 arma_extra_debug_print("lapack::gels() -- finished");
02613
02614 out.set_size(A.n_cols, B.n_cols);
02615
02616 for(uword col=0; col<B.n_cols; ++col)
02617 {
02618 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
02619 }
02620
02621 return (info == 0);
02622 }
02623 #else
02624 {
02625 arma_ignore(out);
02626 arma_ignore(A);
02627 arma_ignore(B);
02628 arma_stop("solve(): use of LAPACK needs to be enabled");
02629 return false;
02630 }
02631 #endif
02632 }
02633
02634
02635
02636
02637
02638
02639 template<typename eT>
02640 inline
02641 bool
02642 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout)
02643 {
02644 arma_extra_debug_sigprint();
02645
02646 #if defined(ARMA_USE_LAPACK)
02647 {
02648 if(A.is_empty() || B.is_empty())
02649 {
02650 out.zeros(A.n_cols, B.n_cols);
02651 return true;
02652 }
02653
02654 out = B;
02655
02656 char uplo = (layout == 0) ? 'U' : 'L';
02657 char trans = 'N';
02658 char diag = 'N';
02659 blas_int n = blas_int(A.n_rows);
02660 blas_int nrhs = blas_int(B.n_cols);
02661 blas_int info = 0;
02662
02663 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info);
02664
02665 return (info == 0);
02666 }
02667 #else
02668 {
02669 arma_ignore(out);
02670 arma_ignore(A);
02671 arma_ignore(B);
02672 arma_ignore(layout);
02673 arma_stop("solve(): use of LAPACK needs to be enabled");
02674 return false;
02675 }
02676 #endif
02677 }
02678
02679
02680
02681
02682
02683
02684 template<typename eT>
02685 inline
02686 bool
02687 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A)
02688 {
02689 arma_extra_debug_sigprint();
02690
02691 #if defined(ARMA_USE_LAPACK)
02692 {
02693 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
02694
02695 if(A.is_empty())
02696 {
02697 Z.reset();
02698 T.reset();
02699 return true;
02700 }
02701
02702 const uword A_n_rows = A.n_rows;
02703
02704 char jobvs = 'V';
02705 char sort = 'N';
02706 blas_int* select = 0;
02707 blas_int n = blas_int(A_n_rows);
02708 blas_int sdim = 0;
02709
02710 blas_int lwork = 3 * n;
02711
02712 podarray<eT> work( static_cast<uword>(lwork) );
02713 podarray<blas_int> bwork(A_n_rows);
02714
02715 blas_int info = 0;
02716
02717 Z.set_size(A_n_rows, A_n_rows);
02718 T = A;
02719
02720 podarray<eT> wr(A_n_rows);
02721 podarray<eT> wi(A_n_rows);
02722
02723 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info);
02724
02725 return (info == 0);
02726 }
02727 #else
02728 {
02729 arma_ignore(Z);
02730 arma_ignore(T);
02731 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
02732 return false;
02733 }
02734 #endif
02735 }
02736
02737
02738
02739 template<typename cT>
02740 inline
02741 bool
02742 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A)
02743 {
02744 arma_extra_debug_sigprint();
02745
02746 #if defined(ARMA_USE_LAPACK)
02747 {
02748 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
02749
02750 if(A.is_empty())
02751 {
02752 Z.reset();
02753 T.reset();
02754 return true;
02755 }
02756
02757 typedef std::complex<cT> eT;
02758
02759 const uword A_n_rows = A.n_rows;
02760
02761 char jobvs = 'V';
02762 char sort = 'N';
02763 blas_int* select = 0;
02764 blas_int n = blas_int(A_n_rows);
02765 blas_int sdim = 0;
02766
02767 blas_int lwork = 3 * n;
02768
02769 podarray<eT> work( static_cast<uword>(lwork) );
02770 podarray<blas_int> bwork(A_n_rows);
02771
02772 blas_int info = 0;
02773
02774 Z.set_size(A_n_rows, A_n_rows);
02775 T = A;
02776
02777 podarray<eT> w(A_n_rows);
02778 podarray<cT> rwork(A_n_rows);
02779
02780 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info);
02781
02782 return (info == 0);
02783 }
02784 #else
02785 {
02786 arma_ignore(Z);
02787 arma_ignore(T);
02788 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
02789 return false;
02790 }
02791 #endif
02792 }
02793
02794
02795
02796
02797
02798
02799 template<typename eT>
02800 inline
02801 bool
02802 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
02803 {
02804 arma_extra_debug_sigprint();
02805
02806 arma_debug_check( (A.is_square() == false), "syl(): matrix A is not square" );
02807 arma_debug_check( (B.is_square() == false), "syl(): matrix B is not square" );
02808
02809 arma_debug_check( (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols), "syl(): matrices are not conformant" );
02810
02811 if(A.is_empty() || B.is_empty() || C.is_empty())
02812 {
02813 X.reset();
02814 return true;
02815 }
02816
02817 bool status;
02818
02819 #if defined(ARMA_USE_LAPACK)
02820 {
02821 Mat<eT> Z1, Z2, T1, T2;
02822
02823 status = auxlib::schur_dec(Z1, T1, A);
02824 if(status == false)
02825 {
02826 return false;
02827 }
02828
02829 status = auxlib::schur_dec(Z2, T2, B);
02830 if(status == false)
02831 {
02832 return false;
02833 }
02834
02835 char trana = 'N';
02836 char tranb = 'N';
02837 blas_int isgn = +1;
02838 blas_int m = blas_int(T1.n_rows);
02839 blas_int n = blas_int(T2.n_cols);
02840
02841 eT scale = eT(0);
02842 blas_int info = 0;
02843
02844 Mat<eT> Y = trans(Z1) * C * Z2;
02845
02846 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info);
02847
02848
02849 Y /= (-scale);
02850
02851 X = Z1 * Y * trans(Z2);
02852
02853 status = (info == 0);
02854 }
02855 #else
02856 {
02857 arma_stop("syl(): use of LAPACK needs to be enabled");
02858 return false;
02859 }
02860 #endif
02861
02862
02863 return status;
02864 }
02865
02866
02867
02868
02869
02870
02871 template<typename eT>
02872 inline
02873 bool
02874 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
02875 {
02876 arma_extra_debug_sigprint();
02877
02878 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square");
02879 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square");
02880 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions");
02881
02882 Mat<eT> htransA;
02883 op_htrans::apply_noalias(htransA, A);
02884
02885 const Mat<eT> mQ = -Q;
02886
02887 return auxlib::syl(X, A, htransA, mQ);
02888 }
02889
02890
02891
02892
02893
02894
02895 template<typename eT>
02896 inline
02897 bool
02898 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
02899 {
02900 arma_extra_debug_sigprint();
02901
02902 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square");
02903 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square");
02904 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions");
02905
02906 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1);
02907
02908 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A);
02909
02910 Col<eT> vecX;
02911
02912 const bool status = solve(vecX, M, vecQ);
02913
02914 if(status == true)
02915 {
02916 X = reshape(vecX, Q.n_rows, Q.n_cols);
02917 return true;
02918 }
02919 else
02920 {
02921 X.reset();
02922 return false;
02923 }
02924 }
02925
02926
02927