PrioritizedSweeping.cc
Go to the documentation of this file.
00001 #include "PrioritizedSweeping.hh"
00002 #include <algorithm>
00003 
00004 //#include <time.h>
00005 #include <sys/time.h>
00006 
00007 
00008 PrioritizedSweeping::PrioritizedSweeping(int numactions, float gamma,
00009                                          float MAX_TIME, bool onlyAddLastSA,  int modelType,
00010                                          const std::vector<float> &fmax, 
00011                                          const std::vector<float> &fmin, 
00012                                          Random r):
00013   numactions(numactions), gamma(gamma), MAX_TIME(MAX_TIME),
00014   onlyAddLastSA(onlyAddLastSA),  modelType(modelType)
00015 {
00016   rng = r;
00017   nstates = 0;
00018   nactions = 0;
00019 
00020   timingType = false; //true;
00021 
00022   model = NULL;
00023   planTime = getSeconds();
00024 
00025   // algorithm options
00026   MAX_STEPS = 10; //50; //60; //80; //0; //5; //10;
00027 
00028   lastModelUpdate = -1;
00029 
00030   PLANNERDEBUG = false;
00031   POLICYDEBUG = false; //true; //false;
00032   ACTDEBUG = false; //true;
00033   MODELDEBUG = false; //true;
00034   LISTDEBUG = false; // true; //false;
00035 
00036   featmax = fmax;
00037   featmin = fmin;
00038 
00039 }
00040 
00041 PrioritizedSweeping::~PrioritizedSweeping() {}
00042 
00043 void PrioritizedSweeping::setModel(MDPModel* m){
00044 
00045   model = m;
00046 
00047 }
00048 
00049 
00051 // Functional functions :) //
00053 
00054 
00055 void PrioritizedSweeping::initNewState(state_t s){
00056   if (PLANNERDEBUG) cout << "initNewState(s = " << s
00057                          << ") size = " << s->size() << endl;
00058 
00059   if (MODELDEBUG) cout << "New State: " << endl;
00060 
00061   // create state info and add to hash map
00062   state_info* info = &(statedata[s]);
00063   initStateInfo(info);
00064 
00065   // init these from model
00066   for (int i = 0; i < numactions; i++){
00067     model->getStateActionInfo(*s, i, &(info->modelInfo[i]));
00068   }
00069 
00070   // we have to make sure q-values are initialized properly
00071   // or we'll get bizarre results (if these aren't swept over)
00072   for (int j = 0; j < numactions; j++){
00073     // update q values
00074     updateQValues(*s, j);
00075     info->Q[j] += rng.uniform(0,0.01);
00076   }
00077 
00078   if (PLANNERDEBUG) cout << "done with initNewState()" << endl;
00079 
00080 }
00081 
00083 bool PrioritizedSweeping::updateModelWithExperience(const std::vector<float> &laststate,
00084                                                     int lastact,
00085                                                     const std::vector<float> &currstate,
00086                                                     float reward, bool term){
00087   if (PLANNERDEBUG) cout << "updateModelWithExperience(last = " << &laststate
00088                          << ", curr = " << &currstate
00089                          << ", lastact = " << lastact
00090                          << ", r = " << reward
00091                          << ")" << endl;
00092 
00093   if (!timingType)
00094     planTime = getSeconds();
00095 
00096   // canonicalize these things
00097   state_t last = canonicalize(laststate);
00098   state_t curr = canonicalize(currstate);
00099 
00100   prevstate = laststate;
00101   prevact = lastact;
00102 
00103   // if not transition to terminal
00104   if (curr == NULL)
00105     return false;
00106 
00107   // get state info
00108   state_info* info = &(statedata[last]);
00109 
00110   // update the state visit count
00111   info->visits[lastact]++;
00112 
00113   // init model?
00114   if (model == NULL){
00115     cout << "ERROR IN MODEL OR MODEL SIZE" << endl;
00116     exit(-1);
00117   }
00118 
00119   experience e;
00120   e.s = *last;
00121   e.next = *curr;
00122   e.act = lastact;
00123   e.reward = reward;
00124   e.terminal = term;
00125   bool modelChanged = model->updateWithExperience(e);
00126 
00127   if (PLANNERDEBUG) cout << "Added exp: " << modelChanged << endl;
00128   if (timingType)
00129     planTime = getSeconds();
00130 
00131   return modelChanged;
00132 
00133 }
00134 
00135 
00136 
00138 void PrioritizedSweeping::updateStatesFromModel(){
00139   if (PLANNERDEBUG || LISTDEBUG) cout << "updateStatesFromModel()" << endl;
00140 
00141   // for each state
00142   for (std::set<std::vector<float> >::iterator i = statespace.begin();
00143        i != statespace.end(); i++){
00144 
00145     for (int j = 0; j < numactions; j++){
00146       updateStateActionFromModel(*i, j);
00147     }
00148 
00149   }
00150 
00151 }
00152 
00153 
00154 
00155 
00157 int PrioritizedSweeping::getBestAction(const std::vector<float> &state){
00158   if (PLANNERDEBUG) cout << "getBestAction(s = " << &state
00159                          << ")" << endl;
00160 
00161   state_t s = canonicalize(state);
00162 
00163   // get state info
00164   state_info* info = &(statedata[s]);
00165 
00166   // Get Q values
00167   std::vector<float> &Q = info->Q;
00168 
00169   // Choose an action
00170   const std::vector<float>::iterator a =
00171     random_max_element(Q.begin(), Q.end()); // Choose maximum
00172 
00173   int act = a - Q.begin();
00174   float val = *a;
00175 
00176   if (ACTDEBUG){
00177     cout << endl << "chooseAction State " << (*s)[0] << "," << (*s)[1]
00178          << " act: " << act << " val: " << val << endl;
00179     for (int iAct = 0; iAct < numactions; iAct++){
00180       cout << " Action: " << iAct
00181            << " val: " << Q[iAct]
00182            << " visits: " << info->visits[iAct]
00183            << " modelsAgree: " << info->modelInfo[iAct].known << endl;
00184     }
00185   }
00186 
00187   nactions++;
00188 
00189   // return index of action
00190   return act;
00191 }
00192 
00193 
00194 
00195 void PrioritizedSweeping::planOnNewModel(){
00196 
00197   // update model info
00198 
00199   // print state
00200   if (PLANNERDEBUG){
00201     cout << endl << endl << "Before update" << endl << endl;
00202     printStates();
00203   }
00204 
00205   // tabular - can just update last state-action from model.
00206   if (false && modelType == RMAX){
00207     updateStateActionFromModel(prevstate, prevact);
00208   }
00209   else {
00210     updateStatesFromModel();
00211   }
00212 
00213   // just add last state action (this is normal prioritized sweeping).
00214   // if bool was false, will have checked for differences and added them in update above
00215   if (onlyAddLastSA || modelType == RMAX){
00216     float diff = updateQValues(prevstate, prevact);
00217     addSAToList(prevstate, prevact, diff);
00218   }
00219 
00220   if (PLANNERDEBUG){
00221     cout << endl << endl << "After update" << endl << endl;
00222     printStates();
00223   }
00224 
00225   // run value iteration
00226   createPolicy();
00227 
00228 }
00229 
00230 
00232 // Helper Functions       //
00234 
00235 PrioritizedSweeping::state_t PrioritizedSweeping::canonicalize(const std::vector<float> &s) {
00236   if (PLANNERDEBUG) cout << "canonicalize(s = " << s[0] << ", "
00237                          << s[1] << ")" << endl;
00238 
00239   // get state_t for pointer if its in statespace
00240   const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00241     statespace.insert(s);
00242   state_t retval = &*result.first; // Dereference iterator then get pointer
00243 
00244   if (PLANNERDEBUG) cout << " returns " << retval
00245                          << " New: " << result.second << endl;
00246 
00247   // if not, init this new state
00248   if (result.second) { // s is new, so initialize Q(s,a) for all a
00249     initNewState(retval);
00250     if (PLANNERDEBUG) cout << " New state initialized" << endl;
00251   }
00252 
00253 
00254   return retval;
00255 }
00256 
00257 // init state info
00258 void PrioritizedSweeping::initStateInfo(state_info* info){
00259   if (PLANNERDEBUG) cout << "initStateInfo()";
00260 
00261   info->id = nstates++;
00262   if (PLANNERDEBUG) cout << " id = " << info->id << endl;
00263 
00264   info->fresh = true;
00265 
00266   // model data (transition, reward, known)
00267   info->modelInfo = new StateActionInfo[numactions];
00268 
00269   // model q values, visit counts
00270   info->visits.resize(numactions, 0);
00271   info->Q.resize(numactions, 0);
00272   info->lastUpdate.resize(numactions, nactions);
00273 
00274   for (int i = 0; i < numactions; i++){
00275     info->Q[i] = rng.uniform(0,1);
00276   }
00277 
00278   if (PLANNERDEBUG) cout << "done with initStateInfo()" << endl;
00279 
00280 }
00281 
00282 
00284 void PrioritizedSweeping::printStates(){
00285 
00286   for (std::set< std::vector<float> >::iterator i = statespace.begin();
00287        i != statespace.end(); i++){
00288 
00289     state_t s = canonicalize(*i);
00290 
00291     state_info* info = &(statedata[s]);
00292 
00293     cout << endl << "State " << info->id << ": ";
00294     for (unsigned j = 0; j < s->size(); j++){
00295       cout << (*s)[j] << ", ";
00296     }
00297     cout << endl;
00298 
00299     for (int act = 0; act < numactions; act++){
00300       cout << " visits[" << act << "] = " << info->visits[act]
00301            << " Q: " << info->Q[act]
00302            << " R: " << info->modelInfo[act].reward << endl;
00303 
00304       cout << "  Next states: " << endl;
00305       for (std::map<std::vector<float>, float>::iterator outIt
00306              = info->modelInfo[act].transitionProbs.begin();
00307            outIt != info->modelInfo[act].transitionProbs.end(); outIt++){
00308 
00309         std::vector<float> nextstate = (*outIt).first;
00310         float prob = (*outIt).second;
00311 
00312         cout << "   State ";
00313         for (unsigned k = 0; k < nextstate.size(); k++){
00314           cout << nextstate[k] << ", ";
00315         }
00316         cout << " prob: " << prob << endl;
00317 
00318       } // end of next states
00319 
00320     } // end of actions
00321 
00322     // print predecessors
00323     for (std::list<saqPair>::iterator x = info->pred.begin();
00324          x != info->pred.end(); x++){
00325 
00326       std::vector<float> s = (*x).s;
00327       int a = (*x).a;
00328 
00329       cout << "Has predecessor state: ";
00330       for (unsigned k = 0; k < s.size(); k++){
00331         cout << s[k] << ", ";
00332       }
00333       cout << " action: " << a << endl;
00334     }
00335 
00336   }
00337 }
00338 
00339 
00340 
00341 
00342 
00343 void PrioritizedSweeping::deleteInfo(state_info* info){
00344 
00345   delete [] info->modelInfo;
00346 
00347 }
00348 
00349 
00350 double PrioritizedSweeping::getSeconds(){
00351   struct timezone tz;
00352   timeval timeT;
00353   gettimeofday(&timeT, &tz);
00354   return  timeT.tv_sec + (timeT.tv_usec / 1000000.0);
00355 }
00356 
00357 
00359 void PrioritizedSweeping::createPolicy(){
00360   if (POLICYDEBUG) cout << endl << "createPolicy()" << endl;
00361 
00362   /*
00363   // loop through all states, add them all to queue with some high value.
00364   for (std::set<std::vector<float> >::iterator i = statespace.begin();
00365   i != statespace.end(); i++){
00366 
00367   saqPair saq;
00368   saq.s = *i;
00369   saq.q = 100.0;
00370 
00371   for (int j = 0; j < numactions; j++){
00372   saq.a = j;
00373 
00374   if (LISTDEBUG){
00375   cout << "Added state ";
00376   for (unsigned k = 0; k < saq.s.size(); k++){
00377   cout << saq.s[k] << ", ";
00378   }
00379   cout << " action: " << saq.a << endl;
00380   }
00381 
00382   priorityList.push_front(saq);
00383   }
00384   }
00385   */
00386 
00387   // add last state-action to priority list
00388   //addSAToList(prevstate, prevact, 100.0);
00389 
00390   int updates = 0;
00391 
00392   // go through queue, doing prioritized sweeping. until nothing left on queue.
00393   while (!priorityList.empty()){
00394 
00395     if ((getSeconds() - planTime) > MAX_TIME)
00396       break;
00397 
00398     // print list!
00399     if (LISTDEBUG){
00400       cout << endl << "Current List (" << updates << "):" << endl;
00401       for (std::list<saqPair>::iterator k = priorityList.begin(); k != priorityList.end(); k++){
00402         cout << "State: ";
00403         for (unsigned l = 0; l < (*k).s.size(); l++){
00404           cout << (*k).s[l] << ", ";
00405         }
00406         cout << " act: " << (*k).a << " Q: " << (*k).q << endl;
00407       }
00408     }
00409 
00410     updates++;
00411 
00412     // pull off first item
00413     saqPair currUpdate = priorityList.front();
00414     priorityList.pop_front();
00415 
00416     state_t s = canonicalize(currUpdate.s);
00417 
00418     // get state's info
00419     state_info* info = &(statedata[s]);
00420 
00421     updatePriorityList(info, *s);
00422 
00423   } // is list empty loop
00424 
00425 
00426   priorityList.clear();
00427 
00428   if (LISTDEBUG)
00429     cout << "priority list complete after updates to "
00430          << updates << " states." <<endl;
00431 
00432 }
00433 
00434 
00435 void PrioritizedSweeping::updatePriorityList(state_info* info,
00436                                              const std::vector<float> &next){
00437   if (LISTDEBUG) cout << "update priority list" << endl;
00438 
00439   float MIN_ERROR = 0.01;
00440 
00441   // find maxq at this state
00442   std::vector<float>::iterator maxAct =
00443     std::max_element(info->Q.begin(),
00444                      info->Q.end());
00445   float maxval = *maxAct;
00446 
00447   if (LISTDEBUG) cout << " maxQ at this state: " << maxval << endl;
00448 
00449   // loop through all s,a predicted to lead to this state
00450   for (std::list<saqPair>::iterator i = info->pred.begin();
00451        i != info->pred.end(); i++){
00452 
00453     if ((getSeconds() - planTime) > MAX_TIME)
00454       break;
00455 
00456     std::vector<float> s = (*i).s;
00457     int a = (*i).a;
00458 
00459     if (LISTDEBUG) {
00460       cout << endl << "  For predecessor state: ";
00461       for (unsigned j = 0; j < s.size(); j++){
00462         cout << s[j] << ", ";
00463       }
00464       cout << " action: " << a << endl;
00465     }
00466 
00467     // figure out amount of update
00468     float diff = updateQValues(s, a);
00469 
00470     if (LISTDEBUG) {
00471       cout << " diff: " << diff << endl;
00472     }
00473     // possibly add to queue
00474     if (diff > MIN_ERROR){
00475       saqPair saq;
00476       saq.s = s;
00477       saq.a = a;
00478       saq.q = diff;
00479 
00480       // find spot for it in queue
00481       if (priorityList.empty()){
00482         if (LISTDEBUG) cout << "  empty list" << endl;
00483         priorityList.push_front(saq);
00484       }
00485       else {
00486 
00487         // check that its not already in queue
00488         for (std::list<saqPair>::iterator k = priorityList.begin(); k != priorityList.end(); k++){
00489           // matched
00490           if (saqPairMatch(saq, *k)){
00491             if (LISTDEBUG)
00492               cout << "   found matching element already in list" << endl;
00493 
00494             priorityList.erase(k);
00495             break;
00496           }
00497 
00498         }
00499 
00500         int l = 0;
00501         std::list<saqPair>::iterator k;
00502         for (k = priorityList.begin(); k != priorityList.end(); k++){
00503           if (LISTDEBUG)
00504             cout << "    Element " << l << " has q value " << (*k).q << endl;
00505           if (diff > (*k).q){
00506             if (LISTDEBUG)
00507               cout << "   insert at " << l << endl;
00508             priorityList.insert(k, saq);
00509             break;
00510           }
00511           l++;
00512         }
00513         // put this at the end
00514         if (k == priorityList.end()){
00515           if (LISTDEBUG)
00516             cout << "   insert at end" << endl;
00517           priorityList.push_back(saq);
00518         }
00519 
00520       } // not empty
00521 
00522 
00523     } else {
00524       if (LISTDEBUG){
00525         cout << " Error " << diff << " not big enough to put on list." << endl;
00526       }
00527     }
00528   }
00529 }
00530 
00531 
00532 bool PrioritizedSweeping::saqPairMatch(saqPair a, saqPair b){
00533   if (a.a != b.a)
00534     return false;
00535 
00536   for (unsigned i = 0; i < a.s.size(); i++){
00537     if (a.s[i] != b.s[i])
00538       return false;
00539   }
00540 
00541   return true;
00542 }
00543 
00544 
00545 
00546 float PrioritizedSweeping::updateQValues(const std::vector<float> &state, int act){
00547 
00548   state_t s = canonicalize(state);
00549 
00550   // get state's info
00551   state_info* info = &(statedata[s]);
00552 
00553   // see if we should update mode for this state,action
00554   /*
00555     if (info->lastUpdate[act] < lastModelUpdate){
00556     if (LISTDEBUG) {
00557     cout << "Updating this state action. Last updated at "
00558     << info->lastUpdate[act]
00559     << " last model update: " << lastModelUpdate << endl;
00560     }
00561     updateStateActionFromModel(state, act);
00562     }
00563   */
00564 
00565   if (LISTDEBUG || POLICYDEBUG){
00566     cout << endl << " State: id: " << info->id << ": " ;
00567     for (unsigned si = 0; si < s->size(); si++){
00568       cout << (*s)[si] << ",";
00569     }
00570   }
00571 
00572   // get state action info for this action
00573   StateActionInfo *modelInfo = &(info->modelInfo[act]);
00574 
00575   if (LISTDEBUG || POLICYDEBUG)
00576     cout << "  Action: " << act
00577          << " State visits: " << info->visits[act] << endl;
00578 
00579   // Q = R + discounted val of next state
00580   // this is the R part :)
00581   float newQ = modelInfo->reward;
00582 
00583   float probSum = modelInfo->termProb;
00584 
00585   // for all next states, add discounted value appropriately
00586   // loop through next state's that are in this state-actions list
00587   for (std::map<std::vector<float>, float>::iterator outIt
00588          = modelInfo->transitionProbs.begin();
00589        outIt != modelInfo->transitionProbs.end(); outIt++){
00590 
00591     std::vector<float> nextstate = (*outIt).first;
00592 
00593     if (POLICYDEBUG){
00594       cout << "  Next state was: ";
00595       for (unsigned oi = 0; oi < nextstate.size(); oi++){
00596         cout << nextstate[oi] << ",";
00597       }
00598       cout << endl;
00599     }
00600 
00601     // get transition probability
00602     float transitionProb =  (1.0-modelInfo->termProb) *
00603       modelInfo->transitionProbs[nextstate];
00604 
00605     probSum += transitionProb;
00606 
00607     if (POLICYDEBUG)
00608       cout << "   prob: " << transitionProb << endl;
00609 
00610     if (transitionProb < 0 || transitionProb > 1.0001){
00611       cout << "Error with transitionProb: " << transitionProb << endl;
00612       exit(-1);
00613     }
00614 
00615     // if there is some probability of this transition
00616     if (transitionProb > 0.0){
00617 
00618       // assume maxval of qmax if we don't know the state
00619       float maxval = 0.0;
00620 
00621       // make sure its a real state
00622       bool realState = true;
00623 
00624 
00625       for (unsigned b = 0; b < nextstate.size(); b++){
00626         if (nextstate[b] < (featmin[b]-EPSILON)
00627             || nextstate[b] > (featmax[b]+EPSILON)){
00628           realState = false;
00629           if (POLICYDEBUG)
00630             cout << "    Next state is not valid (feature "
00631                  << b << " out of range)" << endl;
00632           break;
00633         }
00634       }
00635 
00636 
00637 
00638       // update q values for any states within MAX_STEPS of visited states
00639       if (realState){
00640 
00641         state_t next = canonicalize(nextstate);
00642 
00643         state_info* nextinfo = &(statedata[next]);
00644         //nextinfo->fresh = false;
00645 
00646         // find the max value of this next state
00647         std::vector<float>::iterator maxAct =
00648           std::max_element(nextinfo->Q.begin(),
00649                            nextinfo->Q.end());
00650         maxval = *maxAct;
00651 
00652       } // within max steps
00653       else {
00654         maxval = 0.0;
00655         if (POLICYDEBUG){
00656           cout << "This state is too far away, state: ";
00657           for (unsigned si = 0; si < s->size(); si++){
00658             cout << (*s)[si] << ",";
00659           }
00660           cout << " Action: " << act << endl;
00661         }
00662       }
00663 
00664       nextstate.clear();
00665 
00666       if (POLICYDEBUG) cout << "    Max value: " << maxval << endl;
00667 
00668       // update q value with this value
00669       newQ += (gamma * transitionProb * maxval);
00670 
00671     } // transition probability > 0
00672 
00673   } // outcome loop
00674 
00675 
00676   if (probSum < 0.9999 || probSum > 1.0001){
00677     cout << "Error: transition probabilities do not add to 1: Sum: "
00678          << probSum << endl;
00679     exit(-1);
00680   }
00681 
00682 
00683   // set q value
00684   float tdError = fabs(info->Q[act] - newQ);
00685   if (LISTDEBUG || POLICYDEBUG) cout << "  NewQ: " << newQ
00686                                      << " OldQ: " << info->Q[act] << endl;
00687   info->Q[act] = newQ;
00688 
00689   return tdError;
00690 }
00691 
00692 void PrioritizedSweeping::addSAToList(const std::vector<float> &s, int act, float q){
00693 
00694   saqPair saq;
00695   saq.s = s;
00696   saq.a = act;
00697   saq.q = q;
00698 
00699   if (LISTDEBUG){
00700     cout << "Added state ";
00701     for (unsigned k = 0; k < saq.s.size(); k++){
00702       cout << saq.s[k] << ", ";
00703     }
00704     cout << " action: " << saq.a
00705          << " value: " << saq.q << endl;
00706   }
00707 
00708   priorityList.push_front(saq);
00709 
00710 }
00711 
00712 
00713 
00715 void PrioritizedSweeping::updateStateActionFromModel(const std::vector<float> &state, int a){
00716 
00717   if ((getSeconds() - planTime) > MAX_TIME)
00718     return;
00719 
00720   state_t s = canonicalize(state);
00721 
00722   // get state's info
00723   state_info* info = &(statedata[s]);
00724 
00725   int j = a;
00726 
00727   // get updated model
00728   model->getStateActionInfo(*s, j, &(info->modelInfo[j]));
00729   info->lastUpdate[j] = nactions;
00730 
00731   if (info->modelInfo[j].termProb >= 1.0)
00732     return;
00733 
00734   // go through next states, for each one, add self to predecessor list
00735   for (std::map<std::vector<float>, float>::iterator outIt
00736          = info->modelInfo[j].transitionProbs.begin();
00737        outIt != info->modelInfo[j].transitionProbs.end(); outIt++){
00738 
00739     std::vector<float> nextstate = (*outIt).first;
00740     state_t next = canonicalize(nextstate);
00741     state_info* nextinfo = &(statedata[next]);
00742     //float prob = (*outIt).second;
00743 
00744     if (LISTDEBUG){
00745       cout << "State ";
00746       for (unsigned k = 0; k < nextstate.size(); k++){
00747         cout << nextstate[k] << ", ";
00748       }
00749       cout << " has predecessor: ";
00750       for (unsigned k = 0; k < nextstate.size(); k++){
00751         cout << (*s)[k] << ", ";
00752       }
00753       cout << " action: " << j << endl;
00754     }
00755 
00756     saqPair saq;
00757     saq.s = *s;
00758     saq.a = j;
00759     saq.q = 0.0;
00760 
00761     // add to list
00762     // check that its not already here
00763     bool nothere = true;
00764     for (std::list<saqPair>::iterator k = nextinfo->pred.begin();
00765          k != nextinfo->pred.end(); k++){
00766       if (saqPairMatch(saq, *k)){
00767         nothere = false;
00768         break;
00769       }
00770     }
00771     if (nothere)
00772       nextinfo->pred.push_front(saq);
00773 
00774   }
00775 
00776   info->fresh = false;
00777 
00778   //if (PLANNERDEBUG || LISTDEBUG) cout << " updateStatesFromModel i = " << &i << " complete" << endl;
00779 
00780 }
00781 
00782 


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13