Go to the documentation of this file.00001 #include "SARSOPPrune.h"
00002
00003 namespace momdp
00004 {
00005
00006 void SARSOPPrune::prune(void)
00007 {
00008 DEBUG_TRACE(cout << "SARSOPPrune" << endl;);
00009
00010 pruneLowerBound();
00011 pruneUpperBound();
00012
00013 if(solver->numBackups/pruneInterval>= currentRound)
00014 {
00015 DEBUG_TRACE(cout << "currentRound " << currentRound << endl;);
00016 currentRound++;
00017
00018
00019
00020 if (problem->XStates->size() == 1)
00021 {
00022 nullifySubOptimalBranches();
00023 }
00024
00025
00026
00027
00028
00029
00030
00031 if (problem->XStates->size() == 1)
00032 {
00033 pruneDynamicDeltaVersion();
00034
00035 }
00036
00037
00038 }
00039 }
00040
00041 void SARSOPPrune::pruneLowerBound()
00042 {
00043 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size())
00044 {
00045 sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->prune();
00046 }
00047 }
00048 void SARSOPPrune::pruneUpperBound()
00049 {
00050 FOR (stateidx, sarsopSolver->upperBoundSet->set.size())
00051 {
00052 sarsopSolver->upperBoundSet->set[stateidx]->pruneEngine->prune();
00053 }
00054 }
00055
00056 void SARSOPPrune::pruneDynamicDeltaVersion()
00057 {
00058 DEBUG_TRACE( cout << "SARSOPPrune::pruneDynamicDeltaVersion" << endl;);
00059
00060 double curTime = GlobalResource::getInstance()->getInstance()->getRunTime();
00061 double sumPruneTime = 0;
00062 int overPruneLocal = 0, underPruneLocal = 0;
00063 int overPruneSum = 0;
00064 int underPruneSum = 0;
00065
00066 int sumNumPrune=0;
00067
00068 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size())
00069 {
00070 sumPruneTime += sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->pruneTime;
00071 sumNumPrune += sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->numPrune;
00072 }
00073
00074 if (firstPass)
00075 {
00076
00077 DEBUG_TRACE( cout << "firstPass" << endl;);
00078 elapsed = curTime;
00079 firstPass = false;
00080 }
00081
00082
00083
00084
00085
00086
00087 if((curTime-elapsed) > 5 && sumPruneTime < (curTime-elapsed-5) * 0.1)
00088 {
00089
00090 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size())
00091 {
00092 DEBUG_TRACE( cout << "stateidx " << stateidx << endl;);
00093 sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->pruneDynamicDeltaVersion(solver->numBackups, overPruneLocal, underPruneLocal);
00094 overPruneSum += overPruneLocal;
00095 underPruneSum += underPruneLocal;
00096
00097 }
00098
00099
00100
00101
00102
00103
00104 updateDeltaVersion2(overPruneSum, underPruneSum);
00105
00106
00107
00108
00109 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size())
00110 {
00111 sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->setDelta(bglobal_delta);
00112 }
00113 }
00114 }
00115
00116
00117 void SARSOPPrune::setDelta(double newDelta)
00118 {
00119 DEBUG_TRACE(cout << "SARSOPPrune::setDelta newDelta " << newDelta << endl;);
00120 bglobal_delta = newDelta;
00121 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size())
00122 {
00123 sarsopSolver->lowerBoundSet->set[stateidx]->pruneEngine->setDelta(newDelta);
00124 }
00125 }
00126
00127 void SARSOPPrune::updateDeltaVersion2(int overPrune, int underPrune)
00128 {
00129 double overPruneThreshold = solver->solverParams-> overPruneThreshold;
00130 double lowerPruneThreshold = solver->solverParams-> lowerPruneThreshold;
00131
00132 double fOverPrune = overPrune;
00133
00134 DEBUG_TRACE(cout << "SARSOPPrune::updateDeltaVersion2" << endl;);
00135 DEBUG_TRACE(cout << "overPruneThreshold" << overPruneThreshold << endl;);
00136 DEBUG_TRACE(cout << "lowerPruneThreshold" << lowerPruneThreshold << endl;);
00137 DEBUG_TRACE(cout << "fOverPrune" << fOverPrune << endl;);
00138 DEBUG_TRACE(cout << "underPrune" << underPrune << endl;);
00139
00140
00141
00142 if(solver->solverParams->dynamicDeltaPercentageMode)
00143 {
00144 DEBUG_TRACE(cout << "dynamicDeltaPercentageMode" << endl;);
00145 unsigned int total_planes = 0;
00146
00147 FOR (stateidx, sarsopSolver->lowerBoundSet->set.size())
00148 {
00149 total_planes += sarsopSolver->lowerBoundSet->set[stateidx]->planes.size();
00150 }
00151 fOverPrune = ((double)overPrune)/total_planes;
00152 DEBUG_TRACE(cout << "fOverPrune" << fOverPrune << endl;);
00153 }
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166 switch(state)
00167 {
00168 case 0:
00169 DEBUG_TRACE(cout << "state 0" << endl;);
00170 if(fOverPrune < overPruneThreshold)
00171 {
00172 state = 1;
00173 }
00174 else
00175 {
00176 increaseDelta();
00177
00178 }
00179 break;
00180 case 1:
00181 DEBUG_TRACE(cout << "state 1" << endl;);
00182 if(fOverPrune < lowerPruneThreshold)
00183 {
00184 state = 2;
00185 decreaseDelta();
00186
00187 }
00188 else if(fOverPrune > overPruneThreshold)
00189 {
00190 state = 0;
00191 increaseDelta();
00192
00193 }
00194 break;
00195 case 2:
00196 DEBUG_TRACE(cout << "state 2" << endl;);
00197 if(fOverPrune > overPruneThreshold)
00198 {
00199 state = 0;
00200 increaseDelta();
00201
00202 }
00203 else
00204 {
00205 decreaseDelta();
00206 }
00207 break;
00208 }
00209 state = 1;
00210
00211 }
00212 void SARSOPPrune::increaseDelta()
00213 {
00214 DEBUG_TRACE(cout << "increaseDelta" << endl;);
00215 if (bglobal_delta < 2.0+1e-7)
00216 {
00217 bglobal_delta *= 2;
00218 }
00219
00220 }
00221
00222 void SARSOPPrune::decreaseDelta()
00223 {
00224 DEBUG_TRACE(cout << "decreaseDelta" << endl;);
00225 bglobal_delta /= 2;
00226
00227
00228 }
00229
00230 void SARSOPPrune::nullifySubOptimalBranches()
00231 {
00232 DEBUG_TRACE( cout << "SARSOPPrune::nullifySubOptimalBranches" << endl;);
00233
00234 BeliefForest* globalRoot = sarsopSolver->sampleEngine->getGlobalNode();
00235
00236 BeliefTreeNode* currRoot;
00237
00238 unsigned int numSampleRoots = globalRoot->sampleRootEdges.size();
00239
00240 FOR (r, numSampleRoots)
00241 {
00242
00243 SampleRootEdge* eR = globalRoot->sampleRootEdges[r];
00244 if (NULL != eR)
00245 {
00246 currRoot = eR->sampleRoot;
00247 uncheckAllSubNodes(currRoot);
00248 nullifySubOptimalCerts(currRoot);
00249 }
00250 }
00251 }
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269 void SARSOPPrune::nullifySubOptimalCerts(BeliefTreeNode* cn)
00270 {
00271 double ubVal, lbVal;
00272 DEBUG_TRACE ( cout << "SARSOPPrune::nullifySubOptimalCerts" << endl; );
00273 if(cn->checked==false)
00274 {
00275 DEBUG_TRACE ( cout << "cn->checked==false" << endl; );
00276
00277 cn->checked = true;
00278
00279
00280
00281 ubVal=sarsopSolver->beliefCacheSet[cn->cacheIndex.sval]->getRow(cn->cacheIndex.row)->UB;
00282 lbVal=sarsopSolver->beliefCacheSet[cn->cacheIndex.sval]->getRow(cn->cacheIndex.row)->LB;
00283
00284
00285
00286 DEBUG_TRACE ( cout << "ubVal " << ubVal << endl; );
00287 DEBUG_TRACE ( cout << "lbVal " << lbVal << endl; );
00288
00289
00290 FOR(a, cn->getNodeNumActions())
00291 {
00292 if(cn->Q[a].ubVal <lbVal - 0.0001)
00293 {
00294 DEBUG_TRACE ( cout << "cn->Q[a].ubVal " << cn->Q[a].ubVal << endl; );
00295
00296 nullifyEntry(&cn->Q[a]);
00297 }
00298 nullifySubOptimalCerts(&cn->Q[a]);
00299 }
00300 }
00301 }
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313 void SARSOPPrune::nullifySubOptimalCerts(BeliefTreeQEntry* e)
00314 {
00315 FOR(x, e->getNumStateOutcomes()) {
00316 BeliefTreeObsState* xpt = e->stateOutcomes[x];
00317 if (xpt!=NULL) {
00318 FOR(o, xpt->getNumOutcomes()){
00319 if(xpt->outcomes[o]!=NULL){
00320 BeliefTreeNode* cn_p = xpt->outcomes[o]->nextState;
00321 if(cn_p != NULL){
00322 if((*cn_p).isFringe()==false){
00323 nullifySubOptimalCerts(cn_p);
00324 }
00325 }
00326 }
00327 }
00328 }
00329 }
00330 }
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342 void SARSOPPrune::nullifyEntry(BeliefTreeQEntry* e)
00343 {
00344
00345 if(e->valid==true){
00346 e->valid = false;
00347 FOR(x, e->getNumStateOutcomes())
00348 {
00349 BeliefTreeObsState* xpt = e->stateOutcomes[x];
00350 if (xpt!=NULL)
00351 {
00352 FOR(o, xpt->getNumOutcomes())
00353 {
00354
00355 if(xpt->outcomes[o]!=NULL)
00356 {
00357 BeliefTreeNode* cn_p = xpt->outcomes[o]->nextState;
00358 if(cn_p != NULL)
00359 {
00360 cn_p->count--;
00361
00362 DEBUG_TRACE ( cout << "nullifyEntry" << endl; );
00363 DEBUG_TRACE ( cout << "Node " << cn_p->cacheIndex.row << " count " << cn_p->count << endl; );
00364
00365 if(cn_p->count==0)
00366 {
00367 FOR(a, cn_p->getNodeNumActions())
00368 {
00369 nullifyEntry(&cn_p->Q[a]);
00370 }
00371 }
00372 }
00373 }
00374 }
00375 }
00376 }
00377 }
00378 }
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390 void SARSOPPrune::uncheckAllSubNodes(BeliefTreeNode* cn)
00391 {
00392 if (cn->checked==true)
00393 {
00394 cn->checked = false;
00395
00396 FOR(a, cn->getNodeNumActions())
00397 {
00398 uncheckEntry(&cn->Q[a]);
00399 }
00400 }
00401 }
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411 void SARSOPPrune::uncheckEntry(BeliefTreeQEntry* e)
00412 {
00413 FOR(x, e->getNumStateOutcomes())
00414 {
00415 BeliefTreeObsState* xpt = e->stateOutcomes[x];
00416 if (xpt!=NULL)
00417 {
00418 FOR(o, xpt->getNumOutcomes())
00419 {
00420
00421 if(xpt->outcomes[o]!=NULL)
00422 {
00423 BeliefTreeNode* cn_p = xpt->outcomes[o]->nextState;
00424 if(cn_p != NULL)
00425 {
00426 uncheckAllSubNodes(cn_p);
00427 }
00428 }
00429 }
00430 }
00431 }
00432
00433 }
00434 }
00435