PolicyGraphGenerator.cpp
Go to the documentation of this file.
00001 #include "MOMDP.h"
00002 #include "AlphaVectorPolicy.h"
00003 #include "BeliefForest.h"
00004 #include "Sample.h"
00005 #include "BeliefCache.h"
00006 #include "EvaluatorSampleEngine.h"
00007 #include "EvaluatorBeliefTreeNodeTuple.h"
00008 
00009 #include "PolicyGraphGenerator.h"
00010 
00011 #include <string>
00012 #include <stdlib.h>
00013 #include <sstream>
00014 #include <fstream>
00015 
00016 using namespace std;
00017 using namespace momdp;
00018 
00019 /*
00020  *      Comparator for NodeRelation by joint probability of o, X
00021  */
00022 bool compareNodeRelationsProb(NodeRelation a, NodeRelation b)
00023 {
00024     return a.xProb*a.oProb > b.xProb* b.oProb;
00025 }
00026 //comparator for removing duplicates
00027 bool nodeRelationsOrder(NodeRelation a, NodeRelation b)
00028 {
00029     if(a.destNode > b.destNode)
00030         return true;
00031     else
00032         if(a.srcNode > b.srcNode)
00033             return true;
00034         else
00035             return false;
00036 }
00037 
00038 bool nodeRelationsEqual(NodeRelation a, NodeRelation b)
00039 {
00040     return a.destNode == b.destNode && a.srcNode == b.srcNode;
00041 }
00042 /*
00043  * Initialise and setup data member 
00044  */
00045 
00046     PolicyGraphGenerator::PolicyGraphGenerator(SharedPointer<MOMDP> _problem, SharedPointer<AlphaVectorPolicy> _policy, PolicyGraphParam _graphParam)
00047 : problem(_problem), policy(_policy), graphParam(_graphParam)
00048 {
00049     int xStateNum = problem->XStates->size();
00050 
00051     beliefCacheSet.resize(xStateNum);
00052 
00053     for(States::iterator iter = problem->XStates->begin(); iter != problem->XStates->end(); iter ++ )
00054     {
00055         beliefCacheSet[iter.index()] = new BeliefCache();
00056     }
00057 
00058     beliefForest = new BeliefForest();
00059     sampleEngine = new EvaluatorSampleEngine();
00060 
00061     sampleEngine->setup(NULL, problem, &beliefCacheSet, beliefForest);
00062     beliefForest->setup(problem, sampleEngine, &beliefCacheSet);
00063     beliefForest->globalRootPrepare();
00064 }
00065 
00066 PolicyGraphGenerator::~PolicyGraphGenerator()
00067 {
00068 }
00069 
00070 
00071 /*
00072  * Recursive dfs from a BeliefTreeNode
00073  */
00074 void PolicyGraphGenerator::dfsBeliefForest(BeliefTreeNode* curNode, vector<BeliefTreeNode*>& path, int level)
00075 {
00076     nodesList.insert(curNode);
00077     if(level == 0){
00078         return; 
00079     }
00080 
00081     //detect cycle 
00082     for(vector<BeliefTreeNode*>::iterator it=path.begin();it!=path.end();it++){
00083         if(*it==curNode){
00084             return;     //do not expand if it is a cycle
00085         }
00086     }
00087     path.push_back(curNode);    
00088 
00089     int maxEdge = graphParam.maxEdge;
00090     double probThreshold = graphParam.probThreshold;
00091     bool lookahead = graphParam.useLookahead;
00092 
00093     sampleEngine->samplePrepare(curNode); 
00094 
00095     EvaluatorBeliefTreeNodeTuple *curNodeExtraData = (EvaluatorBeliefTreeNodeTuple *)curNode->extraData;
00096 
00097     SharedPointer<BeliefWithState> currBelSt = curNode->s;
00098     int bestAction = curNodeExtraData->selectedAction;
00099     if(bestAction == -1){
00100         if(graphParam.useLookahead){
00101             bestAction = policy->getBestActionLookAhead(*currBelSt);
00102         }else{
00103             bestAction = policy->getBestAction(*currBelSt);
00104         }
00105         curNodeExtraData->selectedAction = bestAction;
00106     }
00107 
00108     vector<NodeRelation> curExpansion;
00109     expandNode(curNode, bestAction, curExpansion);
00110 
00111     //prune edges with joint probability of X and observation less than threshold
00112     if(probThreshold > 0){
00113         for(vector<NodeRelation>::iterator it=curExpansion.begin();it!=curExpansion.end();){
00114             if((it->xProb * it->oProb) < probThreshold){
00115                 it=curExpansion.erase(it);
00116             }else{
00117                 it++;
00118             }
00119         }
00120     }
00121 
00122     //limit the number of edges 
00123     if(maxEdge > 0 && curExpansion.size() > maxEdge){
00124         sort(curExpansion.begin(), curExpansion.end(), compareNodeRelationsProb); 
00125         vector<NodeRelation>::iterator it=curExpansion.begin();
00126         //only keep top maxEdge number of edges
00127         for(int i=0;i<maxEdge;i++){
00128             it++;
00129         }
00130         curExpansion.erase(it, curExpansion.end());
00131     }
00132     //primitive method of search for duplicate edges
00133     for(vector<NodeRelation>::iterator it=curExpansion.begin();it!=curExpansion.end();it++){
00134 
00135         bool found = false;     
00136         for(vector<NodeRelation>::iterator nit=nodeRelationsList.begin();nit!=nodeRelationsList.end();nit++){
00137             if(it->destNode == nit->destNode && nit->srcNode == it->srcNode)
00138                 found = true;
00139         }
00140         if(!found)
00141             nodeRelationsList.push_back(*it);
00142     }
00143 
00144     //dfs remaining child node
00145     for(vector<NodeRelation>::iterator it=curExpansion.begin();it!=curExpansion.end();it++){
00146         if(level < 0){
00147             level=0;            //no limit
00148         }
00149 
00150         //do not expand nodes that are already been expanded
00151         if(nodesList.find(it->destNode)==nodesList.end())
00152             dfsBeliefForest(it->destNode, path,level-1);
00153     }
00154     path.pop_back();
00155 }
00156 
00157 
00158 void PolicyGraphGenerator::expandNode(BeliefTreeNode* curNode, int bestAction, vector<NodeRelation>& expansion)
00159 {
00160     /*
00161      *  DFS nodes of best action only
00162      */
00163     BeliefTreeQEntry& Qa = curNode->Q[bestAction];
00164     int numXstate = Qa.stateOutcomes.size();
00165 
00166     EvaluatorAfterActionDataTuple *afterActionDataTuple = (EvaluatorAfterActionDataTuple *)Qa.extraData;
00167 
00168     for(int X = 0 ; X < numXstate ; X++)
00169     {
00170         if(Qa.stateOutcomes[X] == NULL)
00171         {
00172             continue;
00173         }
00174         REAL_VALUE xProb = afterActionDataTuple->spv->operator ()(X);
00175         BeliefTreeObsState* obsX = Qa.stateOutcomes[X];
00176 
00177         EvaluatorAfterObsDataTuple* afterObsDataTupel = (EvaluatorAfterObsDataTuple*)obsX->extraData;
00178 
00179         int numObs = obsX->outcomes.size();
00180 
00181         for(int o = 0 ; o < numObs ; o++)
00182         {
00183             if(obsX->outcomes[o] == NULL)
00184             {
00185                 continue;
00186             }
00187             REAL_VALUE oProb = afterObsDataTupel->opv->operator ()(o);
00188 
00189             BeliefTreeNode* nextNode = obsX->outcomes[o]->nextState;
00190 
00191             NodeRelation newRelation;
00192             newRelation.srcNode = curNode;
00193             newRelation.destNode = nextNode;
00194             newRelation.X = X;
00195             newRelation.xProb = xProb;
00196             newRelation.o = o;
00197             newRelation.oProb = oProb;
00198 
00199             expansion.push_back(newRelation);
00200         }
00201     }
00202 }
00203 
00204 /***
00205  *      get a string representation of X [or action or Y] in the form of:
00206  *      For eg: X: (state1, state2, ..) or: X:index if dataMap is empty
00207  *      map<string, string> dataMap stores the variable name and value mapping
00208  */
00209 string PolicyGraphGenerator::formatTuple(string name, int index, map<string, string> dataMap)
00210 {
00211     stringstream sstream;
00212     sstream << name << " ";
00213     if(dataMap.empty())
00214     {
00215         sstream << index;
00216     }
00217     else
00218     {
00219         //line wrapping for smaller circles
00220         int lineLength=0;
00221         sstream << "(";
00222         for(map<string, string>::iterator iter = dataMap.begin() ; iter != dataMap.end() ; iter ++)
00223         {
00224             lineLength += iter->second.length(); 
00225             if(lineLength > 15)
00226             {
00227                 sstream << "\\n";
00228                 lineLength = 0;
00229             }
00230             if(iter != dataMap.begin()) 
00231                 sstream <<  ",";
00232 
00233             sstream <<  iter->second;
00234         }
00235         sstream << ")";
00236     }
00237     return sstream.str();
00238 }
00239 
00240 void PolicyGraphGenerator::convertStCacheIndex(ostream& output, cacherow_stval& stRowIndex)
00241 {
00242     output << "x"<< stRowIndex.sval << "row" << stRowIndex.row;
00243 }
00244 
00245 /*
00246  * Generate a Dot node from a BeliefTreeNode
00247  */
00248 void PolicyGraphGenerator::generateNodesDot(ostream& output, BeliefTreeNode* node)
00249 {
00250     convertStCacheIndex(output, node->cacheIndex); //to check purpose
00251     output << " [label=\"" ;
00252 
00253     if(problem->XStates->size()>1){
00254         int Xindex = node->cacheIndex.sval;
00255         map<string, string> Xstate = problem->getFactoredObservedStatesSymbols(node->cacheIndex.sval);
00256         output << formatTuple("X", Xindex, Xstate) << "\\l";
00257     }
00258     //get most probable Y state
00259     BeliefCache* cache  = beliefCacheSet[node->cacheIndex.sval];
00260     BeliefCacheRow* cacheRow = cache->getRow(node->cacheIndex.row);
00261     SharedPointer<belief_vector> curBelief = cacheRow->BELIEF;
00262 
00263     int mostProbY  = curBelief->argmax();       //get the most probable Y state
00264     double prob = curBelief->operator()(mostProbY);     //get its probability
00265 
00266     map<string, string> mostProbYstate = problem->getFactoredUnobservedStatesSymbols(mostProbY);
00267     output << formatTuple("Y", mostProbY, mostProbYstate) << " " << prob << "\\l";
00268 
00269     EvaluatorBeliefTreeNodeTuple *curNodeExtraData = (EvaluatorBeliefTreeNodeTuple *)node->extraData;
00270     int bestAction =  curNodeExtraData->selectedAction;
00271 
00272     map<string, string> actName = problem->getActionsSymbols(bestAction);
00273     output << formatTuple("A", bestAction, actName) << "\\l";
00274 
00275     output << "\"";
00276 
00277     /*if(rootNodesList.find(node) != rootNodesList.end())
00278       {
00279     // is starting node
00280     output << " shape=doublecircle";
00281     }*/
00282     output << "];" << endl;
00283 
00284 }
00285 /*
00286  * Generate a DOT edge from NodeRelation
00287  */
00288 void PolicyGraphGenerator::generateNodesRelation(ostream& output, vector<NodeRelation>& nodeRelations, set<BeliefTreeNode*> firstLevelNodes)
00289 {
00290     for(vector<NodeRelation>::iterator iter = nodeRelations.begin(); iter != nodeRelations.end() ; iter ++)
00291     {
00292         //special case for nodes link back to first level nodes
00293         //link back to a single root node instead
00294         if(firstLevelNodes.find(iter->srcNode) == firstLevelNodes.end() && iter->srcNode != NULL){
00295             convertStCacheIndex(output, iter->srcNode->cacheIndex);
00296         }
00297         else{
00298             output << "root";
00299         }
00300         output << " -> " ;
00301         if(firstLevelNodes.find(iter->destNode) == firstLevelNodes.end()){
00302             convertStCacheIndex(output, iter->destNode->cacheIndex);
00303         }
00304         else{
00305             output << "root";
00306         }
00307         output << " [label=\""; //X: " << iter->X << " (" << iter->xProb << ") o: " << iter->o << " (" << iter->oProb << ")";
00308 
00309         if(problem->XStates->size()>1){
00310             map<string, string> Xstate = problem->getFactoredObservedStatesSymbols(iter->X);
00311             output << formatTuple("X", iter->X, Xstate) << " " << iter->xProb  << "\\l";
00312         }
00313         map<string, string> obsName = problem->getObservationsSymbols(iter->o);
00314         output << formatTuple("o", iter->o, obsName) <<" " << iter->oProb  <<"\\l";
00315 
00316         output <<"\"]";
00317         output << ";" ;
00318         output << endl;
00319 
00320     }
00321 }
00322 
00323 void PolicyGraphGenerator::drawRootNodeDot(ostream& output, SharedPointer<SparseVector> initialBeliefY, SharedPointer<DenseVector> initialBeliefX, int bestAction)
00324 {
00325     output << "root"; 
00326     output << " [label=\"" ;
00327 
00328     //most probable X state from initial X belief
00329     int mostProbX;
00330     int numXstate = initialBeliefX->size();
00331     double xProb = 0.0;
00332     for(int i=0;i<numXstate;i++){
00333         if(initialBeliefX->operator()(i) > xProb){
00334             xProb = initialBeliefX->operator()(i);
00335             mostProbX = i;
00336         }
00337     }
00338 
00339     if(numXstate > 1){          //don't output X if there is only one value, i.e dummy X for pure POMDP problem
00340         map<string, string> Xstate = problem->getFactoredObservedStatesSymbols(mostProbX);
00341         output << formatTuple("X", mostProbX, Xstate) << " " << xProb << "\\l";
00342     }
00343     //get most probable Y state
00344     int mostProbY  = initialBeliefY->argmax();  //get the most probable Y state
00345     double prob = initialBeliefY->operator()(mostProbY);        //get its probability
00346 
00347     map<string, string> mostProbYstate = problem->getFactoredUnobservedStatesSymbols(mostProbY);
00348     output << formatTuple("Y", mostProbY, mostProbYstate) << " " << prob << "\\l";
00349 
00350     map<string, string> actName = problem->getActionsSymbols(bestAction);
00351     output << formatTuple("A", bestAction, actName) << "\\l";
00352 
00353     output << "\"";
00354     output << " shape=doublecircle";
00355     output << " labeljust=\"l\""; 
00356     output << "];" << endl;
00357 
00358 }
00359 
00360 /*
00361  *  Generate policy graph in DOT format
00362  *  Precondition: BeliefForest must be initialised (via setup function) and globalPrepareRootNode must be called
00363  */
00364 void PolicyGraphGenerator::generateGraph(ostream& output)
00365 {
00366     output << "digraph G" << endl;
00367     output << "{" << endl;
00368 
00369     //draw root node.  X might be a distribution at first step
00370     int depth = graphParam.depth;
00371     vector<BeliefTreeNode*> path;       //store nodes in current search path to detect cycle and halt
00372 
00373     /*
00374      * special case for first level when X might be a distribution
00375      */
00376 
00377     //get best action for initial X belief
00378     SharedPointer<SparseVector> initialBeliefY = problem->getInitialBeliefY(0);
00379     // TODO(haoyu) figure out how to generate policy graph for intraslice
00380     if(problem->hasIntraslice) {
00381       cerr << "Policy graph generating with intraslice conditioning is not supported yet." << endl;
00382       exit(-1);
00383     }
00384     
00385     SharedPointer<DenseVector> initialBeliefX = problem->initialBeliefX;
00386     int initialBestAction;
00387     if(graphParam.useLookahead){
00388         initialBestAction= policy->getBestActionLookAhead(initialBeliefY, *initialBeliefX);
00389     }
00390     else{
00391         initialBestAction= policy->getBestAction(initialBeliefY, *initialBeliefX);
00392     }
00393 
00394 
00395     vector<NodeRelation> firstLevel;
00396     set<BeliefTreeNode*> firstLevelNodes;
00397     //special case for initial X is a distribution
00398     if(problem->initialBeliefStval->sval == -1)
00399     {
00400         //get the next X distribution after initial action
00401         obsState_prob_vector spv;
00402         problem->getObsStateProbVector(spv, initialBeliefY, *initialBeliefX, initialBestAction);
00403         
00404         //manipulate the belief forest sample root edges such that there is one sample root for each X' and O
00405         beliefForest->sampleRootEdges.resize(problem->XStates->size() * problem->observations->size());
00406         
00407         FOR(Xn, problem->XStates->size()){
00408 
00409             obs_prob_vector opv;
00410             problem->getObsProbVector(opv, initialBeliefY, spv, initialBestAction, Xn);
00411 
00412             FOR(O, problem->observations->size()){
00413                 double rprob = opv(O) * spv(Xn);
00414                 unsigned int r = Xn * problem->observations->size();
00415 
00416                 if (rprob > OBS_IS_ZERO_EPS) 
00417                 {
00418                         SharedPointer<BeliefWithState>  thisRootb_s = problem->beliefTransition->nextBelief(initialBeliefY, *initialBeliefX, initialBestAction, O, Xn);  
00419 
00420                         SampleRootEdge* rE = new SampleRootEdge();
00421                         beliefForest->sampleRootEdges[r] = rE;
00422                         rE->sampleRootProb = rprob;
00423                         rE->sampleRoot = sampleEngine->getNode(thisRootb_s);
00424 
00425                         rE->sampleRoot->count = 1;//for counting valid path
00426 
00427                         //draw the edges from root node to first level nodes
00428                         NodeRelation newRelation;
00429                         newRelation.srcNode = NULL;
00430                         newRelation.destNode = rE->sampleRoot;
00431                         newRelation.X = Xn;
00432                         newRelation.xProb = spv(Xn);
00433                         newRelation.o = O;
00434                         newRelation.oProb = opv(O);
00435 
00436                         firstLevel.push_back(newRelation);
00437                 } 
00438                 else 
00439                 {
00440                         beliefForest->sampleRootEdges[r] = NULL;
00441                 }
00442                         
00443             }
00444         }
00445     }
00446     else
00447     {
00448         //dfs for each X state with non-zero initial probability
00449         for(vector<SampleRootEdge*>::iterator it=beliefForest->sampleRootEdges.begin();it!=beliefForest->sampleRootEdges.end();it++){
00450             if((*it)!=NULL){
00451                 sampleEngine->samplePrepare((*it)->sampleRoot);
00452                 expandNode((*it)->sampleRoot, initialBestAction, firstLevel); 
00453                 firstLevelNodes.insert((*it)->sampleRoot);      //for checking purpose when drawing edges that links back to root node
00454                 path.push_back((*it)->sampleRoot);
00455             }
00456         }
00457     }
00458     
00459     drawRootNodeDot(output, initialBeliefY, initialBeliefX, initialBestAction);
00460     nodeRelationsList.insert(nodeRelationsList.end(),firstLevel.begin(),firstLevel.end()); 
00461 
00462     //DFS all children of first level nodes 
00463     for(vector<NodeRelation>::iterator it=firstLevel.begin();it!=firstLevel.end();it++){
00464         dfsBeliefForest(it->destNode, path, depth-1);
00465     }
00466 
00467     //list all nodes in graph in DOT
00468     for( set<BeliefTreeNode *>::iterator iter = nodesList.begin() ; iter != nodesList.end(); iter++)
00469     {
00470         BeliefTreeNode* curNode = *iter;
00471         if(firstLevelNodes.find(*iter) == firstLevelNodes.end())        //don't out put first level nodes
00472             generateNodesDot(output, curNode);
00473     }
00474 
00475     //list all edges in DOT
00476     generateNodesRelation(output, nodeRelationsList, firstLevelNodes);     
00477 
00478     output << "}" << endl;
00479     output.flush();
00480 }


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29