Go to the documentation of this file.00001 #include "ObservationProbabilitiesByFunction.h"
00002 #include "MOMDP.h"
00003 #include <sstream>
00004
00005 ObservationProbabilitiesByFunction::ObservationProbabilitiesByFunction(void)
00006 {
00007 }
00008
00009 ObservationProbabilitiesByFunction::~ObservationProbabilitiesByFunction(void)
00010 {
00011 }
00012
00013
00014 SharedPointer<SparseMatrix> ObservationProbabilitiesByFunction::getMatrix(int a, int x)
00015 {
00016 stringstream ss;
00017 ss << "obsProb a " << a << " x " << x;
00018 string key = ss.str();
00019 if(problem->cache.hasKey(key))
00020 {
00021
00022 }
00023 else
00024 {
00025 problem->cache.put(key, getMatrixInner(a,x,false));
00026 }
00027 return problem->cache.get(key);
00028
00029 }
00030 SharedPointer<SparseMatrix> ObservationProbabilitiesByFunction::getMatrixTr(int a, int x)
00031 {
00032 stringstream ss;
00033 ss << "obsProbTr a " << a << " x " << x;
00034 string key = ss.str();
00035 if(problem->cache.hasKey(key))
00036 {
00037
00038 }
00039 else
00040 {
00041 problem->cache.put(key, getMatrixInner(a,x,true));
00042 }
00043 return problem->cache.get(key);
00044 }
00045
00046
00047 SharedPointer<SparseMatrix> ObservationProbabilitiesByFunction::getMatrixInner(int a, int x, bool transpose)
00048 {
00049 kmatrix tempMatrix;
00050
00051 int numObs = problem->observations->size();
00052 int numUnobsState = problem->YStates->size();
00053 if(!transpose)
00054 {
00055 tempMatrix.resize(numUnobsState, numObs);
00056 }
00057 else
00058 {
00059 tempMatrix.resize(numObs, numUnobsState);
00060 }
00061
00062 SharedPointer<SparseMatrix> result (new SparseMatrix());
00063 ValueSet aVals = problem->actions->get(a);
00064 ValueSet xVals = problem->XStates->get(x);
00065
00066 FOR(s, problem->YStates->size())
00067 {
00068 ValueSet yVals = problem->YStates->get(s);
00069 map<string, SharedPointer<IVariableValue> > sourceVals;
00070 sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00071 sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00072 sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00073
00074 vector<vector<SharedPointer<RelEntry> > > RelEntries;
00075 FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00076 {
00077 RelEntries.push_back((*curRel)->getProb(sourceVals));
00078 }
00079
00080 vector<int> curProgress;
00081 FOR(index, RelEntries.size())
00082 {
00083 curProgress.push_back(0);
00084 }
00085 while(true)
00086 {
00087 map<string, SharedPointer<IVariableValue> > combinedDestProb;
00088 double combinedProb = 1.0;
00089 FOR(index, RelEntries.size())
00090 {
00091 vector<SharedPointer<RelEntry> > destProbs = RelEntries[index];
00092 int progress = curProgress[index];
00093 SharedPointer<RelEntry> curRel = destProbs[progress];
00094
00095 combinedDestProb.insert(curRel->destValues.begin(), curRel->destValues.end());
00096 combinedProb *= curRel->prob;
00097
00098 }
00099 int destO = problem->observations->indexOf(combinedDestProb);
00100 if(!transpose)
00101 {
00102 tempMatrix.push_back(s, destO, combinedProb);
00103 }
00104 else
00105 {
00106 tempMatrix.push_back(destO, s, combinedProb);
00107 }
00108
00109
00110
00111 curProgress[0] ++;
00112 bool done = false;
00113 FOR(index, RelEntries.size())
00114 {
00115 if(curProgress[index] >= RelEntries[index].size())
00116 {
00117 curProgress[index] = 0;
00118 if(index + 1 >= curProgress.size())
00119 {
00120
00121 done = true;
00122 }
00123 else
00124 {
00125 curProgress[index + 1] ++;
00126 }
00127 }
00128 }
00129
00130 if(done)
00131 {
00132 break;
00133 }
00134
00135 }
00136 }
00137
00138 copy(*result, tempMatrix);
00139 return result;
00140 }
00141