00001 #include "PrioritizedSweeping.hh"
00002 #include <algorithm>
00003
00004
00005 #include <sys/time.h>
00006
00007
00008 PrioritizedSweeping::PrioritizedSweeping(int numactions, float gamma,
00009 float MAX_TIME, bool onlyAddLastSA, int modelType,
00010 const std::vector<float> &fmax,
00011 const std::vector<float> &fmin,
00012 Random r):
00013 numactions(numactions), gamma(gamma), MAX_TIME(MAX_TIME),
00014 onlyAddLastSA(onlyAddLastSA), modelType(modelType)
00015 {
00016 rng = r;
00017 nstates = 0;
00018 nactions = 0;
00019
00020 timingType = false;
00021
00022 model = NULL;
00023 planTime = getSeconds();
00024
00025
00026 MAX_STEPS = 10;
00027
00028 lastModelUpdate = -1;
00029
00030 PLANNERDEBUG = false;
00031 POLICYDEBUG = false;
00032 ACTDEBUG = false;
00033 MODELDEBUG = false;
00034 LISTDEBUG = false;
00035
00036 featmax = fmax;
00037 featmin = fmin;
00038
00039 }
00040
00041 PrioritizedSweeping::~PrioritizedSweeping() {}
00042
00043 void PrioritizedSweeping::setModel(MDPModel* m){
00044
00045 model = m;
00046
00047 }
00048
00049
00051
00053
00054
00055 void PrioritizedSweeping::initNewState(state_t s){
00056 if (PLANNERDEBUG) cout << "initNewState(s = " << s
00057 << ") size = " << s->size() << endl;
00058
00059 if (MODELDEBUG) cout << "New State: " << endl;
00060
00061
00062 state_info* info = &(statedata[s]);
00063 initStateInfo(info);
00064
00065
00066 for (int i = 0; i < numactions; i++){
00067 model->getStateActionInfo(*s, i, &(info->modelInfo[i]));
00068 }
00069
00070
00071
00072 for (int j = 0; j < numactions; j++){
00073
00074 updateQValues(*s, j);
00075 info->Q[j] += rng.uniform(0,0.01);
00076 }
00077
00078 if (PLANNERDEBUG) cout << "done with initNewState()" << endl;
00079
00080 }
00081
00083 bool PrioritizedSweeping::updateModelWithExperience(const std::vector<float> &laststate,
00084 int lastact,
00085 const std::vector<float> &currstate,
00086 float reward, bool term){
00087 if (PLANNERDEBUG) cout << "updateModelWithExperience(last = " << &laststate
00088 << ", curr = " << &currstate
00089 << ", lastact = " << lastact
00090 << ", r = " << reward
00091 << ")" << endl;
00092
00093 if (!timingType)
00094 planTime = getSeconds();
00095
00096
00097 state_t last = canonicalize(laststate);
00098 state_t curr = canonicalize(currstate);
00099
00100 prevstate = laststate;
00101 prevact = lastact;
00102
00103
00104 if (curr == NULL)
00105 return false;
00106
00107
00108 state_info* info = &(statedata[last]);
00109
00110
00111 info->visits[lastact]++;
00112
00113
00114 if (model == NULL){
00115 cout << "ERROR IN MODEL OR MODEL SIZE" << endl;
00116 exit(-1);
00117 }
00118
00119 experience e;
00120 e.s = *last;
00121 e.next = *curr;
00122 e.act = lastact;
00123 e.reward = reward;
00124 e.terminal = term;
00125 bool modelChanged = model->updateWithExperience(e);
00126
00127 if (PLANNERDEBUG) cout << "Added exp: " << modelChanged << endl;
00128 if (timingType)
00129 planTime = getSeconds();
00130
00131 return modelChanged;
00132
00133 }
00134
00135
00136
00138 void PrioritizedSweeping::updateStatesFromModel(){
00139 if (PLANNERDEBUG || LISTDEBUG) cout << "updateStatesFromModel()" << endl;
00140
00141
00142 for (std::set<std::vector<float> >::iterator i = statespace.begin();
00143 i != statespace.end(); i++){
00144
00145 for (int j = 0; j < numactions; j++){
00146 updateStateActionFromModel(*i, j);
00147 }
00148
00149 }
00150
00151 }
00152
00153
00154
00155
00157 int PrioritizedSweeping::getBestAction(const std::vector<float> &state){
00158 if (PLANNERDEBUG) cout << "getBestAction(s = " << &state
00159 << ")" << endl;
00160
00161 state_t s = canonicalize(state);
00162
00163
00164 state_info* info = &(statedata[s]);
00165
00166
00167 std::vector<float> &Q = info->Q;
00168
00169
00170 const std::vector<float>::iterator a =
00171 random_max_element(Q.begin(), Q.end());
00172
00173 int act = a - Q.begin();
00174 float val = *a;
00175
00176 if (ACTDEBUG){
00177 cout << endl << "chooseAction State " << (*s)[0] << "," << (*s)[1]
00178 << " act: " << act << " val: " << val << endl;
00179 for (int iAct = 0; iAct < numactions; iAct++){
00180 cout << " Action: " << iAct
00181 << " val: " << Q[iAct]
00182 << " visits: " << info->visits[iAct]
00183 << " modelsAgree: " << info->modelInfo[iAct].known << endl;
00184 }
00185 }
00186
00187 nactions++;
00188
00189
00190 return act;
00191 }
00192
00193
00194
00195 void PrioritizedSweeping::planOnNewModel(){
00196
00197
00198
00199
00200 if (PLANNERDEBUG){
00201 cout << endl << endl << "Before update" << endl << endl;
00202 printStates();
00203 }
00204
00205
00206 if (false && modelType == RMAX){
00207 updateStateActionFromModel(prevstate, prevact);
00208 }
00209 else {
00210 updateStatesFromModel();
00211 }
00212
00213
00214
00215 if (onlyAddLastSA || modelType == RMAX){
00216 float diff = updateQValues(prevstate, prevact);
00217 addSAToList(prevstate, prevact, diff);
00218 }
00219
00220 if (PLANNERDEBUG){
00221 cout << endl << endl << "After update" << endl << endl;
00222 printStates();
00223 }
00224
00225
00226 createPolicy();
00227
00228 }
00229
00230
00232
00234
00235 PrioritizedSweeping::state_t PrioritizedSweeping::canonicalize(const std::vector<float> &s) {
00236 if (PLANNERDEBUG) cout << "canonicalize(s = " << s[0] << ", "
00237 << s[1] << ")" << endl;
00238
00239
00240 const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00241 statespace.insert(s);
00242 state_t retval = &*result.first;
00243
00244 if (PLANNERDEBUG) cout << " returns " << retval
00245 << " New: " << result.second << endl;
00246
00247
00248 if (result.second) {
00249 initNewState(retval);
00250 if (PLANNERDEBUG) cout << " New state initialized" << endl;
00251 }
00252
00253
00254 return retval;
00255 }
00256
00257
00258 void PrioritizedSweeping::initStateInfo(state_info* info){
00259 if (PLANNERDEBUG) cout << "initStateInfo()";
00260
00261 info->id = nstates++;
00262 if (PLANNERDEBUG) cout << " id = " << info->id << endl;
00263
00264 info->fresh = true;
00265
00266
00267 info->modelInfo = new StateActionInfo[numactions];
00268
00269
00270 info->visits.resize(numactions, 0);
00271 info->Q.resize(numactions, 0);
00272 info->lastUpdate.resize(numactions, nactions);
00273
00274 for (int i = 0; i < numactions; i++){
00275 info->Q[i] = rng.uniform(0,1);
00276 }
00277
00278 if (PLANNERDEBUG) cout << "done with initStateInfo()" << endl;
00279
00280 }
00281
00282
00284 void PrioritizedSweeping::printStates(){
00285
00286 for (std::set< std::vector<float> >::iterator i = statespace.begin();
00287 i != statespace.end(); i++){
00288
00289 state_t s = canonicalize(*i);
00290
00291 state_info* info = &(statedata[s]);
00292
00293 cout << endl << "State " << info->id << ": ";
00294 for (unsigned j = 0; j < s->size(); j++){
00295 cout << (*s)[j] << ", ";
00296 }
00297 cout << endl;
00298
00299 for (int act = 0; act < numactions; act++){
00300 cout << " visits[" << act << "] = " << info->visits[act]
00301 << " Q: " << info->Q[act]
00302 << " R: " << info->modelInfo[act].reward << endl;
00303
00304 cout << " Next states: " << endl;
00305 for (std::map<std::vector<float>, float>::iterator outIt
00306 = info->modelInfo[act].transitionProbs.begin();
00307 outIt != info->modelInfo[act].transitionProbs.end(); outIt++){
00308
00309 std::vector<float> nextstate = (*outIt).first;
00310 float prob = (*outIt).second;
00311
00312 cout << " State ";
00313 for (unsigned k = 0; k < nextstate.size(); k++){
00314 cout << nextstate[k] << ", ";
00315 }
00316 cout << " prob: " << prob << endl;
00317
00318 }
00319
00320 }
00321
00322
00323 for (std::list<saqPair>::iterator x = info->pred.begin();
00324 x != info->pred.end(); x++){
00325
00326 std::vector<float> s = (*x).s;
00327 int a = (*x).a;
00328
00329 cout << "Has predecessor state: ";
00330 for (unsigned k = 0; k < s.size(); k++){
00331 cout << s[k] << ", ";
00332 }
00333 cout << " action: " << a << endl;
00334 }
00335
00336 }
00337 }
00338
00339
00340
00341
00342
00343 void PrioritizedSweeping::deleteInfo(state_info* info){
00344
00345 delete [] info->modelInfo;
00346
00347 }
00348
00349
00350 double PrioritizedSweeping::getSeconds(){
00351 struct timezone tz;
00352 timeval timeT;
00353 gettimeofday(&timeT, &tz);
00354 return timeT.tv_sec + (timeT.tv_usec / 1000000.0);
00355 }
00356
00357
00359 void PrioritizedSweeping::createPolicy(){
00360 if (POLICYDEBUG) cout << endl << "createPolicy()" << endl;
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390 int updates = 0;
00391
00392
00393 while (!priorityList.empty()){
00394
00395 if ((getSeconds() - planTime) > MAX_TIME)
00396 break;
00397
00398
00399 if (LISTDEBUG){
00400 cout << endl << "Current List (" << updates << "):" << endl;
00401 for (std::list<saqPair>::iterator k = priorityList.begin(); k != priorityList.end(); k++){
00402 cout << "State: ";
00403 for (unsigned l = 0; l < (*k).s.size(); l++){
00404 cout << (*k).s[l] << ", ";
00405 }
00406 cout << " act: " << (*k).a << " Q: " << (*k).q << endl;
00407 }
00408 }
00409
00410 updates++;
00411
00412
00413 saqPair currUpdate = priorityList.front();
00414 priorityList.pop_front();
00415
00416 state_t s = canonicalize(currUpdate.s);
00417
00418
00419 state_info* info = &(statedata[s]);
00420
00421 updatePriorityList(info, *s);
00422
00423 }
00424
00425
00426 priorityList.clear();
00427
00428 if (LISTDEBUG)
00429 cout << "priority list complete after updates to "
00430 << updates << " states." <<endl;
00431
00432 }
00433
00434
00435 void PrioritizedSweeping::updatePriorityList(state_info* info,
00436 const std::vector<float> &next){
00437 if (LISTDEBUG) cout << "update priority list" << endl;
00438
00439 float MIN_ERROR = 0.01;
00440
00441
00442 std::vector<float>::iterator maxAct =
00443 std::max_element(info->Q.begin(),
00444 info->Q.end());
00445 float maxval = *maxAct;
00446
00447 if (LISTDEBUG) cout << " maxQ at this state: " << maxval << endl;
00448
00449
00450 for (std::list<saqPair>::iterator i = info->pred.begin();
00451 i != info->pred.end(); i++){
00452
00453 if ((getSeconds() - planTime) > MAX_TIME)
00454 break;
00455
00456 std::vector<float> s = (*i).s;
00457 int a = (*i).a;
00458
00459 if (LISTDEBUG) {
00460 cout << endl << " For predecessor state: ";
00461 for (unsigned j = 0; j < s.size(); j++){
00462 cout << s[j] << ", ";
00463 }
00464 cout << " action: " << a << endl;
00465 }
00466
00467
00468 float diff = updateQValues(s, a);
00469
00470 if (LISTDEBUG) {
00471 cout << " diff: " << diff << endl;
00472 }
00473
00474 if (diff > MIN_ERROR){
00475 saqPair saq;
00476 saq.s = s;
00477 saq.a = a;
00478 saq.q = diff;
00479
00480
00481 if (priorityList.empty()){
00482 if (LISTDEBUG) cout << " empty list" << endl;
00483 priorityList.push_front(saq);
00484 }
00485 else {
00486
00487
00488 for (std::list<saqPair>::iterator k = priorityList.begin(); k != priorityList.end(); k++){
00489
00490 if (saqPairMatch(saq, *k)){
00491 if (LISTDEBUG)
00492 cout << " found matching element already in list" << endl;
00493
00494 priorityList.erase(k);
00495 break;
00496 }
00497
00498 }
00499
00500 int l = 0;
00501 std::list<saqPair>::iterator k;
00502 for (k = priorityList.begin(); k != priorityList.end(); k++){
00503 if (LISTDEBUG)
00504 cout << " Element " << l << " has q value " << (*k).q << endl;
00505 if (diff > (*k).q){
00506 if (LISTDEBUG)
00507 cout << " insert at " << l << endl;
00508 priorityList.insert(k, saq);
00509 break;
00510 }
00511 l++;
00512 }
00513
00514 if (k == priorityList.end()){
00515 if (LISTDEBUG)
00516 cout << " insert at end" << endl;
00517 priorityList.push_back(saq);
00518 }
00519
00520 }
00521
00522
00523 } else {
00524 if (LISTDEBUG){
00525 cout << " Error " << diff << " not big enough to put on list." << endl;
00526 }
00527 }
00528 }
00529 }
00530
00531
00532 bool PrioritizedSweeping::saqPairMatch(saqPair a, saqPair b){
00533 if (a.a != b.a)
00534 return false;
00535
00536 for (unsigned i = 0; i < a.s.size(); i++){
00537 if (a.s[i] != b.s[i])
00538 return false;
00539 }
00540
00541 return true;
00542 }
00543
00544
00545
00546 float PrioritizedSweeping::updateQValues(const std::vector<float> &state, int act){
00547
00548 state_t s = canonicalize(state);
00549
00550
00551 state_info* info = &(statedata[s]);
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565 if (LISTDEBUG || POLICYDEBUG){
00566 cout << endl << " State: id: " << info->id << ": " ;
00567 for (unsigned si = 0; si < s->size(); si++){
00568 cout << (*s)[si] << ",";
00569 }
00570 }
00571
00572
00573 StateActionInfo *modelInfo = &(info->modelInfo[act]);
00574
00575 if (LISTDEBUG || POLICYDEBUG)
00576 cout << " Action: " << act
00577 << " State visits: " << info->visits[act] << endl;
00578
00579
00580
00581 float newQ = modelInfo->reward;
00582
00583 float probSum = modelInfo->termProb;
00584
00585
00586
00587 for (std::map<std::vector<float>, float>::iterator outIt
00588 = modelInfo->transitionProbs.begin();
00589 outIt != modelInfo->transitionProbs.end(); outIt++){
00590
00591 std::vector<float> nextstate = (*outIt).first;
00592
00593 if (POLICYDEBUG){
00594 cout << " Next state was: ";
00595 for (unsigned oi = 0; oi < nextstate.size(); oi++){
00596 cout << nextstate[oi] << ",";
00597 }
00598 cout << endl;
00599 }
00600
00601
00602 float transitionProb = (1.0-modelInfo->termProb) *
00603 modelInfo->transitionProbs[nextstate];
00604
00605 probSum += transitionProb;
00606
00607 if (POLICYDEBUG)
00608 cout << " prob: " << transitionProb << endl;
00609
00610 if (transitionProb < 0 || transitionProb > 1.0001){
00611 cout << "Error with transitionProb: " << transitionProb << endl;
00612 exit(-1);
00613 }
00614
00615
00616 if (transitionProb > 0.0){
00617
00618
00619 float maxval = 0.0;
00620
00621
00622 bool realState = true;
00623
00624
00625 for (unsigned b = 0; b < nextstate.size(); b++){
00626 if (nextstate[b] < (featmin[b]-EPSILON)
00627 || nextstate[b] > (featmax[b]+EPSILON)){
00628 realState = false;
00629 if (POLICYDEBUG)
00630 cout << " Next state is not valid (feature "
00631 << b << " out of range)" << endl;
00632 break;
00633 }
00634 }
00635
00636
00637
00638
00639 if (realState){
00640
00641 state_t next = canonicalize(nextstate);
00642
00643 state_info* nextinfo = &(statedata[next]);
00644
00645
00646
00647 std::vector<float>::iterator maxAct =
00648 std::max_element(nextinfo->Q.begin(),
00649 nextinfo->Q.end());
00650 maxval = *maxAct;
00651
00652 }
00653 else {
00654 maxval = 0.0;
00655 if (POLICYDEBUG){
00656 cout << "This state is too far away, state: ";
00657 for (unsigned si = 0; si < s->size(); si++){
00658 cout << (*s)[si] << ",";
00659 }
00660 cout << " Action: " << act << endl;
00661 }
00662 }
00663
00664 nextstate.clear();
00665
00666 if (POLICYDEBUG) cout << " Max value: " << maxval << endl;
00667
00668
00669 newQ += (gamma * transitionProb * maxval);
00670
00671 }
00672
00673 }
00674
00675
00676 if (probSum < 0.9999 || probSum > 1.0001){
00677 cout << "Error: transition probabilities do not add to 1: Sum: "
00678 << probSum << endl;
00679 exit(-1);
00680 }
00681
00682
00683
00684 float tdError = fabs(info->Q[act] - newQ);
00685 if (LISTDEBUG || POLICYDEBUG) cout << " NewQ: " << newQ
00686 << " OldQ: " << info->Q[act] << endl;
00687 info->Q[act] = newQ;
00688
00689 return tdError;
00690 }
00691
00692 void PrioritizedSweeping::addSAToList(const std::vector<float> &s, int act, float q){
00693
00694 saqPair saq;
00695 saq.s = s;
00696 saq.a = act;
00697 saq.q = q;
00698
00699 if (LISTDEBUG){
00700 cout << "Added state ";
00701 for (unsigned k = 0; k < saq.s.size(); k++){
00702 cout << saq.s[k] << ", ";
00703 }
00704 cout << " action: " << saq.a
00705 << " value: " << saq.q << endl;
00706 }
00707
00708 priorityList.push_front(saq);
00709
00710 }
00711
00712
00713
00715 void PrioritizedSweeping::updateStateActionFromModel(const std::vector<float> &state, int a){
00716
00717 if ((getSeconds() - planTime) > MAX_TIME)
00718 return;
00719
00720 state_t s = canonicalize(state);
00721
00722
00723 state_info* info = &(statedata[s]);
00724
00725 int j = a;
00726
00727
00728 model->getStateActionInfo(*s, j, &(info->modelInfo[j]));
00729 info->lastUpdate[j] = nactions;
00730
00731 if (info->modelInfo[j].termProb >= 1.0)
00732 return;
00733
00734
00735 for (std::map<std::vector<float>, float>::iterator outIt
00736 = info->modelInfo[j].transitionProbs.begin();
00737 outIt != info->modelInfo[j].transitionProbs.end(); outIt++){
00738
00739 std::vector<float> nextstate = (*outIt).first;
00740 state_t next = canonicalize(nextstate);
00741 state_info* nextinfo = &(statedata[next]);
00742
00743
00744 if (LISTDEBUG){
00745 cout << "State ";
00746 for (unsigned k = 0; k < nextstate.size(); k++){
00747 cout << nextstate[k] << ", ";
00748 }
00749 cout << " has predecessor: ";
00750 for (unsigned k = 0; k < nextstate.size(); k++){
00751 cout << (*s)[k] << ", ";
00752 }
00753 cout << " action: " << j << endl;
00754 }
00755
00756 saqPair saq;
00757 saq.s = *s;
00758 saq.a = j;
00759 saq.q = 0.0;
00760
00761
00762
00763 bool nothere = true;
00764 for (std::list<saqPair>::iterator k = nextinfo->pred.begin();
00765 k != nextinfo->pred.end(); k++){
00766 if (saqPairMatch(saq, *k)){
00767 nothere = false;
00768 break;
00769 }
00770 }
00771 if (nothere)
00772 nextinfo->pred.push_front(saq);
00773
00774 }
00775
00776 info->fresh = false;
00777
00778
00779
00780 }
00781
00782