00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef TOON_INCLUDE_LU_H
00032 #define TOON_INCLUDE_LU_H
00033
00034 #include <iostream>
00035
00036 #include <TooN/lapack.h>
00037
00038 #include <TooN/TooN.h>
00039
00040 namespace TooN {
00068 template <int Size=-1, class Precision=double>
00069 class LU {
00070 public:
00071
00074 template<int S1, int S2, class Base>
00075 LU(const Matrix<S1,S2,Precision, Base>& m)
00076 :my_lu(m.num_rows(),m.num_cols()),my_IPIV(m.num_rows()){
00077 compute(m);
00078 }
00079
00081 template<int S1, int S2, class Base>
00082 void compute(const Matrix<S1,S2,Precision,Base>& m){
00083
00084 SizeMismatch<Size, S1>::test(my_lu.num_rows(),m.num_rows());
00085 SizeMismatch<Size, S2>::test(my_lu.num_rows(),m.num_cols());
00086
00087
00088 my_lu=m;
00089 int lda = m.num_rows();
00090 int M = m.num_rows();
00091 int N = m.num_rows();
00092
00093 getrf_(&M,&N,&my_lu[0][0],&lda,&my_IPIV[0],&my_info);
00094
00095 if(my_info < 0){
00096 std::cerr << "error in LU, INFO was " << my_info << std::endl;
00097 }
00098 }
00099
00102 template <int Rows, int NRHS, class Base>
00103 Matrix<Size,NRHS,Precision> backsub(const Matrix<Rows,NRHS,Precision,Base>& rhs){
00104
00105 SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.num_rows());
00106
00107 Matrix<Size, NRHS, Precision> result(rhs);
00108
00109 int M=rhs.num_cols();
00110 int N=my_lu.num_rows();
00111 double alpha=1;
00112 int lda=my_lu.num_rows();
00113 int ldb=rhs.num_cols();
00114 trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
00115 trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
00116
00117
00118 for(int i=N-1; i>=0; i--){
00119 const int swaprow = my_IPIV[i]-1;
00120 for(int j=0; j<NRHS; j++){
00121 Precision temp = result[i][j];
00122 result[i][j] = result[swaprow][j];
00123 result[swaprow][j] = temp;
00124 }
00125 }
00126 return result;
00127 }
00128
00131 template <int Rows, class Base>
00132 Vector<Size,Precision> backsub(const Vector<Rows,Precision,Base>& rhs){
00133
00134 SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.size());
00135
00136 Vector<Size, Precision> result(rhs);
00137
00138 int M=1;
00139 int N=my_lu.num_rows();
00140 double alpha=1;
00141 int lda=my_lu.num_rows();
00142 int ldb=1;
00143 trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
00144 trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
00145
00146
00147 for(int i=N-1; i>=0; i--){
00148 const int swaprow = my_IPIV[i]-1;
00149 Precision temp = result[i];
00150 result[i] = result[swaprow];
00151 result[swaprow] = temp;
00152 }
00153 return result;
00154 }
00155
00158 Matrix<Size,Size,Precision> get_inverse(){
00159 Matrix<Size,Size,Precision> Inverse(my_lu);
00160 int N = my_lu.num_rows();
00161 int lda=my_lu.num_rows();
00162 int lwork=-1;
00163 Precision size;
00164 getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], &size, &lwork, &my_info);
00165 lwork=int(size);
00166 Precision* WORK = new Precision[lwork];
00167 getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], WORK, &lwork, &my_info);
00168 delete [] WORK;
00169 return Inverse;
00170 }
00171
00177 const Matrix<Size,Size,Precision>& get_lu()const {return my_lu;}
00178
00179 private:
00180 inline int get_sign() const {
00181 int result=1;
00182 for(int i=0; i<my_lu.num_rows()-1; i++){
00183 if(my_IPIV[i] > i+1){
00184 result=-result;
00185 }
00186 }
00187 return result;
00188 }
00189 public:
00190
00192 inline Precision determinant() const {
00193 Precision result = get_sign();
00194 for (int i=0; i<my_lu.num_rows(); i++){
00195 result*=my_lu(i,i);
00196 }
00197 return result;
00198 }
00199
00201 int get_info() const { return my_info; }
00202
00203 private:
00204
00205 Matrix<Size,Size,Precision> my_lu;
00206 int my_info;
00207 Vector<Size, int> my_IPIV;
00208
00209 };
00210 }
00211
00212
00213 #endif