SARSOP.cpp
Go to the documentation of this file.
00001 #include <cerrno>
00002 #include <cstring>
00003 #include <iomanip>
00004 
00005 #include "SARSOP.h"
00006 #include "SARSOPPrune.h"
00007 #include "MOMDP.h"
00008 #include "BeliefValuePairPoolSet.h"
00009 #include "AlphaPlanePoolSet.h"
00010 #include "BeliefTreeNode.h"
00011 #include "CPTimer.h"
00012 #include "BlindLBInitializer.h"
00013 #include "FastInfUBInitializer.h"
00014 #include "BackupBeliefValuePairMOMDP.h"
00015 #include "BackupAlphaPlaneMOMDP.h"
00016 
00017 
00018 void printSampleBelief(list<cacherow_stval>& beliefNStates)
00019 {
00020         cout << "SampledBelief" <<  endl;
00021         for(list<cacherow_stval>::iterator iter =beliefNStates.begin(); iter != beliefNStates.end() ; iter ++)
00022         {
00023                 cout <<  "[ " <<(*iter).row << " : " << (*iter).sval << " ] ";
00024         }
00025         cout <<  endl;
00026 }
00027 
00028 
00029 void SARSOP::progressiveIncreasePolicyInteval(int& numPolicies)
00030 {
00031         if (numPolicies == 0) 
00032         {
00033                 this->solverParams->interval *= 10; 
00034                 numPolicies++;
00035 
00036         } 
00037         else 
00038         {
00039                 if (numPolicies == 5)
00040                 {
00041                         this->solverParams->interval *= 5;
00042                 }
00043                 else if (numPolicies == 10)
00044                 {
00045                         this->solverParams->interval *= 2;
00046                 }
00047                 else if (numPolicies == 15)
00048                 {
00049                         this->solverParams->interval *= 4;
00050                 }
00051 
00052                 numPolicies++;
00053         }
00054 }
00055 
00056 SARSOP::SARSOP(SharedPointer<MOMDP> problem, SolverParams * solverParams)
00057 {
00058         this->problem = problem;
00059         this->solverParams = solverParams;
00060         beliefForest = new BeliefForest(); 
00061         sampleEngine = new SampleBP();
00062         ((SampleBP*)sampleEngine)->setup(problem, this);
00063         beliefForest->setup(problem, this->sampleEngine, &this->beliefCacheSet);
00064         numBackups = 0;
00065 }
00066 
00067 SARSOP::~SARSOP(void)
00068 {
00069 }
00070 
00071 
00072 void SARSOP::solve(SharedPointer<MOMDP> problem)
00073 {
00074         try
00075         {
00076 
00077                 bool skipSample = false;        // ADDED_24042009 flag for when all roots have ended their last trial and precision gap  <= 0
00078 
00079                 //struct tms now;
00080                 int policyIndex, checkIndex;//index for policy output file, and index for checking whether to output policy file
00081                 bool stop;
00082                 std::vector<cacherow_stval> currentBeliefIndexArr; //int currentBeliefIndex;
00083                 // modified for parallel trials
00084                 //cacherow_stval currentBeliefIndex; //int currentBeliefIndex;
00085                 currentBeliefIndexArr.resize(problem->initialBeliefX->size());
00086 
00087                 FOR(r,currentBeliefIndexArr.size())
00088                 {
00089                         currentBeliefIndexArr[r].sval = -1;
00090                         currentBeliefIndexArr[r].row = -1;
00091                 }
00092 
00093                 list<cacherow_stval> sampledBeliefs;  //modified for factored, prevly: list<int> sampledBeliefs;
00094                 cacherow_stval lastRootBeliefIndex; //24092008 added to keep track of root chosen for each new trial
00095                 lastRootBeliefIndex.row = -1;
00096                 lastRootBeliefIndex.sval = -1;
00097 
00098                 int numPolicies = 0; //SYLADDED 07082008 temporary
00099 
00100                 //start timing
00101                 //times(&start);
00102                 runtimeTimer.start();
00103                 cout << "\nSARSOP initializing ..." << endl;
00104 
00105                 initialize(problem);
00106 
00107                 if(problem->XStates->size() != 1 && problem->hasPOMDPMatrices())
00108                 {
00109                         // only POMDPX parser can generates 2 sets of matrices, therefore, only release the second set if it is using POMDPX parser and the second set is generated
00110                         problem->deletePOMDPMatrices();
00111                 }
00112                 if(problem->XStates->size() != 1 && problem->hasPOMDPMatrices())
00113                 {
00114                         // only POMDPX parser can generates 2 sets of matrices, therefore, only release the second set if it is using POMDPX parser and the second set is generated
00115                         problem->deletePOMDPMatrices();
00116                 }
00117                 if(problem->XStates->size() != 1 && problem->hasPOMDPMatrices())
00118                 {
00119                         // only POMDPX parser can generates 2 sets of matrices, therefore, only release the second set if it is using POMDPX parser and the second set is generated
00120                         problem->deletePOMDPMatrices();
00121                 }
00122                 GlobalResource::getInstance()->getInstance()->solving = true;
00123 
00124                 //cout << "finished calling initialize() in SARSOP::solve()" << endl;
00125 
00126                 //initialize parameters 
00127                 stop = false;
00128 
00129                 //ADD SYLTAG - need to expand global root and all the roots for sampling
00130                 BeliefForest& globalroot = *(sampleEngine->getGlobalNode());
00131                 beliefForest->globalRootPrepare();//do preparation work for global root
00132 
00133                 // cycle through all the roots and do preparation work
00134                 FOR(r, globalroot.sampleRootEdges.size()) {
00135                         //              FOR(r, sampleEngine->globalRoot->sampleRootEdges.size()) {
00136                         if (NULL != globalroot.sampleRootEdges[r]) 
00137                         {
00138                                 BeliefTreeNode& thisRoot = *(globalroot.sampleRootEdges[r]->sampleRoot);
00139                                 sampleEngine->samplePrepare(thisRoot.cacheIndex);//do preparation work for this root
00140                         }
00141                 }
00142 
00143                 // TODO:: sampleEngine->dumpData = dumpData;//dump data
00144                 // TODO:: sampleEngine->dumpPolicyTrace = dumpPolicyTrace;//dump datadone
00145                 policyIndex = 0;
00146                 checkIndex = 0;
00147 
00148                 lapTimer.start();
00149                 elapsed = runtimeTimer.elapsed();
00150                 printf("  initialization time : %.2fs\n", elapsed);
00151 
00152 
00153                 DEBUG_LOG(logFilePrint(policyIndex-1););
00154 
00155                 // paused timer for writing policy
00156                 double currentElapsed = lapTimer.elapsed();
00157                 lapTimer.pause();
00158                 runtimeTimer.pause();
00159 
00160                 //write out INITIAL policy
00161 
00162                 DEBUG_LOG(writeIntermediatePolicyTraceToFile(0, 0.0, this->solverParams->outPolicyFileName, this->solverParams->problemName ); );
00163                 DEBUG_LOG(cout << "Initial policy written" << endl;);
00164                 
00165                 printHeader();  
00166 
00167                 lapTimer.resume();
00168                 runtimeTimer.resume();
00169 
00170                 policyIndex++;
00171                 elapsed += currentElapsed;
00172 
00173                 //times(&last);//renew 'last' time flag
00174                 lapTimer.restart();
00175 
00176                 alwaysPrint();
00177 
00178                 DEBUG_LOG( logFilePrint(policyIndex-1); );
00179 
00180 
00181                 //create while loop where:
00182                 int lastTrial = ((SampleBP*)sampleEngine)->numTrials;
00183 
00184                 // no root assigned as active at the beginning
00185                 int activeRoot = -1;
00186 
00187                 while(!stop)
00188                 {
00189                         int numTrials = ((SampleBP*)sampleEngine)->numTrials;
00190                         if( this->solverParams->targetTrials > 0 && numTrials >  this->solverParams->targetTrials )
00191                         {
00192                                 //    target number of trials reached
00193                                 break;
00194                         }
00195 
00196                         //0. IF this is the start of a new trial, 
00197                         // backup the list of nodes in sampledBeliefs, then 
00198                         // decide on which root to sample from
00199                         //      (choose the root which has the largest weighted excess uncertainty)
00200                         //   ELSE, do a regular backup of just one node
00201 
00202                         if (activeRoot == -1) 
00203                         {
00204                                 FOR (r, globalroot.getGlobalRootNumSampleroots()) 
00205                                 {
00206                                         SampleRootEdge* eR = globalroot.sampleRootEdges[r];
00207 
00208                                         if (NULL != eR) 
00209                                         {
00210                                                 BeliefTreeNode & sn = *eR->sampleRoot;
00211                                                 sampledBeliefs.clear();
00212                                                 sampledBeliefs.push_back(sn.cacheIndex);
00213 
00214                                                 DEBUG_TRACE( printSampleBelief(sampledBeliefs); );
00215 
00216                                                 currentBeliefIndexArr[r] =  backup(sampledBeliefs);
00217                                         }
00218                                 }
00219                                 sampledBeliefs.clear();
00220 
00221                         } 
00222                         else  
00223                         {
00224 
00225                                 if (((SampleBP *)sampleEngine)->newTrialFlagArr[activeRoot] == 1) 
00226                                 {
00227                                         // backup the list of nodes in sampledBeliefs
00228                                         DEBUG_TRACE( printSampleBelief(sampledBeliefs); );
00229                                         currentBeliefIndexArr[activeRoot] = backup(sampledBeliefs);
00230                                         lastRootBeliefIndex = currentBeliefIndexArr[activeRoot];
00231                                         sampledBeliefs.clear();
00232                                         // backup at all root nodes except for the root node that we had just backedup
00233                                         FOR (r, globalroot.getGlobalRootNumSampleroots()) 
00234                                         {
00235                                                 SampleRootEdge* eR = globalroot.sampleRootEdges[r];
00236 
00237                                                 if (NULL != eR) {
00238                                                         BeliefTreeNode & sn = *eR->sampleRoot;
00239                                                         // check if we had just done backup at this root,
00240                                                         if( !((sn.cacheIndex.row == lastRootBeliefIndex.row)&&(sn.cacheIndex.sval == lastRootBeliefIndex.sval)) ) {
00241 
00242                                                                 // ADDED_24042009 - dont do LB backup if precision gap <= 0
00243                                                                 // check if the precision gap for this root is already zero
00244                                                                 double lbVal = beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->LB;
00245                                                                 double ubVal = beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->UB;
00246 
00247                                                                 if (!((ubVal - lbVal) <= 0)) 
00248                                                                 {
00249                                                                         // else, do backup at this root
00250                                                                         sampledBeliefs.clear();
00251                                                                         sampledBeliefs.push_back(sn.cacheIndex);
00252 
00253                                                                         DEBUG_TRACE( cout << "LB backup only " << endl; );
00254                                                                         DEBUG_TRACE( printSampleBelief(sampledBeliefs); );
00255 
00256                                                                         backupLBonly(sampledBeliefs);
00257                                                                         //ofsol1710d: backup(sampledBeliefs);
00258                                                                 } 
00259                                                         }
00260                                                 }
00261                                         }
00262                                         sampledBeliefs.clear();
00263 
00264                                         ((SampleBP *)sampleEngine)->newTrialFlagArr[activeRoot]  = 0;
00265 
00266                                 } 
00267                                 else 
00268                                 {
00269                                         DEBUG_TRACE( printSampleBelief(sampledBeliefs); );
00270                                         currentBeliefIndexArr[activeRoot] = backup(sampledBeliefs);
00271 
00272                                 }
00273 
00274                         }       
00275 
00276                         // go to next valid activeRoot here
00277                         if (activeRoot == -1) // set to the first valid root
00278                         {
00279                                 // cycle through all roots till we find a valid one
00280                                 FOR (r, globalroot.getGlobalRootNumSampleroots()) 
00281                                 {
00282                                         SampleRootEdge* eR = globalroot.sampleRootEdges[r];
00283                                         if (NULL != eR) 
00284                                         {
00285                                                 activeRoot = r;
00286                                                 break;
00287                                         }
00288                                 } 
00289                         } 
00290                         else                    // set to the next valid root
00291                         {
00292                                 int currActiveRoot = activeRoot;        // ADDED_24042009  
00293                                 bool passedcurrActiveRoot = false;      // ADDED_24042009 
00294                                 while(true){
00295 
00296                                         // ADDED_24042009
00297                                         if ((activeRoot == currActiveRoot) && passedcurrActiveRoot) // i.e. this is the second time that activeRoot == currActiveRoot, the while loop has cycled through all roots and not found one that passes the tests below
00298                                         {       skipSample = true;              // flag to indicate dont call sample()
00299                                         break;
00300                                         } 
00301 
00302                                         if (activeRoot == currActiveRoot) passedcurrActiveRoot = true;
00303 
00304                                         if (activeRoot == (globalroot.getGlobalRootNumSampleroots()-1)) 
00305                                                 activeRoot = 0;
00306                                         else activeRoot++;
00307 
00308                                         if (globalroot.sampleRootEdges[activeRoot] != NULL){
00309 
00310                                                 // ADDED_24042009 - dont go to this root if this root is about to start a new trial
00311                                                 // and the precision gap for the root is already zero
00312                                                 cacherow_stval currCacheIndex = globalroot.sampleRootEdges[activeRoot]->sampleRoot->cacheIndex;
00313                                                 double lbVal = beliefCacheSet[currCacheIndex.sval]->getRow(currCacheIndex.row)->LB;
00314                                                 double ubVal = beliefCacheSet[currCacheIndex.sval]->getRow(currCacheIndex.row)->UB;
00315                                                 if (!((((SampleBP *)sampleEngine)->trialTargetPrecisionArr[activeRoot] == -1)&&((ubVal - lbVal) <= 0) ))
00316                                                 {
00317                                                         break;
00318                                                 }
00319                                         }
00320                                 }
00321                         }
00322 
00323                         //2. sample
00324                         //  samples the next belief to do backup
00325                         //  a. if haven't reached target depth, search further
00326                         //  b. if target depth has been reached, go back to root
00327                         if (!skipSample) 
00328                         {               
00329                                 // ADDED_24042009
00330                                 sampledBeliefs = sampleEngine->sample(currentBeliefIndexArr[activeRoot], activeRoot);
00331                         }
00332                         //3. prune
00333                         //      decide whether needs pruning at this moment, if so,
00334                         //  prune off the unnecessary nodes
00335 
00336                         //DEBUG_TRACE (beliefForest->print(););
00337 
00338                         pruneEngine->prune(); 
00339 
00340                         //4. write out policy file if interval time reached
00341                         // check time every CHECK_INTERVAL backups
00342                         if(this->solverParams->interval > 0 || this->solverParams->timeoutSeconds > 0)
00343                         {
00344                                 //only do this if required
00345                                 if((numBackups/CHECK_INTERVAL) >= checkIndex)
00346                                 {//do check every CHECK_INTERVAL(50) backups
00347                                         //times(&now);
00348                                         checkIndex++;//for next check
00349 
00350                                         //check and write out policy file periodically
00351                                         if (this->solverParams->interval > 0)
00352                                         {
00353                                                 double currentElapsed = lapTimer.elapsed();
00354                                                 if(currentElapsed > this->solverParams->interval)
00355                                                 {
00356                                                         //write out policy and reset parameters
00357 
00358                                                         // paused timer for writing policy
00359                                                         lapTimer.pause();
00360                                                         runtimeTimer.pause();
00361 
00362                                                         writeIntermediatePolicyTraceToFile(numTrials, runtimeTimer.elapsed(), this->solverParams->outPolicyFileName, this->solverParams->problemName );
00363 
00364                                                         lapTimer.resume();
00365                                                         runtimeTimer.resume();
00366 
00367                                                         policyIndex++;
00368                                                         elapsed += currentElapsed;
00369                                                         cout << "Intermediate policy written(interval: "<< this->solverParams->interval <<")" << endl;
00370 
00371                                                         // reset laptime so that next interval can start
00372                                                         lapTimer.restart();
00373 
00374 
00375                                                         DEBUG_LOG(logFilePrint(policyIndex-1););
00376                                                         DEBUG_LOG( progressiveIncreasePolicyInteval(numPolicies); );
00377 
00378 
00379                                                 }
00380                                         }//end write out policy periodically
00381 
00382                                         else if(this->solverParams->timeoutSeconds >0)
00383                                         {
00384                                                 double currentElapsed = runtimeTimer.elapsed();
00385                                                 elapsed = currentElapsed;
00386                                         }
00387                                 }//end check periodically for policy write out and elapsed time update 
00388                         }
00389 
00390                         //5. do printing for current precision
00391                         print();
00392 
00393                         //6. decide whether to stop here
00394                         stop = stopNow();
00395 
00396 
00397                 }
00398 
00399         }
00400 
00401         catch(bad_alloc &e)
00402         {
00403                 // likely bad_alloc exception
00404                 // should we remove the last alpha vector?
00405                 cout << "Memory limit reached, trying to write out policy" << endl;
00406 
00407         }
00408 
00409         //prune for the last time
00410         FOR (stateidx, lowerBoundSet->set.size())
00411         {
00412                 lowerBoundSet->set[stateidx]->pruneEngine->prunePlanes();
00413         }
00414 
00415         printHeader();
00416         alwaysPrint();
00417         printDivider();
00418         DEBUG_LOG(logFilePrint(-1););
00419 
00420         //now output policy to the outfile
00421         cout << endl << "Writing out policy ..." << endl;
00422         cout << "  output file : " << this->solverParams->outPolicyFileName << endl;
00423         writePolicy(this->solverParams->outPolicyFileName, this->solverParams->problemName);
00424 }
00425 
00426 //Function: print
00427 //Functionality:
00428 //      print the necessary info for help  understanding current situation inside
00429 //      solver
00430 
00431 void SARSOP::print()
00432 {
00433         if(numBackups/CHECK_INTERVAL>printIndex)
00434         {
00435                 printIndex++;
00436                 //print time now
00437                 alwaysPrint();
00438         }
00439 }
00440 
00441 //Function: print
00442 //Functionality:
00443 //    print the necessary info for help  understanding current situation inside
00444 //    solver
00445 void SARSOP::alwaysPrint()
00446 {
00447         //struct tms now;
00448         //float utime, stime;
00449         //long int clk_tck = sysconf(_SC_CLK_TCK);
00450 
00451         //print time now
00452         //times(&now);
00453         double currentTime =0;
00454         if(this->solverParams->interval >0)
00455         {
00456                 currentTime = elapsed + lapTimer.elapsed();
00457         }
00458         else
00459         {
00460                 currentTime = runtimeTimer.elapsed();
00461         }
00462         //printf("%.2fs ", currentTime);
00463         cout.precision(6);
00464         cout <<" ";cout.width(8);cout << left << currentTime;
00465 
00466         //print current trial number, num of backups
00467         int numTrials = ((SampleBP*)sampleEngine)->numTrials;
00468         //printf("#Trial %d ",numTrials);
00469         cout.width(7);cout << left  <<numTrials << " "; 
00470         //printf("#Backup %d ", numBackups); 
00471         cout.width(8);cout << left << numBackups << " ";
00472         //print #alpha vectors
00473         //print precision
00474 
00475         //ADD SYLTAG
00476         //assume we can estimate lb and ub at the global root
00477         //by cycling through all the roots to find their bounds
00478         double lb = 0, ub = 0, width = 0;
00479 
00480         BeliefForest& globalRoot  = *(sampleEngine->getGlobalNode());
00481         FOR (r, globalRoot.getGlobalRootNumSampleroots()) 
00482         {
00483                 SampleRootEdge* eR = globalRoot.sampleRootEdges[r];
00484                 if (NULL != eR) 
00485                 {
00486                         BeliefTreeNode & sn = *eR->sampleRoot;
00487                         double lbVal =  beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->LB;
00488                         double ubVal =  beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->UB;
00489                         lb += eR->sampleRootProb * lbVal;
00490                         ub += eR->sampleRootProb * ubVal;
00491                         width += eR->sampleRootProb * (ubVal - lbVal);
00492                 }
00493         }
00494 
00495         //REMOVE SYLTAG
00496         //cacherow_stval rootIndex = sampleEngine->getRootNode()->cacheIndex;
00497         //double lb = bounds->boundsSet[rootIndex.sval]->beliefCache->getRow(rootIndex.row)->LB;
00498         //double ub = bounds->boundsSet[rootIndex.sval]->beliefCache->getRow(rootIndex.row)->UB;
00499 
00500         //printf("[%f,%f],", lb, ub);
00501         cout.width(10); cout << left << lb<< " ";
00502         cout.width(10); cout << left << ub<< " ";
00503         
00504         //print precision
00505         double precision = width; // ub - lb;   //MOD SYLTAG
00506         //printf("%f, ", precision);
00507         cout.width(11); 
00508         cout << left << precision << " ";
00509         int numAlphas = 0;
00510         FOR (setIdx, beliefCacheSet.size()) 
00511         {
00512                 numAlphas += (int)lowerBoundSet->set[setIdx]->planes.size();
00513         }
00514 
00515         //printf("#Alphas %d ", numAlphas);                     //SYLTEMP FOR EXPTS
00516         cout.width(9);cout << left << numAlphas;
00517 
00518         //print #belief nodes
00519         //printf("#Beliefs %d", sampleEngine->numStatesExpanded);
00520         cout.width(9);cout << left << sampleEngine->numStatesExpanded;
00521 
00522         //printf("#alphas %d", (int)bounds->alphaPlanePool->planes.size());
00523         printf("\n"); 
00524 
00525 }
00526 
00527 //SYL ADDED FOR EXPTS
00528 //Function: print
00529 //Functionality:
00530 //    print the necessary info for help  understanding current situation inside
00531 //    solver
00532 
00533 void SARSOP::logFilePrint(int index)
00534 {
00535         //struct tms now;
00536         //float utime, stime;
00537         //long int clk_tck = sysconf(_SC_CLK_TCK);
00538 
00539         //print time now
00540         //times(&now);
00541 
00542         FILE *fp = fopen("solve.log", "a");
00543         if(fp==NULL){
00544                 cerr << "can't open logfile\n";
00545                 exit(1);
00546         }
00547 
00548 
00549         fprintf(fp,"%d ",index);
00550 
00551         //print current trial number, num of backups
00552         int numTrials = ((SampleBP*)sampleEngine)->numTrials;
00553         //int numBackups = numBackups;
00554         fprintf(fp,"%d ",numTrials);                    //SYLTEMP FOR EXPTS
00555         //printf("#Trial %d, #Backup %d ",numTrials, numBackups); 
00556 
00557         //print #alpha vectors
00558         int numAlphas = 0;
00559         FOR (setIdx, beliefCacheSet.size()) 
00560         {
00561                 //cout << " p : " << setIdx << " : " <<          (int)bounds->boundsSet[setIdx]->alphaPlanePool->planes.size();
00562                 numAlphas += (int)lowerBoundSet->set[setIdx]->planes.size();
00563         }
00564 
00565         fprintf(fp, "%d ", numAlphas);                  //SYLTEMP FOR EXPTS
00566 
00567         double currentTime =0;
00568         if(this->solverParams->interval >0)
00569         {
00570                 //utime = ((float)(now.tms_utime-last.tms_utime))/clk_tck;
00571                 //stime = ((float)(now.tms_stime-last.tms_stime))/clk_tck;
00572                 //currentTime = elapsed+utime+stime;
00573                 currentTime = elapsed + lapTimer.elapsed();
00574                 fprintf(fp, "%.2f ", currentTime);                      //SYLTEMP FOR EXPTS
00575                 //printf("<%.2fs> ", currentTime);
00576         }
00577         else{
00578                 //utime = ((float)(now.tms_utime-start.tms_utime))/clk_tck;
00579                 //stime = ((float)(now.tms_stime-start.tms_stime))/clk_tck;
00580                 //currentTime = utime+stime;
00581                 currentTime = runtimeTimer.elapsed();
00582                 fprintf(fp, "%.2f ", currentTime);              //SYLTEMP FOR EXPTS
00583                 //printf("<%.2fs> ", currentTime);      
00584         }
00585 
00586         fprintf(fp,"\n"); 
00587 
00588         fclose(fp);
00589 }
00590 
00591 //ADD SYLTAG
00592 bool SARSOP::stopNow(){
00593         bool stop = false;
00594 
00595         double width = 0;
00596         BeliefForest& globalRoot  = *(sampleEngine->getGlobalNode());
00597 
00598         //find the weighted excess uncertainty at the global root
00599         //cycle through all the roots to find their bounds
00600         FOR (r, globalRoot.getGlobalRootNumSampleroots()) 
00601         {
00602                 SampleRootEdge* eR = globalRoot.sampleRootEdges[r];
00603                 if (NULL != eR) {
00604                         BeliefTreeNode & sn = *eR->sampleRoot;
00605                         double lbVal =  beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->LB;
00606                         double ubVal =  beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->UB;
00607                         width += eR->sampleRootProb * (ubVal - lbVal);
00608                 }
00609         }
00610 
00611         if(GlobalResource::getInstance()->userTerminatedG)
00612         {
00613                 stop = true;
00614         }
00615 
00616         if ((width) < this->solverParams->targetPrecision)
00617         {      
00618                 alwaysPrint();
00619                 printDivider();
00620                 printf("\nSARSOP finishing ...\n");
00621                 printf("  target precision reached\n");
00622                 printf("  target precision  : %f\n", this->solverParams->targetPrecision);
00623                 printf("  precision reached : %f \n", width);
00624 
00625                 stop = true;
00626         }
00627         if (this->solverParams->timeoutSeconds > 0)
00628         {
00629                 if (elapsed > this->solverParams->timeoutSeconds )
00630                 {
00631                         printDivider();
00632                         printf("\nSARSOP finishing ...\n");
00633                         printf("  Preset timeout reached\n");
00634                         printf("  Timeout     : %fs\n",  this->solverParams->timeoutSeconds );
00635                         printf("  Actual Time : %fs\n", elapsed);
00636                         stop = true;
00637                 }
00638         }
00639         return stop;
00640 }
00641 
00642 
00643 //REMOVE SYLTAG
00644 /*      bool SARSOP::stopNow(){
00645 bool stop = false;
00646 cacherow_stval rootIndex = sampleEngine->getRootNode()->cacheIndex;
00647 //int rootIndex = sampleEngine->getRootNode()->cacheIndex;
00648 
00649 //decide whether to stop or not depend on current root precision
00650 double lb = bounds->boundsSet[rootIndex.sval]->beliefCache->getRow(rootIndex.row)->LB;
00651 double ub = bounds->boundsSet[rootIndex.sval]->beliefCache->getRow(rootIndex.row)->UB;
00652 #if USE_DEBUG_PRINT
00653 printf("targetPrecision is %f, precision is %f\n", targetPrecision, ub-lb);
00654 #endif
00655 if(GlobalResource::getInstance()->userTerminatedG)
00656 {
00657 stop = true;
00658 }
00659 if ((ub-lb)<targetPrecision){      
00660 alwaysPrint();
00661 printf("Target precision reached: %f (%f)\n\n", ub-lb, targetPrecision);
00662 stop = true;
00663 }
00664 if (timeout > 0){
00665 if (elapsed > timeout){
00666 printf("Preset timeout reached %f (%fs)\n\n", elapsed, timeout);
00667 stop = true;
00668 }
00669 }
00670 return stop;
00671 }
00672 */
00673 
00674 void SARSOP::writeIntermediatePolicyTraceToFile(int trial, double time, const string& outFileName, string problemName)
00675 {
00676         stringstream newFileNameStream;
00677         string outputBasename = GlobalResource::parseBaseNameWithPath(outFileName);
00678         newFileNameStream << outputBasename << "_" << trial << "_" << time << ".policy";
00679         string newFileName = newFileNameStream.str();
00680         cout << "Writing policy file: " << newFileName << endl;
00681         writePolicy(newFileName, problemName);
00682 }
00683 
00684 
00685 BeliefTreeNode& SARSOP::getMaxExcessUncRoot(BeliefForest& globalroot) 
00686 {
00687 
00688         double maxExcessUnc = -99e+20;
00689         int maxExcessUncRoot = -1;
00690         double width;
00691         double lbVal, ubVal;
00692 
00693         FOR (r, globalroot.getGlobalRootNumSampleroots()) {
00694                 SampleRootEdge* eR = globalroot.sampleRootEdges[r];
00695                 if (NULL != eR) {
00696                         BeliefTreeNode & sn = *eR->sampleRoot;
00697                         lbVal = beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->LB;
00698                         ubVal = beliefCacheSet[sn.cacheIndex.sval]->getRow(sn.cacheIndex.row)->UB;
00699                         width = eR->sampleRootProb * (ubVal - lbVal);
00700 
00701                         if (width > maxExcessUnc)
00702                         {
00703                                 maxExcessUnc = width;
00704                                 maxExcessUncRoot = r;
00705                         }
00706                 }
00707 
00708         }
00709 
00710         return *(globalroot.sampleRootEdges[maxExcessUncRoot]->sampleRoot);
00711 
00712 }
00713 
00714 void SARSOP::backup(BeliefTreeNode* node)
00715 {
00716         upperBoundSet->backup(node);
00717         lowerBoundSet->backup(node);
00718 }
00719 
00720 void SARSOP::initialize(SharedPointer<MOMDP> problem)
00721 {
00722         printIndex = 0; // reset printing counter
00723 
00724         int xStateNum = problem->XStates->size();
00725         beliefCacheSet.resize(xStateNum);
00726         lbDataTableSet.resize(xStateNum);
00727         ubDataTableSet.resize(xStateNum);
00728 
00729         for(States::iterator iter = problem->XStates->begin(); iter != problem->XStates->end(); iter ++ )
00730         {
00731                 beliefCacheSet[iter.index()] = new BeliefCache();
00732                 lbDataTableSet[iter.index()] = new IndexedTuple<AlphaPlanePoolDataTuple>();
00733                 ubDataTableSet[iter.index()] = new IndexedTuple<BeliefValuePairPoolDataTuple>();
00734         }
00735 
00736         initializeUpperBound(problem);
00737 
00738         upperBoundSet->setBeliefCache(beliefCacheSet);
00739         upperBoundSet->setDataTable(ubDataTableSet);
00740 
00741         initializeLowerBound(problem);
00742         lowerBoundSet->setBeliefCache(beliefCacheSet);
00743         lowerBoundSet->setDataTable(lbDataTableSet);
00744 
00745         initializeBounds(this->solverParams->targetPrecision);
00746         initSampleEngine(problem);
00747 
00748         pruneEngine = new SARSOPPrune(this);
00749 
00750 }
00751 void SARSOP::initSampleEngine(SharedPointer<MOMDP> problem)
00752 {
00753         sampleEngine->appendOnGetNodeHandler(&SARSOP::onGetNode);
00754         binManagerSet = new BinManagerSet(upperBoundSet);
00755         ((SampleBP*)sampleEngine)->setBinManager(binManagerSet);
00756         ((SampleBP*)sampleEngine)->setRandomization(solverParams->randomizationBP);
00757 
00758 }
00759 void SARSOP::initializeUpperBound(SharedPointer<MOMDP> problem)
00760 {
00761         upperBoundSet = new BeliefValuePairPoolSet(upperBoundBackup);
00762         upperBoundSet->setProblem(problem);
00763         upperBoundSet->setSolver(this);
00764         upperBoundSet->initialize();
00765         upperBoundSet->appendOnBackupHandler(&SARSOP::onUpperBoundBackup);
00766         ((BackupBeliefValuePairMOMDP*)upperBoundBackup)->boundSet = upperBoundSet;
00767 }
00768 void SARSOP::initializeLowerBound(SharedPointer<MOMDP> problem)
00769 {
00770         lowerBoundSet = new AlphaPlanePoolSet(lowerBoundBackup);
00771         lowerBoundSet->setProblem(problem);
00772         lowerBoundSet->setSolver(this);
00773         lowerBoundSet->initialize();
00774         lowerBoundSet->appendOnBackupHandler(&SARSOP::onLowerBoundBackup);
00775         lowerBoundSet->appendOnBackupHandler(&SARSOPPrune::onLowerBoundBackup);
00776         ((BackupAlphaPlaneMOMDP* )lowerBoundBackup)->boundSet = lowerBoundSet;
00777 }
00778 
00779 void SARSOP::initializeBounds(double _targetPrecision)
00780 {
00781         double targetPrecision = _targetPrecision * CB_INITIALIZATION_PRECISION_FACTOR;
00782 
00783         CPTimer heurTimer;
00784         heurTimer.start();      // for timing heuristics
00785         BlindLBInitializer blb(problem, lowerBoundSet);
00786         blb.initialize(targetPrecision);
00787         elapsed = heurTimer.elapsed();
00788 
00789         DEBUG_LOG( cout << fixed << setprecision(2) << elapsed << "s blb.initialize(targetPrecision) done" << endl; );
00790 
00791         heurTimer.restart();
00792         FastInfUBInitializer fib(problem, upperBoundSet); 
00793         fib.initialize(targetPrecision);
00794         elapsed = heurTimer.elapsed();
00795         DEBUG_LOG(cout << fixed << setprecision(2) << elapsed << "s fib.initialize(targetPrecision) done" << endl;);
00796 
00797         FOR (state_idx, problem->XStates->size()) 
00798         {
00799                 upperBoundSet->set[state_idx]->cornerPointsVersion++;   // advance the version by one so that next time get value will calculate rather than skip
00800         }
00801 
00802         //cout << "finished setting cornerPointsVersion" << endl;
00803 
00804         numBackups = 0;
00805 }//end method initialize
00806 
00807 cacherow_stval SARSOP::backup(list<cacherow_stval> beliefNStates)
00808 {
00809         //decide the order of backups 
00810         cacherow_stval rowNState, nextRowNState;
00811         nextRowNState.row = -1;
00812 
00813         //for each belief given, we perform backup for it
00814         LISTFOREACH(cacherow_stval, beliefNState,  beliefNStates) 
00815         {
00816                 //get belief
00817                 rowNState = *beliefNState;
00818                 nextRowNState = backup(rowNState);
00819         }//end FOR_EACH
00820 
00821         // prevly:
00822         //for each belief given, we perform backup for it
00823         /* LISTFOREACH(int, belief,  beliefs) {
00824         //get belief
00825         row = *belief;
00826         nextRow = backup(row);
00827         }//end FOR_EACH */
00828 
00829 
00830 
00831         if(nextRowNState.row== -1)
00832         {
00833                 printf("Error: backup list empty\n");
00834                 cout << "In SARSOP::backup( )" << endl;
00835         }
00836         return nextRowNState;
00837 }//end method: backup
00838 
00839 
00840 cacherow_stval SARSOP::backupLBonly(list<cacherow_stval> beliefNStates){
00841         //decide the order of backups 
00842         cacherow_stval rowNState, nextRowNState;
00843         nextRowNState.row = -1;
00844 
00845         //for each belief given, we perform backup for it
00846         LISTFOREACH(cacherow_stval, beliefNState,  beliefNStates) {
00847                 //get belief
00848                 rowNState = *beliefNState;
00849                 nextRowNState = backupLBonly(rowNState);
00850         }//end FOR_EACH
00851 
00852 
00853         if(nextRowNState.row== -1){
00854                 printf("Error: backup list empty\n");
00855                 cout << "In SARSOP::backupLBonly( )" << endl;
00856         }
00857         return nextRowNState;
00858 }//end method: backup
00859 
00860 //Function: backup
00861 //Functionality:
00862 //      do backup at a single belief
00863 //Parameters:
00864 //      row: the row index of the to-be-backuped belief in BeliefCache
00865 //Returns:
00866 //      row index of the belief as the starting point for sampling
00867 cacherow_stval SARSOP::backup(cacherow_stval beliefNState)
00868 {
00869 
00870         unsigned int stateidx = beliefNState.sval;
00871         int row = beliefNState.row;
00872 
00873         //cout << "in SARSOP::backup(), beliefNState.sval : " << beliefNState.sval << " beliefNState.row : " << beliefNState.row << endl;
00874 
00875         //do belief propogation if the belief is a fringe node in tree
00876         BeliefTreeNode* cn = beliefCacheSet[stateidx]->getRow(row)->REACHABLE;
00877         //should we use BeliefTreeNode or should we use row index?
00878         // TEMP, should move all the time stamp code to Global Resource
00879 
00880         //SYL220210 the very first backup should have timestamp of 1, so we increment the timestamp first 
00881         numBackups++;
00882         
00883         GlobalResource::getInstance()->setTimeStamp(numBackups);
00884         lowerBoundSet->backup(cn);
00885         upperBoundSet->backup(cn);
00886         //numBackups++;
00887         GlobalResource::getInstance()->setTimeStamp(numBackups);
00888         return beliefNState;
00889 }
00890 
00891 cacherow_stval SARSOP::backupLBonly(cacherow_stval beliefNState){
00892 
00893         unsigned int stateidx = beliefNState.sval;
00894         int row = beliefNState.row;
00895 
00896         //cout << "in SARSOP::backup(), beliefNState.sval : " << beliefNState.sval << " beliefNState.row : " << beliefNState.row << endl;
00897 
00898         //do belief propogation if the belief is a fringe node in tree
00899         BeliefTreeNode* cn = beliefCacheSet[stateidx]->getRow(row)->REACHABLE;
00900         //should we use BeliefTreeNode or should we use row index?
00901         // TEMP, should move all the time stamp code to Global Resource
00902         GlobalResource::getInstance()->setTimeStamp(numBackups);
00903         lowerBoundSet->backup(cn);
00904         //bounds->backupUpperBoundBVpair->backup(*cn);
00905         numBackups++;
00906         GlobalResource::getInstance()->setTimeStamp(numBackups);
00907         return beliefNState;
00908 }
00909 
00910 void SARSOP::writePolicy(string fileName, string problemName)
00911 {
00912         writeToFile(fileName, problemName);
00913 }
00914 
00915 void SARSOP::writeToFile(const std::string& outFileName, string problemName) 
00916 {
00917         lowerBoundSet->writeToFile(outFileName, problemName);
00918 
00919 }
00920 
00921 void SARSOP::printHeader(){
00922     cout << endl;
00923     printDivider();
00924     cout << " Time   |#Trial |#Backup |LBound    |UBound    |Precision  |#Alphas |#Beliefs  " << endl;
00925     printDivider();
00926 }
00927 
00928 void SARSOP::printDivider(){
00929     cout << "-------------------------------------------------------------------------------" << endl;
00930 }


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29