Go to the documentation of this file.00001 #include "StateTransitionXYByFunction.h"
00002 #include "MOMDP.h"
00003
00004 StateTransitionXYByFunction::StateTransitionXYByFunction(void)
00005 {
00006 }
00007
00008 StateTransitionXYByFunction::~StateTransitionXYByFunction(void)
00009 {
00010 }
00011
00012 SharedPointer<SparseMatrix> StateTransitionXYByFunction::getMatrix(int a, int x)
00013 {
00014 stringstream ss;
00015 ss << "StateTransitionXY a " << a << " x " << x;
00016 string key = ss.str();
00017 if(problem->cache.hasKey(key))
00018 {
00019
00020 }
00021 else
00022 {
00023 problem->cache.put(key, getMatrixInner(a,x,false));
00024 }
00025 return problem->cache.get(key);
00026 }
00027 SharedPointer<SparseMatrix> StateTransitionXYByFunction::getMatrixTr(int a, int x)
00028 {
00029 stringstream ss;
00030 ss << "StateTransitionXYTr a " << a << " x " << x;
00031 string key = ss.str();
00032 if(problem->cache.hasKey(key))
00033 {
00034
00035 }
00036 else
00037 {
00038 problem->cache.put(key, getMatrixInner(a,x,true));
00039 }
00040 return problem->cache.get(key);
00041 }
00042 SharedPointer<SparseMatrix> StateTransitionXYByFunction::getMatrixInner(int a, int x, bool transpose)
00043 {
00044 kmatrix tempMatrix;
00045 int numUnobsState = problem->YStates->size();
00046 tempMatrix.resize(numUnobsState, numUnobsState);
00047
00048 SharedPointer<SparseMatrix> result (new SparseMatrix());
00049 ValueSet aVals = problem->actions->get(a);
00050 ValueSet xVals = problem->XStates->get(x);
00051
00052 FOR(s, problem->YStates->size())
00053 {
00054 ValueSet yVals = problem->YStates->get(s);
00055 map<string, SharedPointer<IVariableValue> > sourceVals;
00056 sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00057 sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00058 sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00059
00060 vector<vector<SharedPointer<RelEntry> > > RelEntries;
00061 FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00062 {
00063 RelEntries.push_back((*curRel)->getProb(sourceVals));
00064 }
00065
00066 vector<int> curProgress;
00067 FOR(index, RelEntries.size())
00068 {
00069 curProgress.push_back(0);
00070 }
00071 while(true)
00072 {
00073 map<string, SharedPointer<IVariableValue> > combinedDestProb;
00074 double combinedProb = 1.0;
00075 FOR(index, RelEntries.size())
00076 {
00077 vector<SharedPointer<RelEntry> > destProbs = RelEntries[index];
00078 int progress = curProgress[index];
00079 SharedPointer<RelEntry> curRel = destProbs[progress];
00080
00081 combinedDestProb.insert(curRel->destValues.begin(), curRel->destValues.end());
00082 combinedProb *= curRel->prob;
00083
00084 }
00085 int destY = problem->YStates->indexOf(combinedDestProb);
00086 if(!transpose)
00087 {
00088 tempMatrix.push_back(s, destY, combinedProb);
00089 }
00090 else
00091 {
00092 tempMatrix.push_back(destY, s, combinedProb);
00093 }
00094
00095
00096
00097 curProgress[0] ++;
00098 bool done = false;
00099 FOR(index, RelEntries.size())
00100 {
00101 if(curProgress[index] >= RelEntries[index].size())
00102 {
00103 curProgress[index] = 0;
00104 if(index + 1 >= curProgress.size())
00105 {
00106
00107 done = true;
00108 }
00109 else
00110 {
00111 curProgress[index + 1] ++;
00112 }
00113 }
00114 }
00115
00116 if(done)
00117 {
00118 break;
00119 }
00120
00121 }
00122 }
00123
00124 copy(*result, tempMatrix);
00125 return result;
00126 }