00001 #include "MOMDP.h"
00002 #include "AlphaVectorPolicy.h"
00003 #include "BeliefForest.h"
00004 #include "Sample.h"
00005 #include "BeliefCache.h"
00006 #include "EvaluatorSampleEngine.h"
00007 #include "EvaluatorBeliefTreeNodeTuple.h"
00008
00009 #include "PolicyGraphGenerator.h"
00010
00011 #include <string>
00012 #include <stdlib.h>
00013 #include <sstream>
00014 #include <fstream>
00015
00016 using namespace std;
00017 using namespace momdp;
00018
00019
00020
00021
00022 bool compareNodeRelationsProb(NodeRelation a, NodeRelation b)
00023 {
00024 return a.xProb*a.oProb > b.xProb* b.oProb;
00025 }
00026
00027 bool nodeRelationsOrder(NodeRelation a, NodeRelation b)
00028 {
00029 if(a.destNode > b.destNode)
00030 return true;
00031 else
00032 if(a.srcNode > b.srcNode)
00033 return true;
00034 else
00035 return false;
00036 }
00037
00038 bool nodeRelationsEqual(NodeRelation a, NodeRelation b)
00039 {
00040 return a.destNode == b.destNode && a.srcNode == b.srcNode;
00041 }
00042
00043
00044
00045
00046 PolicyGraphGenerator::PolicyGraphGenerator(SharedPointer<MOMDP> _problem, SharedPointer<AlphaVectorPolicy> _policy, PolicyGraphParam _graphParam)
00047 : problem(_problem), policy(_policy), graphParam(_graphParam)
00048 {
00049 int xStateNum = problem->XStates->size();
00050
00051 beliefCacheSet.resize(xStateNum);
00052
00053 for(States::iterator iter = problem->XStates->begin(); iter != problem->XStates->end(); iter ++ )
00054 {
00055 beliefCacheSet[iter.index()] = new BeliefCache();
00056 }
00057
00058 beliefForest = new BeliefForest();
00059 sampleEngine = new EvaluatorSampleEngine();
00060
00061 sampleEngine->setup(NULL, problem, &beliefCacheSet, beliefForest);
00062 beliefForest->setup(problem, sampleEngine, &beliefCacheSet);
00063 beliefForest->globalRootPrepare();
00064 }
00065
00066 PolicyGraphGenerator::~PolicyGraphGenerator()
00067 {
00068 }
00069
00070
00071
00072
00073
00074 void PolicyGraphGenerator::dfsBeliefForest(BeliefTreeNode* curNode, vector<BeliefTreeNode*>& path, int level)
00075 {
00076 nodesList.insert(curNode);
00077 if(level == 0){
00078 return;
00079 }
00080
00081
00082 for(vector<BeliefTreeNode*>::iterator it=path.begin();it!=path.end();it++){
00083 if(*it==curNode){
00084 return;
00085 }
00086 }
00087 path.push_back(curNode);
00088
00089 int maxEdge = graphParam.maxEdge;
00090 double probThreshold = graphParam.probThreshold;
00091 bool lookahead = graphParam.useLookahead;
00092
00093 sampleEngine->samplePrepare(curNode);
00094
00095 EvaluatorBeliefTreeNodeTuple *curNodeExtraData = (EvaluatorBeliefTreeNodeTuple *)curNode->extraData;
00096
00097 SharedPointer<BeliefWithState> currBelSt = curNode->s;
00098 int bestAction = curNodeExtraData->selectedAction;
00099 if(bestAction == -1){
00100 if(graphParam.useLookahead){
00101 bestAction = policy->getBestActionLookAhead(*currBelSt);
00102 }else{
00103 bestAction = policy->getBestAction(*currBelSt);
00104 }
00105 curNodeExtraData->selectedAction = bestAction;
00106 }
00107
00108 vector<NodeRelation> curExpansion;
00109 expandNode(curNode, bestAction, curExpansion);
00110
00111
00112 if(probThreshold > 0){
00113 for(vector<NodeRelation>::iterator it=curExpansion.begin();it!=curExpansion.end();){
00114 if((it->xProb * it->oProb) < probThreshold){
00115 it=curExpansion.erase(it);
00116 }else{
00117 it++;
00118 }
00119 }
00120 }
00121
00122
00123 if(maxEdge > 0 && curExpansion.size() > maxEdge){
00124 sort(curExpansion.begin(), curExpansion.end(), compareNodeRelationsProb);
00125 vector<NodeRelation>::iterator it=curExpansion.begin();
00126
00127 for(int i=0;i<maxEdge;i++){
00128 it++;
00129 }
00130 curExpansion.erase(it, curExpansion.end());
00131 }
00132
00133 for(vector<NodeRelation>::iterator it=curExpansion.begin();it!=curExpansion.end();it++){
00134
00135 bool found = false;
00136 for(vector<NodeRelation>::iterator nit=nodeRelationsList.begin();nit!=nodeRelationsList.end();nit++){
00137 if(it->destNode == nit->destNode && nit->srcNode == it->srcNode)
00138 found = true;
00139 }
00140 if(!found)
00141 nodeRelationsList.push_back(*it);
00142 }
00143
00144
00145 for(vector<NodeRelation>::iterator it=curExpansion.begin();it!=curExpansion.end();it++){
00146 if(level < 0){
00147 level=0;
00148 }
00149
00150
00151 if(nodesList.find(it->destNode)==nodesList.end())
00152 dfsBeliefForest(it->destNode, path,level-1);
00153 }
00154 path.pop_back();
00155 }
00156
00157
00158 void PolicyGraphGenerator::expandNode(BeliefTreeNode* curNode, int bestAction, vector<NodeRelation>& expansion)
00159 {
00160
00161
00162
00163 BeliefTreeQEntry& Qa = curNode->Q[bestAction];
00164 int numXstate = Qa.stateOutcomes.size();
00165
00166 EvaluatorAfterActionDataTuple *afterActionDataTuple = (EvaluatorAfterActionDataTuple *)Qa.extraData;
00167
00168 for(int X = 0 ; X < numXstate ; X++)
00169 {
00170 if(Qa.stateOutcomes[X] == NULL)
00171 {
00172 continue;
00173 }
00174 REAL_VALUE xProb = afterActionDataTuple->spv->operator ()(X);
00175 BeliefTreeObsState* obsX = Qa.stateOutcomes[X];
00176
00177 EvaluatorAfterObsDataTuple* afterObsDataTupel = (EvaluatorAfterObsDataTuple*)obsX->extraData;
00178
00179 int numObs = obsX->outcomes.size();
00180
00181 for(int o = 0 ; o < numObs ; o++)
00182 {
00183 if(obsX->outcomes[o] == NULL)
00184 {
00185 continue;
00186 }
00187 REAL_VALUE oProb = afterObsDataTupel->opv->operator ()(o);
00188
00189 BeliefTreeNode* nextNode = obsX->outcomes[o]->nextState;
00190
00191 NodeRelation newRelation;
00192 newRelation.srcNode = curNode;
00193 newRelation.destNode = nextNode;
00194 newRelation.X = X;
00195 newRelation.xProb = xProb;
00196 newRelation.o = o;
00197 newRelation.oProb = oProb;
00198
00199 expansion.push_back(newRelation);
00200 }
00201 }
00202 }
00203
00204
00205
00206
00207
00208
00209 string PolicyGraphGenerator::formatTuple(string name, int index, map<string, string> dataMap)
00210 {
00211 stringstream sstream;
00212 sstream << name << " ";
00213 if(dataMap.empty())
00214 {
00215 sstream << index;
00216 }
00217 else
00218 {
00219
00220 int lineLength=0;
00221 sstream << "(";
00222 for(map<string, string>::iterator iter = dataMap.begin() ; iter != dataMap.end() ; iter ++)
00223 {
00224 lineLength += iter->second.length();
00225 if(lineLength > 15)
00226 {
00227 sstream << "\\n";
00228 lineLength = 0;
00229 }
00230 if(iter != dataMap.begin())
00231 sstream << ",";
00232
00233 sstream << iter->second;
00234 }
00235 sstream << ")";
00236 }
00237 return sstream.str();
00238 }
00239
00240 void PolicyGraphGenerator::convertStCacheIndex(ostream& output, cacherow_stval& stRowIndex)
00241 {
00242 output << "x"<< stRowIndex.sval << "row" << stRowIndex.row;
00243 }
00244
00245
00246
00247
00248 void PolicyGraphGenerator::generateNodesDot(ostream& output, BeliefTreeNode* node)
00249 {
00250 convertStCacheIndex(output, node->cacheIndex);
00251 output << " [label=\"" ;
00252
00253 if(problem->XStates->size()>1){
00254 int Xindex = node->cacheIndex.sval;
00255 map<string, string> Xstate = problem->getFactoredObservedStatesSymbols(node->cacheIndex.sval);
00256 output << formatTuple("X", Xindex, Xstate) << "\\l";
00257 }
00258
00259 BeliefCache* cache = beliefCacheSet[node->cacheIndex.sval];
00260 BeliefCacheRow* cacheRow = cache->getRow(node->cacheIndex.row);
00261 SharedPointer<belief_vector> curBelief = cacheRow->BELIEF;
00262
00263 int mostProbY = curBelief->argmax();
00264 double prob = curBelief->operator()(mostProbY);
00265
00266 map<string, string> mostProbYstate = problem->getFactoredUnobservedStatesSymbols(mostProbY);
00267 output << formatTuple("Y", mostProbY, mostProbYstate) << " " << prob << "\\l";
00268
00269 EvaluatorBeliefTreeNodeTuple *curNodeExtraData = (EvaluatorBeliefTreeNodeTuple *)node->extraData;
00270 int bestAction = curNodeExtraData->selectedAction;
00271
00272 map<string, string> actName = problem->getActionsSymbols(bestAction);
00273 output << formatTuple("A", bestAction, actName) << "\\l";
00274
00275 output << "\"";
00276
00277
00278
00279
00280
00281
00282 output << "];" << endl;
00283
00284 }
00285
00286
00287
00288 void PolicyGraphGenerator::generateNodesRelation(ostream& output, vector<NodeRelation>& nodeRelations, set<BeliefTreeNode*> firstLevelNodes)
00289 {
00290 for(vector<NodeRelation>::iterator iter = nodeRelations.begin(); iter != nodeRelations.end() ; iter ++)
00291 {
00292
00293
00294 if(firstLevelNodes.find(iter->srcNode) == firstLevelNodes.end() && iter->srcNode != NULL){
00295 convertStCacheIndex(output, iter->srcNode->cacheIndex);
00296 }
00297 else{
00298 output << "root";
00299 }
00300 output << " -> " ;
00301 if(firstLevelNodes.find(iter->destNode) == firstLevelNodes.end()){
00302 convertStCacheIndex(output, iter->destNode->cacheIndex);
00303 }
00304 else{
00305 output << "root";
00306 }
00307 output << " [label=\"";
00308
00309 if(problem->XStates->size()>1){
00310 map<string, string> Xstate = problem->getFactoredObservedStatesSymbols(iter->X);
00311 output << formatTuple("X", iter->X, Xstate) << " " << iter->xProb << "\\l";
00312 }
00313 map<string, string> obsName = problem->getObservationsSymbols(iter->o);
00314 output << formatTuple("o", iter->o, obsName) <<" " << iter->oProb <<"\\l";
00315
00316 output <<"\"]";
00317 output << ";" ;
00318 output << endl;
00319
00320 }
00321 }
00322
00323 void PolicyGraphGenerator::drawRootNodeDot(ostream& output, SharedPointer<SparseVector> initialBeliefY, SharedPointer<DenseVector> initialBeliefX, int bestAction)
00324 {
00325 output << "root";
00326 output << " [label=\"" ;
00327
00328
00329 int mostProbX;
00330 int numXstate = initialBeliefX->size();
00331 double xProb = 0.0;
00332 for(int i=0;i<numXstate;i++){
00333 if(initialBeliefX->operator()(i) > xProb){
00334 xProb = initialBeliefX->operator()(i);
00335 mostProbX = i;
00336 }
00337 }
00338
00339 if(numXstate > 1){
00340 map<string, string> Xstate = problem->getFactoredObservedStatesSymbols(mostProbX);
00341 output << formatTuple("X", mostProbX, Xstate) << " " << xProb << "\\l";
00342 }
00343
00344 int mostProbY = initialBeliefY->argmax();
00345 double prob = initialBeliefY->operator()(mostProbY);
00346
00347 map<string, string> mostProbYstate = problem->getFactoredUnobservedStatesSymbols(mostProbY);
00348 output << formatTuple("Y", mostProbY, mostProbYstate) << " " << prob << "\\l";
00349
00350 map<string, string> actName = problem->getActionsSymbols(bestAction);
00351 output << formatTuple("A", bestAction, actName) << "\\l";
00352
00353 output << "\"";
00354 output << " shape=doublecircle";
00355 output << " labeljust=\"l\"";
00356 output << "];" << endl;
00357
00358 }
00359
00360
00361
00362
00363
00364 void PolicyGraphGenerator::generateGraph(ostream& output)
00365 {
00366 output << "digraph G" << endl;
00367 output << "{" << endl;
00368
00369
00370 int depth = graphParam.depth;
00371 vector<BeliefTreeNode*> path;
00372
00373
00374
00375
00376
00377
00378 SharedPointer<SparseVector> initialBeliefY = problem->getInitialBeliefY(0);
00379
00380 if(problem->hasIntraslice) {
00381 cerr << "Policy graph generating with intraslice conditioning is not supported yet." << endl;
00382 exit(-1);
00383 }
00384
00385 SharedPointer<DenseVector> initialBeliefX = problem->initialBeliefX;
00386 int initialBestAction;
00387 if(graphParam.useLookahead){
00388 initialBestAction= policy->getBestActionLookAhead(initialBeliefY, *initialBeliefX);
00389 }
00390 else{
00391 initialBestAction= policy->getBestAction(initialBeliefY, *initialBeliefX);
00392 }
00393
00394
00395 vector<NodeRelation> firstLevel;
00396 set<BeliefTreeNode*> firstLevelNodes;
00397
00398 if(problem->initialBeliefStval->sval == -1)
00399 {
00400
00401 obsState_prob_vector spv;
00402 problem->getObsStateProbVector(spv, initialBeliefY, *initialBeliefX, initialBestAction);
00403
00404
00405 beliefForest->sampleRootEdges.resize(problem->XStates->size() * problem->observations->size());
00406
00407 FOR(Xn, problem->XStates->size()){
00408
00409 obs_prob_vector opv;
00410 problem->getObsProbVector(opv, initialBeliefY, spv, initialBestAction, Xn);
00411
00412 FOR(O, problem->observations->size()){
00413 double rprob = opv(O) * spv(Xn);
00414 unsigned int r = Xn * problem->observations->size();
00415
00416 if (rprob > OBS_IS_ZERO_EPS)
00417 {
00418 SharedPointer<BeliefWithState> thisRootb_s = problem->beliefTransition->nextBelief(initialBeliefY, *initialBeliefX, initialBestAction, O, Xn);
00419
00420 SampleRootEdge* rE = new SampleRootEdge();
00421 beliefForest->sampleRootEdges[r] = rE;
00422 rE->sampleRootProb = rprob;
00423 rE->sampleRoot = sampleEngine->getNode(thisRootb_s);
00424
00425 rE->sampleRoot->count = 1;
00426
00427
00428 NodeRelation newRelation;
00429 newRelation.srcNode = NULL;
00430 newRelation.destNode = rE->sampleRoot;
00431 newRelation.X = Xn;
00432 newRelation.xProb = spv(Xn);
00433 newRelation.o = O;
00434 newRelation.oProb = opv(O);
00435
00436 firstLevel.push_back(newRelation);
00437 }
00438 else
00439 {
00440 beliefForest->sampleRootEdges[r] = NULL;
00441 }
00442
00443 }
00444 }
00445 }
00446 else
00447 {
00448
00449 for(vector<SampleRootEdge*>::iterator it=beliefForest->sampleRootEdges.begin();it!=beliefForest->sampleRootEdges.end();it++){
00450 if((*it)!=NULL){
00451 sampleEngine->samplePrepare((*it)->sampleRoot);
00452 expandNode((*it)->sampleRoot, initialBestAction, firstLevel);
00453 firstLevelNodes.insert((*it)->sampleRoot);
00454 path.push_back((*it)->sampleRoot);
00455 }
00456 }
00457 }
00458
00459 drawRootNodeDot(output, initialBeliefY, initialBeliefX, initialBestAction);
00460 nodeRelationsList.insert(nodeRelationsList.end(),firstLevel.begin(),firstLevel.end());
00461
00462
00463 for(vector<NodeRelation>::iterator it=firstLevel.begin();it!=firstLevel.end();it++){
00464 dfsBeliefForest(it->destNode, path, depth-1);
00465 }
00466
00467
00468 for( set<BeliefTreeNode *>::iterator iter = nodesList.begin() ; iter != nodesList.end(); iter++)
00469 {
00470 BeliefTreeNode* curNode = *iter;
00471 if(firstLevelNodes.find(*iter) == firstLevelNodes.end())
00472 generateNodesDot(output, curNode);
00473 }
00474
00475
00476 generateNodesRelation(output, nodeRelationsList, firstLevelNodes);
00477
00478 output << "}" << endl;
00479 output.flush();
00480 }