StateTransitionXByFunction.cpp
Go to the documentation of this file.
00001 #include "StateTransitionXByFunction.h"
00002 #include "MOMDP.h"
00003 
00004 StateTransitionXByFunction::StateTransitionXByFunction(void)
00005 {
00006 }
00007 
00008 StateTransitionXByFunction::~StateTransitionXByFunction(void)
00009 {
00010 }
00011 
00012 // (unobserved states, observed states)
00013 SharedPointer<SparseMatrix> StateTransitionXByFunction::getMatrix(int a, int x)
00014 {
00015         stringstream ss;
00016         ss << "StateTransX a " << a << " x " << x;
00017         string key = ss.str();
00018         if(problem->cache.hasKey(key))
00019         {
00020                 
00021         }
00022         else
00023         {
00024                 problem->cache.put(key, getMatrixInner(a,x,false));
00025         }
00026         return problem->cache.get(key);
00027 }
00028 SharedPointer<SparseMatrix> StateTransitionXByFunction::getMatrixTr(int a, int x)
00029 {
00030         stringstream ss;
00031         ss << "StateTransXTr a " << a << " x " << x;
00032         string key = ss.str();
00033         if(problem->cache.hasKey(key))
00034         {
00035                 
00036         }
00037         else
00038         {
00039                 problem->cache.put(key, getMatrixInner(a,x,true));
00040         }
00041         return problem->cache.get(key);
00042 }
00043 SharedPointer<SparseMatrix> StateTransitionXByFunction::getMatrixInner(int a, int x, bool transpose)
00044 {
00045         kmatrix tempMatrix;
00046 
00047         int numObsState = problem->XStates->size();
00048         int numUnobsState = problem->YStates->size();
00049         if(!transpose)
00050         {
00051                 tempMatrix.resize(numUnobsState, numObsState);
00052         }
00053         else
00054         {
00055                 tempMatrix.resize(numObsState, numUnobsState);
00056         }
00057 
00058         SharedPointer<SparseMatrix> result (new SparseMatrix());
00059         ValueSet aVals = problem->actions->get(a);
00060         ValueSet xVals = problem->XStates->get(x);
00061 
00062         FOR(s, problem->YStates->size())
00063         {
00064                 ValueSet yVals = problem->YStates->get(s);
00065                 map<string, SharedPointer<IVariableValue> > sourceVals;
00066                 sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00067                 sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00068                 sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00069                 
00070                 vector<vector<SharedPointer<RelEntry> > > RelEntries;
00071                 FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00072                 {
00073                         RelEntries.push_back((*curRel)->getProb(sourceVals));
00074                 }
00075 
00076                 vector<int> curProgress;
00077                 FOR(index, RelEntries.size())
00078                 {
00079                         curProgress.push_back(0);
00080                 }
00081                 while(true)
00082                 {
00083                         map<string, SharedPointer<IVariableValue> > combinedDestProb;
00084                         double combinedProb = 1.0;
00085                         FOR(index, RelEntries.size())
00086                         {
00087                                 vector<SharedPointer<RelEntry> > destProbs =  RelEntries[index];
00088                                 int progress = curProgress[index];
00089                                 SharedPointer<RelEntry> curRel = destProbs[progress];
00090 
00091                                 combinedDestProb.insert(curRel->destValues.begin(), curRel->destValues.end());
00092                                 combinedProb *= curRel->prob;
00093 
00094                         }
00095                         int destX = problem->XStates->indexOf(combinedDestProb);
00096                         if(!transpose)
00097                         {
00098                                 if(destX >= numObsState)
00099                                 {
00100                                         throw runtime_error("index exceeded");
00101                                 }
00102                                 tempMatrix.push_back(s, destX, combinedProb);
00103                         }
00104                         else
00105                         {
00106                                 tempMatrix.push_back(destX, s, combinedProb);
00107                         }
00108                         // set (s, destX) = combinedProb;
00109 
00110                         // move to next combination
00111                         curProgress[0] ++;
00112                         bool done = false;
00113                         FOR(index, RelEntries.size())
00114                         {
00115                                 if(curProgress[index] >= RelEntries[index].size())
00116                                 {
00117                                         curProgress[index] = 0;
00118                                         if(index + 1 >= curProgress.size())
00119                                         {
00120                                                 // carry in at most significant pos, should stop
00121                                                 done = true;
00122                                         }
00123                                         else
00124                                         {
00125                                                 curProgress[index + 1] ++;
00126                                         }
00127                                 }
00128                         }
00129 
00130                         if(done)
00131                         {
00132                                 break;
00133                         }
00134 
00135                 }
00136         }
00137 
00138         copy(*result, tempMatrix);
00139         return result;
00140 }
00141 


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29