Go to the documentation of this file.00001 #include <algorithm>
00002 #include "SparseMatrix.h"
00003
00004 #include "DenseVector.h"
00005 using namespace momdp;
00006 namespace momdp
00007 {
00008 REAL_VALUE SparseMatrix::operator()(int r, int c) const
00009 {
00010 vector<SparseVector_Entry>::const_iterator di;
00011
00012 vector<int>::const_iterator col = lower_bound(cols.begin(), cols.end(), c);
00013 if (col == cols.end() or *col != c)
00014 return 0.0;
00015 int ci = col - cols.begin();
00016
00017 vector<SparseVector_Entry>::const_iterator col_end = data.begin() + colEnd(ci);
00018
00019 for (di = data.begin() + cols_start[ci]; di != col_end; di++) {
00020 if (di->index >= r) {
00021 if (di->index == r) {
00022 return di->value;
00023 } else {
00024 return 0.0;
00025 }
00026 }
00027 }
00028 return 0.0;
00029 }
00030
00031 SparseCol SparseMatrix::col(int c) const
00032 {
00033 vector<int>::const_iterator col = lower_bound(cols.begin(), cols.end(), c);
00034 if (col == cols.end() or *col != c)
00035 return SparseCol();
00036
00037 int ci = col - cols.begin();
00038 vector<SparseVector_Entry>::const_iterator col_start = data.begin() + cols_start[ci];
00039 vector<SparseVector_Entry>::const_iterator col_end = data.begin() + colEnd(ci);
00040 return SparseCol(col_start, col_end);
00041 }
00042
00043 DenseVector* SparseMatrix::mult(const DenseVector& x) const {
00044 DenseVector *result = new DenseVector( x.size());
00045 this->mult(x, *result);
00046 return result;
00047 }
00048
00049 void SparseMatrix::mult(const DenseVector& x, DenseVector& result) const
00050 {
00051
00052
00053
00054 vector<SparseVector_Entry>::const_iterator Ai, col_end;
00055
00056 result.resize( x.size() );
00057
00058 double xval;
00059
00060 int mycolIndex = 0;
00061
00062 FOR(xind, x.size())
00063 {
00064 xval = x.data[xind];
00065 while (mycolIndex < cols.size() && cols[mycolIndex] < xind) mycolIndex++;
00066 if (mycolIndex == cols.size()) break;
00067 if (cols[mycolIndex] == xind) {
00068 col_end = data.begin() + colEnd(mycolIndex);
00069 for (Ai = data.begin() + cols_start[mycolIndex]; Ai != col_end; Ai++)
00070 {
00071 result(Ai->index) += xval * Ai->value;
00072 }
00073 }
00074 }
00075
00076 return;
00077 }
00078
00079 DenseVector* SparseMatrix::mult(const SparseVector& x) const {
00080 DenseVector *result = new DenseVector( x.size());
00081 this->mult(x, *result);
00082 return result;
00083 }
00084
00085 void SparseMatrix::mult(const SparseVector& x, DenseVector& result) const
00086 {
00087 vector<SparseVector_Entry>::const_iterator Ai, col_end;
00088
00089 result.resize( x.size() );
00090
00091 int xind;
00092 double xval;
00093
00094 int mycolIndex = 0;
00095
00096 FOREACH(SparseVector_Entry, xi, x.data)
00097 {
00098 xind = xi->index;
00099 xval = xi->value;
00100 while (mycolIndex < cols.size() && cols[mycolIndex] < xind) mycolIndex++;
00101 if (mycolIndex == cols.size()) break;
00102 if (cols[mycolIndex] == xind) {
00103 col_end = data.begin() + colEnd(mycolIndex);
00104 for (Ai = data.begin() + cols_start[mycolIndex];
00105 Ai != col_end;
00106 Ai++)
00107 {
00108 result(Ai->index) += xval * Ai->value;
00109 }
00110 }
00111 }
00112
00113 return;
00114 }
00115
00116 void SparseMatrix::leftMult(const DenseVector& x, DenseVector& result) const
00117 {
00118 vector<SparseVector_Entry>::const_iterator Ai, col_end;
00119
00120 assert( x.size() == size1() );
00121 result.resize( size2() );
00122
00123 FOR (ci, cols.size()) {
00124 int c = cols[ci];
00125 col_end = data.begin() + colEnd(ci);
00126 for (Ai = data.begin() + cols_start[ci]; Ai != col_end; Ai++) {
00127 result(c) += x(Ai->index) * Ai->value;
00128 }
00129 }
00130 }
00131
00132 void SparseMatrix::leftMult(const SparseVector& x, DenseVector& result) const
00133 {
00134 assert( x.size() == size1() );
00135 result.resize( size2() );
00136
00137 FOR (ci, cols.size()) {
00138 int c = cols[ci];
00139 result(c) = inner_prod_SparseVector_internal(data.begin() + cols_start[ci],
00140 data.begin() + colEnd(ci), x.data.begin(), x.data.end() );
00141 }
00142 }
00143
00144 void SparseMatrix::resize(int _size1, int _size2)
00145 {
00146 size1_ = _size1;
00147 size2_ = _size2;
00148 cols.clear();
00149 cols_start.clear();
00150 data.clear();
00151 }
00152
00153 void SparseMatrix::push_back(int r, int c, REAL_VALUE value)
00154 {
00155
00156
00157 data.push_back( SparseVector_Entry( r, value ) );
00158
00159 if (cols.empty() or cols.back() < c) {
00160
00161 cols.push_back(c);
00162 cols_start.push_back(data.size()-1);
00163 } else {
00164 assert(cols.back() == c);
00165 }
00166 }
00167
00168 void SparseMatrix::canonicalize(void)
00169 {
00170
00171
00172
00173
00174
00175
00176
00177 }
00178
00179 void SparseMatrix::read(std::istream& in)
00180 {
00181
00182 int rows, cols;
00183 int num_entries;
00184 int r, c;
00185 REAL_VALUE value;
00186
00187 in >> rows >> cols;
00188 resize( rows, cols );
00189
00190 in >> num_entries;
00191 FOR (i, num_entries)
00192 {
00193 in >> r >> c >> value;
00194 push_back( r, c, value );
00195 }
00196 }
00197
00198 std::ostream& SparseMatrix::write(std::ostream& out) const
00199 {
00200 out << size1_ << " " << size2_ << std::endl;
00201 out << data.size() << std::endl;
00202 FOR (ci, cols.size()) {
00203 int c = cols[ci];
00204 int col_start = cols_start[ci];
00205 int col_end = colEnd(ci);
00206 for (int di = col_start; di < col_end; di++) {
00207 out << data[di].index << " " << c << " " << data[di].value << std::endl;
00208 }
00209 }
00210 return out;
00211 }
00212
00213 REAL_VALUE SparseMatrix::getMaxValue()
00214 {
00215 REAL_VALUE maxVal = data.begin()->value;
00216 REAL_VALUE val;
00217 FOREACH(SparseVector_Entry, entry, data)
00218 {
00219 val = entry->value;
00220 if(val>maxVal){
00221 maxVal = val;
00222 }
00223 }
00224 return maxVal;
00225 }
00226
00227 const vector<int>& SparseMatrix::nonEmptyColumns() const {
00228 return cols;
00229 }
00230
00231 bool SparseMatrix::isColumnEmpty(int c) const
00232 {
00233 return !binary_search(cols.begin(), cols.end(), c);
00234 }
00235
00236 int SparseMatrix::colEnd(int index) const
00237 {
00238 int col_end = index+1 < cols.size() ? cols_start[index+1] : data.size();
00239 return col_end;
00240 }
00241 }
00242