ObservationProbabilitiesByFunction.cpp
Go to the documentation of this file.
00001 #include "ObservationProbabilitiesByFunction.h"
00002 #include "MOMDP.h"
00003 #include <sstream>
00004 
00005 ObservationProbabilitiesByFunction::ObservationProbabilitiesByFunction(void)
00006 {
00007 }
00008 
00009 ObservationProbabilitiesByFunction::~ObservationProbabilitiesByFunction(void)
00010 {
00011 }
00012 
00013 // (unobserved states, obs)
00014 SharedPointer<SparseMatrix> ObservationProbabilitiesByFunction::getMatrix(int a, int x)
00015 {
00016         stringstream ss;
00017         ss << "obsProb a " << a << " x " << x;
00018         string key = ss.str();
00019         if(problem->cache.hasKey(key))
00020         {
00021                 
00022         }
00023         else
00024         {
00025                 problem->cache.put(key, getMatrixInner(a,x,false));
00026         }
00027         return problem->cache.get(key);
00028         
00029 }
00030 SharedPointer<SparseMatrix> ObservationProbabilitiesByFunction::getMatrixTr(int a, int x)
00031 {
00032         stringstream ss;
00033         ss << "obsProbTr a " << a << " x " << x;
00034         string key = ss.str();
00035         if(problem->cache.hasKey(key))
00036         {
00037                 
00038         }
00039         else
00040         {
00041                 problem->cache.put(key, getMatrixInner(a,x,true));
00042         }
00043         return problem->cache.get(key);
00044 }
00045 
00046 
00047 SharedPointer<SparseMatrix> ObservationProbabilitiesByFunction::getMatrixInner(int a, int x, bool transpose)
00048 {
00049         kmatrix tempMatrix;
00050         
00051         int numObs = problem->observations->size();
00052         int numUnobsState = problem->YStates->size();
00053         if(!transpose)
00054         {
00055                 tempMatrix.resize(numUnobsState, numObs);
00056         }
00057         else
00058         {
00059                 tempMatrix.resize(numObs, numUnobsState);
00060         }
00061 
00062         SharedPointer<SparseMatrix> result (new SparseMatrix());
00063         ValueSet aVals = problem->actions->get(a);
00064         ValueSet xVals = problem->XStates->get(x);
00065 
00066         FOR(s, problem->YStates->size())
00067         {
00068                 ValueSet yVals = problem->YStates->get(s);
00069                 map<string, SharedPointer<IVariableValue> > sourceVals;
00070                 sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00071                 sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00072                 sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00073                 
00074                 vector<vector<SharedPointer<RelEntry> > > RelEntries;
00075                 FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00076                 {
00077                         RelEntries.push_back((*curRel)->getProb(sourceVals));
00078                 }
00079 
00080                 vector<int> curProgress;
00081                 FOR(index, RelEntries.size())
00082                 {
00083                         curProgress.push_back(0);
00084                 }
00085                 while(true)
00086                 {
00087                         map<string, SharedPointer<IVariableValue> > combinedDestProb;
00088                         double combinedProb = 1.0;
00089                         FOR(index, RelEntries.size())
00090                         {
00091                                 vector<SharedPointer<RelEntry> > destProbs =  RelEntries[index];
00092                                 int progress = curProgress[index];
00093                                 SharedPointer<RelEntry> curRel = destProbs[progress];
00094 
00095                                 combinedDestProb.insert(curRel->destValues.begin(), curRel->destValues.end());
00096                                 combinedProb *= curRel->prob;
00097 
00098                         }
00099                         int destO = problem->observations->indexOf(combinedDestProb);
00100                         if(!transpose)
00101                         {
00102                                 tempMatrix.push_back(s, destO, combinedProb);
00103                         }
00104                         else
00105                         {
00106                                 tempMatrix.push_back(destO, s, combinedProb);
00107                         }
00108                         // set (s, destX) = combinedProb;
00109 
00110                         // move to next combination
00111                         curProgress[0] ++;
00112                         bool done = false;
00113                         FOR(index, RelEntries.size())
00114                         {
00115                                 if(curProgress[index] >= RelEntries[index].size())
00116                                 {
00117                                         curProgress[index] = 0;
00118                                         if(index + 1 >= curProgress.size())
00119                                         {
00120                                                 // carry in at most significant pos, should stop
00121                                                 done = true;
00122                                         }
00123                                         else
00124                                         {
00125                                                 curProgress[index + 1] ++;
00126                                         }
00127                                 }
00128                         }
00129 
00130                         if(done)
00131                         {
00132                                 break;
00133                         }
00134 
00135                 }
00136         }
00137 
00138         copy(*result, tempMatrix);
00139         return result;
00140 }
00141 


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29