00001 #ifndef TOON_INCLUDE_QR_LAPACK_H
00002 #define TOON_INCLUDE_QR_LAPACK_H
00003
00004
00005 #include <TooN/TooN.h>
00006 #include <TooN/lapack.h>
00007 #include <utility>
00008
00009 namespace TooN{
00010
00030 template<int Rows=Dynamic, int Cols=Rows, class Precision=double>
00031 class QR_Lapack{
00032
00033 private:
00034 static const int square_Size = (Rows>=0 && Cols>=0)?(Rows<Cols?Rows:Cols):Dynamic;
00035
00036 public:
00041 template<int R, int C, class P, class B>
00042 QR_Lapack(const Matrix<R,C,P,B>& m, bool p=0)
00043 :copy(m),tau(square_size()), Q(square_size(), square_size()), do_pivoting(p), pivot(Zeros(square_size()))
00044 {
00045
00046
00047
00048 compute();
00049 }
00050
00052 const Matrix<Rows, Cols, Precision, ColMajor>& get_R()
00053 {
00054 return copy;
00055 }
00056
00058 const Matrix<square_Size, square_Size, Precision, ColMajor>& get_Q()
00059 {
00060 return Q;
00061 }
00062
00065 const Vector<Cols, int>& get_P()
00066 {
00067 return pivot;
00068 }
00069
00070 private:
00071
00072 void compute()
00073 {
00074 int M = copy.num_rows();
00075 int N = copy.num_cols();
00076
00077 int LWORK=-1;
00078 int INFO;
00079 int lda = M;
00080
00081 Precision size;
00082
00083
00084 if(do_pivoting)
00085 pivot = Zeros;
00086 else
00087 for(int i=0; i < pivot.size(); i++)
00088 pivot[i] = i+1;
00089
00090
00091
00092 geqp3_(&M, &N, copy.get_data_ptr(), &lda, pivot.get_data_ptr(), tau.get_data_ptr(), &size, &LWORK, &INFO);
00093
00094 LWORK = (int) size;
00095
00096 Precision* work = new Precision[LWORK];
00097
00098 geqp3_(&M, &N, copy.get_data_ptr(), &lda, pivot.get_data_ptr(), tau.get_data_ptr(), work, &LWORK, &INFO);
00099
00100
00101 if(INFO < 0)
00102 std::cerr << "error in QR, INFO was " << INFO << std::endl;
00103
00104
00105
00106
00107
00108 Q = copy.template slice<0,0,square_Size, square_Size>(0,0,square_size(), square_size());
00109
00110 int K = square_size();
00111 M=K;
00112 N=K;
00113 lda = K;
00114 orgqr_(&M, &N, &K, Q.get_data_ptr(), &lda, tau.get_data_ptr(), work, &LWORK, &INFO);
00115
00116 if(INFO < 0)
00117 std::cerr << "error in QR, INFO was " << INFO << std::endl;
00118
00119 delete [] work;
00120
00121
00122 for(int r=1; r < square_size(); r++)
00123 for(int c=0; c<r; c++)
00124 copy[r][c] = 0;
00125
00126
00127
00128 for(int i=0; i < pivot.size(); i++)
00129 pivot[i]--;
00130 }
00131
00132 Matrix<Rows, Cols, Precision, ColMajor> copy;
00133 Matrix<square_Size, square_Size, Precision, ColMajor> Q;
00134 Vector<square_Size, Precision> tau;
00135 Vector<Cols, int> pivot;
00136
00137 bool do_pivoting;
00138
00139 int square_size()
00140 {
00141 return std::min(copy.num_rows(), copy.num_cols());
00142 }
00143 };
00144
00145 }
00146
00147
00148 #endif