SARSOPPrune.cpp
Go to the documentation of this file.
00001 #include "SARSOPPrune.h"
00002 
00003 namespace momdp 
00004 {
00005 
00006         void SARSOPPrune::prune(void)
00007         {
00008                 DEBUG_TRACE(cout << "SARSOPPrune" << endl;);
00009 
00010                 pruneLowerBound(); // bounds->pruneAlpha->prune();
00011                 pruneUpperBound();  // bounds->pruneBVpair->prune();
00012 
00013                 if(solver->numBackups/pruneInterval>= currentRound)
00014                 {
00015                         DEBUG_TRACE(cout << "currentRound " << currentRound << endl;);
00016                         currentRound++;
00017                         //tranverse the entire tree to nullify sub-optimal branches 
00018                         //      in the tree
00019                         //SYL_NO_NULLIFY  280409 added if statement
00020                         if (problem->XStates->size() == 1)
00021                         {
00022                                 nullifySubOptimalBranches();
00023                         }
00024                         //check for delta-dominance, and prune the alpha planes
00025                         //      which do not hold any certificate
00026 
00027                         //bounds->pruneAlpha->updateCertsAndUses(bounds->numBackups);
00028                         //bounds->pruneAlpha->pruneNotCertedAndNotUsed();
00029 
00030                         //SYL_NO_PRUNE  
00031                         if (problem->XStates->size() == 1)
00032                         {
00033                                 pruneDynamicDeltaVersion();  // calls wrapper in SARSOPPrune class
00034                                 // bounds->pruneAlpha->pruneDynamicDeltaVersion(bounds->numBackups);
00035                         }
00036 
00037 
00038                 } 
00039         }
00040 
00041         void SARSOPPrune::pruneLowerBound()
00042         {
00043                 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size()) 
00044                 {
00045                         sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->prune();
00046                 }
00047         }
00048         void SARSOPPrune::pruneUpperBound()
00049         {
00050                 FOR (stateidx, sarsopSolver->upperBoundSet->set.size()) 
00051                 {
00052                         sarsopSolver->upperBoundSet->set[stateidx]->pruneEngine->prune();
00053                 }
00054         }
00055 
00056         void SARSOPPrune::pruneDynamicDeltaVersion()
00057         {
00058                 DEBUG_TRACE( cout << "SARSOPPrune::pruneDynamicDeltaVersion" << endl;);
00059 
00060                 double curTime = GlobalResource::getInstance()->getInstance()->getRunTime();    
00061                 double sumPruneTime = 0;
00062                 int overPruneLocal = 0, underPruneLocal = 0; 
00063                 int overPruneSum = 0;
00064                 int underPruneSum = 0; 
00065 
00066                 int sumNumPrune=0;
00067 
00068                 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size()) 
00069                 {
00070                         sumPruneTime  += sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->pruneTime;
00071                         sumNumPrune  += sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->numPrune;
00072                 }
00073 
00074                 if (firstPass) 
00075                 {
00076                         // first time in this function - record that as the elapsed time
00077                         DEBUG_TRACE( cout << "firstPass" << endl;);
00078                         elapsed = curTime;
00079                         firstPass = false;
00080                 }
00081 
00082         
00083                 // below follows the settings in appl-0.3w
00084                 //              if((curTime-elapsed) > 10 && sumPruneTime < (curTime-elapsed-10) * 0.05)  //curTime-elapsed-10
00085                 //              if(curTime > 5 && sumPruneTime < curTime * 0.1)
00086                 
00087                 if((curTime-elapsed) > 5 && sumPruneTime < (curTime-elapsed-5) * 0.1)  
00088                 {
00089                         //cout << "elapsed : " << elapsed << " curTime : " << curTime << " (curTime-elapsed) : " << (curTime-elapsed) << " (curTime-elapsed-5) : " << (curTime-elapsed-5) << " sumPruneTime : " << sumPruneTime << " sumNumPrune : " << sumNumPrune << endl;
00090                         FOR (stateidx, sarsopSolver->lowerBoundSet->set.size()) 
00091                         {
00092                                 DEBUG_TRACE( cout << "stateidx " << stateidx << endl;);
00093                                 sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->pruneDynamicDeltaVersion(solver->numBackups, overPruneLocal, underPruneLocal);
00094                                 overPruneSum += overPruneLocal;
00095                                 underPruneSum += underPruneLocal;
00096 
00097                         }
00098 
00099 
00100                         //      cout << "overPruneSum : " << overPruneSum << endl;
00101                         //      cout << "underPruneSum : " << underPruneSum << endl;
00102 
00103                         // adjust bglobal_delta accordingly
00104                         updateDeltaVersion2(overPruneSum, underPruneSum);  
00105 
00106                         //      cout << "bglobal_delta : " << bglobal_delta << endl;
00107 
00108                         // set global_delta for each pruneAlpha accordingly
00109                         FOR (stateidx, sarsopSolver->lowerBoundSet->set.size()) 
00110                         {
00111                                 sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->setDelta(bglobal_delta);
00112                         }
00113                 }
00114         }
00115 
00116 
00117         void SARSOPPrune::setDelta(double newDelta)
00118         {
00119                 DEBUG_TRACE(cout << "SARSOPPrune::setDelta newDelta " << newDelta << endl;);
00120                 bglobal_delta = newDelta;
00121                 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size()) 
00122                 {
00123                         sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->setDelta(newDelta);
00124                 }
00125         }
00126 
00127         void SARSOPPrune::updateDeltaVersion2(int overPrune, int underPrune)
00128         {
00129                 double overPruneThreshold = solver->solverParams-> overPruneThreshold;
00130                 double lowerPruneThreshold = solver->solverParams-> lowerPruneThreshold;
00131                 
00132                 double fOverPrune = overPrune;
00133                 
00134                 DEBUG_TRACE(cout << "SARSOPPrune::updateDeltaVersion2" << endl;);
00135                 DEBUG_TRACE(cout << "overPruneThreshold" << overPruneThreshold << endl;);
00136                 DEBUG_TRACE(cout << "lowerPruneThreshold" << lowerPruneThreshold << endl;);
00137                 DEBUG_TRACE(cout << "fOverPrune" << fOverPrune << endl;);
00138                 DEBUG_TRACE(cout << "underPrune" << underPrune << endl;);
00139                 
00140 
00141 
00142                 if(solver->solverParams->dynamicDeltaPercentageMode)
00143                 {
00144                         DEBUG_TRACE(cout << "dynamicDeltaPercentageMode" << endl;);
00145                         unsigned int total_planes = 0;
00146 
00147                         FOR (stateidx, sarsopSolver->lowerBoundSet->set.size()) 
00148                         {
00149                                 total_planes += sarsopSolver->lowerBoundSet->set[stateidx]->planes.size();
00150                         }
00151                         fOverPrune = ((double)overPrune)/total_planes;
00152                         DEBUG_TRACE(cout << "fOverPrune" << fOverPrune << endl;);
00153                 }
00154 
00155                 
00156                 
00157                 //cout  << "overPruneThreshold : "<< overPruneThreshold<< endl; 
00158                 //cout  << "lowerPruneThreshold : "<< lowerPruneThreshold<< endl;
00159                 //cout  << "overPrune : "<< fOverPrune << endl;                                 
00160         
00161                 // 0 is increase delta
00162                 // 1 is stay
00163                 // 2 is decrease delta
00164                 
00165                 //int state = 2;
00166                 switch(state)
00167                 {
00168                         case 0:
00169                                 DEBUG_TRACE(cout << "state 0" << endl;);
00170                                 if(fOverPrune < overPruneThreshold)
00171                                 {
00172                                         state = 1;
00173                                 }
00174                                 else
00175                                 {
00176                                         increaseDelta();
00177                                         
00178                                 }
00179                                 break;
00180                         case 1:
00181                                 DEBUG_TRACE(cout << "state 1" << endl;);
00182                                 if(fOverPrune < lowerPruneThreshold)
00183                                 {
00184                                         state = 2;
00185                                         decreaseDelta();
00186                                         
00187                                 }
00188                                 else if(fOverPrune > overPruneThreshold)
00189                                 {
00190                                         state = 0;
00191                                         increaseDelta();
00192                                         
00193                                 }
00194                                 break;
00195                         case 2:
00196                                 DEBUG_TRACE(cout << "state 2" << endl;);
00197                                 if(fOverPrune > overPruneThreshold)
00198                                 {
00199                                         state = 0;
00200                                         increaseDelta();
00201 
00202                                 }
00203                                 else
00204                                 {
00205                                         decreaseDelta();
00206                                 }
00207                                 break;
00208                 }
00209                 state = 1;
00210 
00211         }
00212         void SARSOPPrune::increaseDelta()
00213         {
00214                 DEBUG_TRACE(cout << "increaseDelta" << endl;);
00215                 if (bglobal_delta < 2.0+1e-7) 
00216                 {       // 2 is the maximum distance between 2 belief points.
00217                         bglobal_delta *= 2;
00218                 }
00219                 //cout << "Increase Delta to : " << global_delta << endl;
00220         }
00221 
00222         void SARSOPPrune::decreaseDelta()
00223         {
00224                 DEBUG_TRACE(cout << "decreaseDelta" << endl;);
00225                 bglobal_delta /= 2; // on the left follows the settings in appl-0.3w *= 0.75;
00226                 //bglobal_delta /= 2;
00227                 //cout << "Decrease Delta to : " << global_delta << endl;
00228         }
00229 
00230         void SARSOPPrune::nullifySubOptimalBranches()
00231         {
00232                 DEBUG_TRACE( cout << "SARSOPPrune::nullifySubOptimalBranches" << endl;);
00233 
00234                 BeliefForest* globalRoot = sarsopSolver->sampleEngine->getGlobalNode();
00235 
00236                 BeliefTreeNode* currRoot;
00237                 //unsigned int numSampleRoots = globalRoot->getGlobalRootNumSampleroots();
00238                 unsigned int numSampleRoots = globalRoot->sampleRootEdges.size();
00239 
00240                 FOR (r, numSampleRoots) 
00241                 {
00242                         //FOR (r, globalRoot->getGlobalRootNumSampleroots()) {
00243                         SampleRootEdge* eR = globalRoot->sampleRootEdges[r];
00244                         if (NULL != eR) 
00245                         {
00246                                 currRoot = eR->sampleRoot;
00247                                 uncheckAllSubNodes(currRoot);
00248                                 nullifySubOptimalCerts(currRoot);
00249                         }
00250                 }
00251         }
00252 
00253 
00254         //Function: nullifySubOptimalCerts
00255         //Functionality:
00256         //      tranverse the reachability-tree with root 'cn', and update
00257         //      the 'count' of non-suboptimal paths which leads to the node
00258         //      for each node in the tree. (i.e. a node with reachableCount==0
00259         //      is a suboptimal path
00260         //Note:
00261         //      note that a node which becomes suboptimal may become a valid
00262         //      node again later, as it maybe reached from other paths as
00263         //      the tree grows
00264         //Parameters:
00265         //      cn: the root for the reachability-tree that we are going to
00266         //              explore
00267         //Returns:
00268         //      NA
00269         void SARSOPPrune::nullifySubOptimalCerts(BeliefTreeNode* cn)
00270         {
00271                 double ubVal, lbVal;
00272                 DEBUG_TRACE ( cout << "SARSOPPrune::nullifySubOptimalCerts" << endl; );
00273                 if(cn->checked==false)
00274                 {
00275                         DEBUG_TRACE ( cout << "cn->checked==false" << endl; );
00276                         //check the checked
00277                         cn->checked = true;
00278 
00279                         //get ubVal and lbVal
00280                         
00281                         ubVal=sarsopSolver->beliefCacheSet[cn->cacheIndex.sval]->getRow(cn->cacheIndex.row)->UB;
00282                         lbVal=sarsopSolver->beliefCacheSet[cn->cacheIndex.sval]->getRow(cn->cacheIndex.row)->LB;
00283                         //ubVal=beliefCache->getRow(cn->cacheIndex)->UB;
00284                         //lbVal=beliefCache->getRow(cn->cacheIndex)->LB;
00285 
00286                         DEBUG_TRACE ( cout << "ubVal " << ubVal << endl; );
00287                         DEBUG_TRACE ( cout << "lbVal " << lbVal << endl; );
00288 
00289                         //check for each sub entry E of cn
00290                         FOR(a, cn->getNodeNumActions()) 
00291                         {
00292                                 if(cn->Q[a].ubVal <lbVal - 0.0001)
00293                                 {
00294                                         DEBUG_TRACE ( cout << "cn->Q[a].ubVal " << cn->Q[a].ubVal << endl; );
00295                                         //update the entry cn.Q[a] as uninitialized
00296                                         nullifyEntry(&cn->Q[a]);
00297                                 }
00298                                 nullifySubOptimalCerts(&cn->Q[a]);
00299                         }
00300                 }
00301         }
00302 
00303 
00304         //Function: nullifySubOptimalCerts
00305         //Functionality:
00306         //      for the BeliefTreeQEntry (Action node of pomdp problem) 'e' given,
00307         //      check and nullify all its children nodes
00308         //      Basically, the method pass down the checking procedure
00309         //Parameters:
00310         //      e: the action node whose children nodes are to be checked
00311         //Returns:
00312         //      NA
00313         void SARSOPPrune::nullifySubOptimalCerts(BeliefTreeQEntry* e)
00314         {
00315                 FOR(x, e->getNumStateOutcomes()) {
00316                         BeliefTreeObsState* xpt = e->stateOutcomes[x];
00317                         if (xpt!=NULL) {
00318                                 FOR(o, xpt->getNumOutcomes()){//for all observations, nullify all subsequent BeliefTreeNodes
00319                                         if(xpt->outcomes[o]!=NULL){
00320                                                 BeliefTreeNode* cn_p = xpt->outcomes[o]->nextState;//@
00321                                                 if(cn_p != NULL){
00322                                                         if((*cn_p).isFringe()==false){
00323                                                                 nullifySubOptimalCerts(cn_p);
00324                                                         }
00325                                                 }
00326                                         }
00327                                 }
00328                         }
00329                 } 
00330         }
00331 
00332 
00333         //Function: nullifyEntry
00334         //Functionality:
00335         //      To set the entire subtree of 'e' to be invalid, i.e. 'valid' attribute of
00336         //      'e' and its descendent-BeliefTreeQEntry-s are set to 'false', and all its
00337         //      descendent-BeliefTreeNode-s will have their validCount--
00338         //Parameters:
00339         //      e: the root entry which is suboptimal
00340         //Returns:
00341         //      NA
00342         void SARSOPPrune::nullifyEntry(BeliefTreeQEntry* e)
00343         {
00344                 //only do nullification if it hasn't been nullified before
00345                 if(e->valid==true){
00346                         e->valid = false;//nullify validness of entry
00347                         FOR(x, e->getNumStateOutcomes()) 
00348                         {
00349                                 BeliefTreeObsState* xpt = e->stateOutcomes[x];
00350                                 if (xpt!=NULL) 
00351                                 {
00352                                         FOR(o, xpt->getNumOutcomes())
00353                                         {
00354                                                 //for all observations, nullify all subsequent BeliefTreeNodes
00355                                                 if(xpt->outcomes[o]!=NULL)
00356                                                 {
00357                                                         BeliefTreeNode* cn_p = xpt->outcomes[o]->nextState;//@
00358                                                         if(cn_p != NULL)
00359                                                         {
00360                                                                 cn_p->count--;//decrease counter of cn
00361 
00362                                                                 DEBUG_TRACE ( cout << "nullifyEntry" << endl; );
00363                                                                 DEBUG_TRACE ( cout << "Node " << cn_p->cacheIndex.row << " count " << cn_p->count << endl; );
00364 
00365                                                                 if(cn_p->count==0)
00366                                                                 {
00367                                                                         FOR(a, cn_p->getNodeNumActions())
00368                                                                         {
00369                                                                                 nullifyEntry(&cn_p->Q[a]);
00370                                                                         }
00371                                                                 }
00372                                                         }
00373                                                 }
00374                                         }
00375                                 } 
00376                         }       
00377                 }
00378         }
00379 
00380 
00381 
00382         //Function: uncheckAllSubNodes
00383         //Functionality:
00384         //      refresh the 'checked' record of every node in the subtree
00385         //      to be false. (so that later can be used for tranversing the tree)
00386         //Parameter:
00387         //      cn: the root node of the tree to be unchecked
00388         //Returns:
00389         //      NA
00390         void SARSOPPrune::uncheckAllSubNodes(BeliefTreeNode* cn)
00391         {
00392                 if (cn->checked==true) 
00393                 {
00394                         cn->checked = false;
00395                         //uncheck all sub entries
00396                         FOR(a, cn->getNodeNumActions()) 
00397                         {
00398                                 uncheckEntry(&cn->Q[a]);
00399                         }
00400                 }
00401         }
00402 
00403         //Function: uncheckEntry
00404         //Functionality:
00405         //      refresh the 'checked' record of every node in the subtree
00406         //      of 'e' to be false. (for later's tranversing purpose)
00407         //Parameter:
00408         //      e: the action node whose subtree is to be unchecked
00409         //Returns:
00410         //      NA
00411         void SARSOPPrune::uncheckEntry(BeliefTreeQEntry* e)
00412         {
00413                 FOR(x, e->getNumStateOutcomes()) 
00414                 {
00415                         BeliefTreeObsState* xpt = e->stateOutcomes[x];
00416                         if (xpt!=NULL) 
00417                         {
00418                                 FOR(o, xpt->getNumOutcomes())
00419                                 {
00420                                         //for all observations, nullify all subsequent BeliefTreeNodes
00421                                         if(xpt->outcomes[o]!=NULL)
00422                                         {
00423                                                 BeliefTreeNode* cn_p = xpt->outcomes[o]->nextState;//@
00424                                                 if(cn_p != NULL)
00425                                                 {
00426                                                         uncheckAllSubNodes(cn_p);
00427                                                 }
00428                                         }
00429                                 }
00430                         }
00431                 } 
00432 
00433         }
00434 }
00435 


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29