StateTransitionXYByFunction.cpp
Go to the documentation of this file.
00001 #include "StateTransitionXYByFunction.h"
00002 #include "MOMDP.h"
00003 
00004 StateTransitionXYByFunction::StateTransitionXYByFunction(void)
00005 {
00006 }
00007 
00008 StateTransitionXYByFunction::~StateTransitionXYByFunction(void)
00009 {
00010 }
00011 // (unobserved states, unobserved states)
00012 SharedPointer<SparseMatrix> StateTransitionXYByFunction::getMatrix(int a, int x)
00013 {
00014         stringstream ss;
00015         ss << "StateTransitionXY a " << a << " x " << x;
00016         string key = ss.str();
00017         if(problem->cache.hasKey(key))
00018         {
00019                 
00020         }
00021         else
00022         {
00023                 problem->cache.put(key, getMatrixInner(a,x,false));
00024         }
00025         return problem->cache.get(key);
00026 }
00027 SharedPointer<SparseMatrix> StateTransitionXYByFunction::getMatrixTr(int a, int x)
00028 {
00029         stringstream ss;
00030         ss << "StateTransitionXYTr a " << a << " x " << x;
00031         string key = ss.str();
00032         if(problem->cache.hasKey(key))
00033         {
00034                 
00035         }
00036         else
00037         {
00038                 problem->cache.put(key, getMatrixInner(a,x,true));
00039         }
00040         return problem->cache.get(key);
00041 }
00042 SharedPointer<SparseMatrix> StateTransitionXYByFunction::getMatrixInner(int a, int x, bool transpose)
00043 {
00044         kmatrix tempMatrix;
00045         int numUnobsState = problem->YStates->size();
00046         tempMatrix.resize(numUnobsState, numUnobsState);
00047 
00048         SharedPointer<SparseMatrix> result (new SparseMatrix());
00049         ValueSet aVals = problem->actions->get(a);
00050         ValueSet xVals = problem->XStates->get(x);
00051 
00052         FOR(s, problem->YStates->size())
00053         {
00054                 ValueSet yVals = problem->YStates->get(s);
00055                 map<string, SharedPointer<IVariableValue> > sourceVals;
00056                 sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00057                 sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00058                 sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00059                 
00060                 vector<vector<SharedPointer<RelEntry> > > RelEntries;
00061                 FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00062                 {
00063                         RelEntries.push_back((*curRel)->getProb(sourceVals));
00064                 }
00065 
00066                 vector<int> curProgress;
00067                 FOR(index, RelEntries.size())
00068                 {
00069                         curProgress.push_back(0);
00070                 }
00071                 while(true)
00072                 {
00073                         map<string, SharedPointer<IVariableValue> > combinedDestProb;
00074                         double combinedProb = 1.0;
00075                         FOR(index, RelEntries.size())
00076                         {
00077                                 vector<SharedPointer<RelEntry> > destProbs =  RelEntries[index];
00078                                 int progress = curProgress[index];
00079                                 SharedPointer<RelEntry> curRel = destProbs[progress];
00080 
00081                                 combinedDestProb.insert(curRel->destValues.begin(), curRel->destValues.end());
00082                                 combinedProb *= curRel->prob;
00083 
00084                         }
00085                         int destY = problem->YStates->indexOf(combinedDestProb);
00086                         if(!transpose)
00087                         {
00088                                 tempMatrix.push_back(s, destY, combinedProb);
00089                         }
00090                         else
00091                         {
00092                                 tempMatrix.push_back(destY, s, combinedProb);
00093                         }
00094                         // set (s, destX) = combinedProb;
00095 
00096                         // move to next combination
00097                         curProgress[0] ++;
00098                         bool done = false;
00099                         FOR(index, RelEntries.size())
00100                         {
00101                                 if(curProgress[index] >= RelEntries[index].size())
00102                                 {
00103                                         curProgress[index] = 0;
00104                                         if(index + 1 >= curProgress.size())
00105                                         {
00106                                                 // carry in at most significant pos, should stop
00107                                                 done = true;
00108                                         }
00109                                         else
00110                                         {
00111                                                 curProgress[index + 1] ++;
00112                                         }
00113                                 }
00114                         }
00115 
00116                         if(done)
00117                         {
00118                                 break;
00119                         }
00120 
00121                 }
00122         }
00123 
00124         copy(*result, tempMatrix);
00125         return result;
00126 }


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29