00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef __SVD_H
00031 #define __SVD_H
00032
00033 #include <TooN/TooN.h>
00034 #include <TooN/lapack.h>
00035
00036 namespace TooN {
00037
00038
00039 static const double condition_no=1e9;
00040
00041
00042
00043
00044
00045
00046
00087 template<int Rows=Dynamic, int Cols=Rows, typename Precision=DefaultPrecision>
00088 class SVD {
00089
00090
00091 static const int Min_Dim = Rows<Cols?Rows:Cols;
00092
00093 public:
00094
00096 SVD() {}
00097
00099 SVD(int rows, int cols)
00100 : my_copy(rows,cols),
00101 my_diagonal(std::min(rows,cols)),
00102 my_square(std::min(rows,cols), std::min(rows,cols))
00103 {}
00104
00107 template <int R2, int C2, typename P2, typename B2>
00108 SVD(const Matrix<R2,C2,P2,B2>& m)
00109 : my_copy(m),
00110 my_diagonal(std::min(m.num_rows(),m.num_cols())),
00111 my_square(std::min(m.num_rows(),m.num_cols()),std::min(m.num_rows(),m.num_cols()))
00112 {
00113 do_compute();
00114 }
00115
00117 template <int R2, int C2, typename P2, typename B2>
00118 void compute(const Matrix<R2,C2,P2,B2>& m){
00119 my_copy=m;
00120 do_compute();
00121 }
00122
00123 private:
00124 void do_compute(){
00125 Precision* const a = my_copy.my_data;
00126 int lda = my_copy.num_cols();
00127 int m = my_copy.num_cols();
00128 int n = my_copy.num_rows();
00129 Precision* const uorvt = my_square.my_data;
00130 Precision* const s = my_diagonal.my_data;
00131 int ldu;
00132 int ldvt = lda;
00133 int LWORK;
00134 int INFO;
00135 char JOBU;
00136 char JOBVT;
00137
00138 if(is_vertical()){
00139 JOBU='O';
00140 JOBVT='S';
00141 ldu = lda;
00142 } else {
00143 JOBU='S';
00144 JOBVT='O';
00145 ldu = my_square.num_cols();
00146 }
00147
00148 Precision* wk;
00149
00150 Precision size;
00151 LWORK = -1;
00152
00153
00154
00155 dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00156 &ldvt, uorvt, &ldu, &size, &LWORK, &INFO);
00157
00158 LWORK = (long int)(size);
00159 wk = new Precision[LWORK];
00160
00161 dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00162 &ldvt, uorvt, &ldu, wk, &LWORK, &INFO);
00163
00164 delete[] wk;
00165 }
00166
00167 bool is_vertical(){
00168 return (my_copy.num_rows() >= my_copy.num_cols());
00169 }
00170
00171 int min_dim(){ return std::min(my_copy.num_rows(), my_copy.num_cols()); }
00172
00173 public:
00174
00179 template <int Rows2, int Cols2, typename P2, typename B2>
00180 Matrix<Cols,Cols2, typename Internal::MultiplyType<Precision,P2>::type >
00181 backsub(const Matrix<Rows2,Cols2,P2,B2>& rhs, const Precision condition=condition_no)
00182 {
00183 Vector<Min_Dim> inv_diag(min_dim());
00184 get_inv_diag(inv_diag,condition);
00185 return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00186 }
00187
00192 template <int Size, typename P2, typename B2>
00193 Vector<Cols, typename Internal::MultiplyType<Precision,P2>::type >
00194 backsub(const Vector<Size,P2,B2>& rhs, const Precision condition=condition_no)
00195 {
00196 Vector<Min_Dim> inv_diag(min_dim());
00197 get_inv_diag(inv_diag,condition);
00198 return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00199 }
00200
00205 Matrix<Cols,Rows> get_pinv(const Precision condition = condition_no){
00206 Vector<Min_Dim> inv_diag(min_dim());
00207 get_inv_diag(inv_diag,condition);
00208 return diagmult(get_VT().T(),inv_diag) * get_U().T();
00209 }
00210
00213 Precision determinant() {
00214 Precision result = my_diagonal[0];
00215 for(int i=1; i<my_diagonal.size(); i++){
00216 result *= my_diagonal[i];
00217 }
00218 return result;
00219 }
00220
00223 int rank(const Precision condition = condition_no) {
00224 if (my_diagonal[0] == 0) return 0;
00225 int result=1;
00226 for(int i=0; i<min_dim(); i++){
00227 if(my_diagonal[i] * condition <= my_diagonal[0]){
00228 result++;
00229 }
00230 }
00231 return result;
00232 }
00233
00237 Matrix<Rows,Min_Dim,Precision,Reference::RowMajor> get_U(){
00238 if(is_vertical()){
00239 return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00240 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00241 } else {
00242 return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00243 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00244 }
00245 }
00246
00248 Vector<Min_Dim,Precision>& get_diagonal(){ return my_diagonal; }
00249
00253 Matrix<Min_Dim,Cols,Precision,Reference::RowMajor> get_VT(){
00254 if(is_vertical()){
00255 return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00256 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00257 } else {
00258 return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00259 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00260 }
00261 }
00262
00268 void get_inv_diag(Vector<Min_Dim>& inv_diag, const Precision condition){
00269 for(int i=0; i<min_dim(); i++){
00270 if(my_diagonal[i] * condition <= my_diagonal[0]){
00271 inv_diag[i]=0;
00272 } else {
00273 inv_diag[i]=static_cast<Precision>(1)/my_diagonal[i];
00274 }
00275 }
00276 }
00277
00278 private:
00279 Matrix<Rows,Cols,Precision,RowMajor> my_copy;
00280 Vector<Min_Dim,Precision> my_diagonal;
00281 Matrix<Min_Dim,Min_Dim,Precision,RowMajor> my_square;
00282 };
00283
00284
00285
00286
00287
00288
00292 template<int Size, typename Precision>
00293 struct SQSVD : public SVD<Size, Size, Precision> {
00297 SQSVD() {}
00298 SQSVD(int size) : SVD<Size,Size,Precision>(size, size) {}
00299
00300 template <int R2, int C2, typename P2, typename B2>
00301 SQSVD(const Matrix<R2,C2,P2,B2>& m) : SVD<Size,Size,Precision>(m) {}
00303 };
00304
00305
00306 }
00307
00308
00309 #endif