00001 #ifndef DMATRIX_HXX
00002 #define DMATRIX_HXX
00003
00004 #include <iostream>
00005 #include <exception>
00006 namespace GMapping {
00007
00008 class DNotInvertibleMatrixException: public std::exception {};
00009 class DIncompatibleMatrixException: public std::exception {};
00010 class DNotSquareMatrixException: public std::exception {};
00011
00012 template <class X> class DMatrix {
00013 public:
00014 DMatrix(int n=0,int m=0);
00015 ~DMatrix();
00016
00017 DMatrix(const DMatrix&);
00018 DMatrix& operator=(const DMatrix&);
00019
00020 X * operator[](int i) {
00021 if ((*shares)>1) detach();
00022 return mrows[i];
00023 }
00024
00025 const X * operator[](int i) const { return mrows[i]; }
00026
00027 const X det() const;
00028 DMatrix inv() const;
00029 DMatrix transpose() const;
00030 DMatrix operator*(const DMatrix&) const;
00031 DMatrix operator+(const DMatrix&) const;
00032 DMatrix operator-(const DMatrix&) const;
00033 DMatrix operator*(const X&) const;
00034
00035 int rows() const { return nrows; }
00036 int columns() const { return ncols; }
00037
00038 void detach();
00039
00040 static DMatrix I(int);
00041
00042 protected:
00043 int nrows,ncols;
00044 X * elems;
00045 X ** mrows;
00046
00047 int * shares;
00048 };
00049
00050 template <class X> DMatrix<X>::DMatrix(int n,int m) {
00051 if (n<1) n=1;
00052 if (m<1) m=1;
00053 nrows=n;
00054 ncols=m;
00055 elems=new X[nrows*ncols];
00056 mrows=new X* [nrows];
00057 for (int i=0;i<nrows;i++) mrows[i]=elems+ncols*i;
00058 for (int i=0;i<nrows*ncols;i++) elems[i]=X(0);
00059 shares=new int;
00060 (*shares)=1;
00061 }
00062
00063 template <class X> DMatrix<X>::~DMatrix() {
00064 if (--(*shares)) return;
00065 delete [] elems;
00066 delete [] mrows;
00067 delete shares;
00068 }
00069
00070 template <class X> DMatrix<X>::DMatrix(const DMatrix& m) {
00071 shares=m.shares;
00072 elems=m.elems;
00073 nrows=m.nrows;
00074 ncols=m.ncols;
00075 mrows=m.mrows;
00076 (*shares)++;
00077 }
00078
00079 template <class X> DMatrix<X>& DMatrix<X>::operator=(const DMatrix& m) {
00080 if (!--(*shares)) {
00081 delete [] elems;
00082 delete [] mrows;
00083 delete shares;
00084 }
00085 shares=m.shares;
00086 elems=m.elems;
00087 nrows=m.nrows;
00088 ncols=m.ncols;
00089 mrows=m.mrows;
00090 (*shares)++;
00091 return *this;
00092 }
00093
00094 template <class X> DMatrix<X> DMatrix<X>::inv() const {
00095 if (nrows!=ncols) throw DNotInvertibleMatrixException();
00096 DMatrix<X> aux1(*this),aux2(I(nrows));
00097 aux1.detach();
00098 for (int i=0;i<nrows;i++) {
00099 int k=i;
00100 for (;k<nrows&&aux1.mrows[k][i]==X(0);k++);
00101 if (k>=nrows) throw DNotInvertibleMatrixException();
00102 X val=aux1.mrows[k][i];
00103 for (int j=0;j<nrows;j++) {
00104 aux1.mrows[k][j]=aux1.mrows[k][j]/val;
00105 aux2.mrows[k][j]=aux2.mrows[k][j]/val;
00106 }
00107 if (k!=i) {
00108 for (int j=0;j<nrows;j++) {
00109 X tmp=aux1.mrows[k][j];
00110 aux1.mrows[k][j]=aux1.mrows[i][j];
00111 aux1.mrows[i][j]=tmp;
00112 tmp=aux2.mrows[k][j];
00113 aux2.mrows[k][j]=aux2.mrows[i][j];
00114 aux2.mrows[i][j]=tmp;
00115 }
00116 }
00117 for (int j=0;j<nrows;j++)
00118 if (j!=i) {
00119 X tmp=aux1.mrows[j][i];
00120 for (int l=0;l<nrows;l++) {
00121 aux1.mrows[j][l]=aux1.mrows[j][l]-tmp*aux1.mrows[i][l];
00122 aux2.mrows[j][l]=aux2.mrows[j][l]-tmp*aux2.mrows[i][l];
00123 }
00124 }
00125 }
00126 return aux2;
00127 }
00128
00129 template <class X> const X DMatrix<X>::det() const {
00130 if (nrows!=ncols) throw DNotSquareMatrixException();
00131 DMatrix<X> aux(*this);
00132 X d=X(1);
00133 aux.detach();
00134 for (int i=0;i<nrows;i++) {
00135 int k=i;
00136 for (;k<nrows&&aux.mrows[k][i]==X(0);k++);
00137 if (k>=nrows) return X(0);
00138 X val=aux.mrows[k][i];
00139 for (int j=0;j<nrows;j++) {
00140 aux.mrows[k][j]/=val;
00141 }
00142 d=d*val;
00143 if (k!=i) {
00144 for (int j=0;j<nrows;j++) {
00145 X tmp=aux.mrows[k][j];
00146 aux.mrows[k][j]=aux.mrows[i][j];
00147 aux.mrows[i][j]=tmp;
00148 }
00149 d=-d;
00150 }
00151 for (int j=i+1;j<nrows;j++){
00152 X tmp=aux.mrows[j][i];
00153 if (!(tmp==X(0)) ){
00154 for (int l=0;l<nrows;l++) {
00155 aux.mrows[j][l]=aux.mrows[j][l]-tmp*aux.mrows[i][l];
00156 }
00157
00158 }
00159 }
00160 }
00161 return d;
00162 }
00163
00164 template <class X> DMatrix<X> DMatrix<X>::transpose() const {
00165 DMatrix<X> aux(ncols, nrows);
00166 for (int i=0; i<nrows; i++)
00167 for (int j=0; j<ncols; j++)
00168 aux[j][i]=mrows[i][j];
00169 return aux;
00170 }
00171
00172 template <class X> DMatrix<X> DMatrix<X>::operator*(const DMatrix<X>& m) const {
00173 if (ncols!=m.nrows) throw DIncompatibleMatrixException();
00174 DMatrix<X> aux(nrows,m.ncols);
00175 for (int i=0;i<nrows;i++)
00176 for (int j=0;j<m.ncols;j++){
00177 X a=0;
00178 for (int k=0;k<ncols;k++)
00179 a+=mrows[i][k]*m.mrows[k][j];
00180 aux.mrows[i][j]=a;
00181 }
00182 return aux;
00183 }
00184
00185 template <class X> DMatrix<X> DMatrix<X>::operator+(const DMatrix<X>& m) const {
00186 if (ncols!=m.ncols||nrows!=m.nrows) throw DIncompatibleMatrixException();
00187 DMatrix<X> aux(nrows,ncols);
00188 for (int i=0;i<nrows*ncols;i++) aux.elems[i]=elems[i]+m.elems[i];
00189 return aux;
00190 }
00191
00192 template <class X> DMatrix<X> DMatrix<X>::operator-(const DMatrix<X>& m) const {
00193 if (ncols!=m.ncols||nrows!=m.nrows) throw DIncompatibleMatrixException();
00194 DMatrix<X> aux(nrows,ncols);
00195 for (int i=0;i<nrows*ncols;i++) aux.elems[i]=elems[i]-m.elems[i];
00196 return aux;
00197 }
00198
00199 template <class X> DMatrix<X> DMatrix<X>::operator*(const X& e) const {
00200 DMatrix<X> aux(nrows,ncols);
00201 for (int i=0;i<nrows*ncols;i++) aux.elems[i]=elems[i]*e;
00202 return aux;
00203 }
00204
00205 template <class X> void DMatrix<X>::detach() {
00206 DMatrix<X> aux(nrows,ncols);
00207 for (int i=0;i<nrows*ncols;i++) aux.elems[i]=elems[i];
00208 operator=(aux);
00209 }
00210
00211 template <class X> DMatrix<X> DMatrix<X>::I(int n) {
00212 DMatrix<X> aux(n,n);
00213 for (int i=0;i<n;i++) aux[i][i]=X(1);
00214 return aux;
00215 }
00216
00217 template <class X> std::ostream& operator<<(std::ostream& os, const DMatrix<X> &m) {
00218 os << "{";
00219 for (int i=0;i<m.rows();i++) {
00220 if (i>0) os << ",";
00221 os << "{";
00222 for (int j=0;j<m.columns();j++) {
00223 if (j>0) os << ",";
00224 os << m[i][j];
00225 }
00226 os << "}";
00227 }
00228 return os << "}";
00229 }
00230
00231 };
00232 #endif