Go to the documentation of this file.00001 #include "StateTransitionXByFunction.h"
00002 #include "MOMDP.h"
00003
00004 StateTransitionXByFunction::StateTransitionXByFunction(void)
00005 {
00006 }
00007
00008 StateTransitionXByFunction::~StateTransitionXByFunction(void)
00009 {
00010 }
00011
00012
00013 SharedPointer<SparseMatrix> StateTransitionXByFunction::getMatrix(int a, int x)
00014 {
00015 stringstream ss;
00016 ss << "StateTransX a " << a << " x " << x;
00017 string key = ss.str();
00018 if(problem->cache.hasKey(key))
00019 {
00020
00021 }
00022 else
00023 {
00024 problem->cache.put(key, getMatrixInner(a,x,false));
00025 }
00026 return problem->cache.get(key);
00027 }
00028 SharedPointer<SparseMatrix> StateTransitionXByFunction::getMatrixTr(int a, int x)
00029 {
00030 stringstream ss;
00031 ss << "StateTransXTr a " << a << " x " << x;
00032 string key = ss.str();
00033 if(problem->cache.hasKey(key))
00034 {
00035
00036 }
00037 else
00038 {
00039 problem->cache.put(key, getMatrixInner(a,x,true));
00040 }
00041 return problem->cache.get(key);
00042 }
00043 SharedPointer<SparseMatrix> StateTransitionXByFunction::getMatrixInner(int a, int x, bool transpose)
00044 {
00045 kmatrix tempMatrix;
00046
00047 int numObsState = problem->XStates->size();
00048 int numUnobsState = problem->YStates->size();
00049 if(!transpose)
00050 {
00051 tempMatrix.resize(numUnobsState, numObsState);
00052 }
00053 else
00054 {
00055 tempMatrix.resize(numObsState, numUnobsState);
00056 }
00057
00058 SharedPointer<SparseMatrix> result (new SparseMatrix());
00059 ValueSet aVals = problem->actions->get(a);
00060 ValueSet xVals = problem->XStates->get(x);
00061
00062 FOR(s, problem->YStates->size())
00063 {
00064 ValueSet yVals = problem->YStates->get(s);
00065 map<string, SharedPointer<IVariableValue> > sourceVals;
00066 sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00067 sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00068 sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00069
00070 vector<vector<SharedPointer<RelEntry> > > RelEntries;
00071 FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00072 {
00073 RelEntries.push_back((*curRel)->getProb(sourceVals));
00074 }
00075
00076 vector<int> curProgress;
00077 FOR(index, RelEntries.size())
00078 {
00079 curProgress.push_back(0);
00080 }
00081 while(true)
00082 {
00083 map<string, SharedPointer<IVariableValue> > combinedDestProb;
00084 double combinedProb = 1.0;
00085 FOR(index, RelEntries.size())
00086 {
00087 vector<SharedPointer<RelEntry> > destProbs = RelEntries[index];
00088 int progress = curProgress[index];
00089 SharedPointer<RelEntry> curRel = destProbs[progress];
00090
00091 combinedDestProb.insert(curRel->destValues.begin(), curRel->destValues.end());
00092 combinedProb *= curRel->prob;
00093
00094 }
00095 int destX = problem->XStates->indexOf(combinedDestProb);
00096 if(!transpose)
00097 {
00098 if(destX >= numObsState)
00099 {
00100 throw runtime_error("index exceeded");
00101 }
00102 tempMatrix.push_back(s, destX, combinedProb);
00103 }
00104 else
00105 {
00106 tempMatrix.push_back(destX, s, combinedProb);
00107 }
00108
00109
00110
00111 curProgress[0] ++;
00112 bool done = false;
00113 FOR(index, RelEntries.size())
00114 {
00115 if(curProgress[index] >= RelEntries[index].size())
00116 {
00117 curProgress[index] = 0;
00118 if(index + 1 >= curProgress.size())
00119 {
00120
00121 done = true;
00122 }
00123 else
00124 {
00125 curProgress[index + 1] ++;
00126 }
00127 }
00128 }
00129
00130 if(done)
00131 {
00132 break;
00133 }
00134
00135 }
00136 }
00137
00138 copy(*result, tempMatrix);
00139 return result;
00140 }
00141