SarsaActionSelector.cpp
Go to the documentation of this file.
00001 
00002 
00003 /* This algorithm is an implementation of
00004  *
00005  * Seijen, Harm V., and Rich Sutton. "True Online TD (lambda)."
00006  * Proceedings of the 31st International Conference on Machine Learning (ICML-14). 2014.
00007  *
00008  */
00009 
00010 
00011 #include "SarsaActionSelector.h"
00012 #include "DefaultActionValue.h"
00013 #include "RewardFunction.h"
00014 
00015 #include <actasp/AspKR.h>
00016 #include <actasp/AspFluent.h>
00017 
00018 #define ROSOUTPUT 0
00019 #if ROSOUTPUT
00020 #include <ros/console.h>
00021 #else
00022 #include <iostream>
00023 // #define ROS_DEBUG( X ) std::cout << "* "<< X << std::endl;
00024 // #define ROS_DEBUG_STREAM( X ) std::cout << "* " << X << std::endl;
00025 // #define ROS_INFO_STREAM( X ) std::cout << "** " << X<< std::endl ;
00026 
00027 #define ROS_DEBUG( X )
00028 #define ROS_DEBUG_STREAM( X )
00029 #define ROS_INFO_STREAM( X )
00030 #endif
00031 
00032 
00033 #include <algorithm>
00034 #include <iterator>
00035 #include <sstream>
00036 #include <fstream>
00037 #include <iostream>
00038 #include <ctime>
00039 
00040 #define FILTER true
00041 
00042 
00043 using namespace std;
00044 using namespace actasp;
00045 
00046 namespace bwi_krexec {
00047 
00048 SarsaActionSelector::SarsaActionSelector(actasp::FilteringKR* reasoner, DefaultActionValue *defval,
00049     RewardFunction<State>*reward, const SarsaParams& p) :
00050   reasoner(reasoner), defval(defval), p(p), reward(reward),value(),e(),
00051   initial(), final(), previousAction("nopreviousaction(0)"), v_s(0), policy(NULL) {}
00052 
00053 struct CompareValues {
00054 
00055   CompareValues(SarsaActionSelector::ActionValueMap& value) : value(value) {}
00056 
00057   bool operator()(const AspFluent& first, const AspFluent& second) {
00058     //if (value[first] == value[second]) {
00059     //choose whether true or false randomly to remove alphabetical bias .. this also good in grid but not robots
00060     //  return (rand() < 0.5*RAND_MAX);
00061     //}
00062     //else {
00063     return value[first] < value[second];
00064     //}
00065   }
00066 
00067   SarsaActionSelector::ActionValueMap& value;
00068 };
00069 
00070 actasp::ActionSet::const_iterator SarsaActionSelector::choose(const actasp::ActionSet& options) throw() {
00071 
00072   stringstream ss;
00073   ss << "Evaluating options: ";
00074   copy(options.begin(), options.end(), ostream_iterator<string>(ss, " "));
00075   ss << endl;
00076 
00077   AnswerSet currentState = reasoner->currentStateQuery(vector<AspRule>());
00078 
00079   set<AspFluent> currentSet(currentState.getFluents().begin(), currentState.getFluents().end()); //not filtered
00080   set<AspFluent> stateFluents; //filtered
00081 
00082 
00083   if (FILTER && policy != NULL) {
00084 
00085     // check if State "currentSet" is in the notFilteredToFiltered map ..
00086     // map is cleared when goal changes, so must be valid if there is.
00087     set<AspFluent> &filtered = notFilteredToFiltered[currentSet];
00088     if (!filtered.empty()) { //there is already the filtered state in the map
00089       stateFluents = filtered;
00090     } else { //not on the map, need to compute filtered state
00091       vector<AnswerSet> plansFromHere = policy->plansFrom(currentSet);
00092       AnswerSet filteredCurrentState = reasoner->filterState(plansFromHere, goalRules);
00093       set<AspFluent> temp(filteredCurrentState.getFluents().begin(), filteredCurrentState.getFluents().end());
00094       stateFluents = temp;
00095       filtered = stateFluents;
00096     }
00097 
00098   } //end of if filter
00099 
00100   else { // no filter
00101     stateFluents = currentSet;
00102   }
00103 
00104   // cout << "state: ";
00105   // set<AspFluent>::iterator printing = stateFluents.begin();
00106   // for (; printing != stateFluents.end(); ++printing)
00107   //   cout << printing->toString() << " ";
00108   // cout << "  ";
00109 
00110   ActionSet::const_iterator optIt = options.begin();
00111   for (; optIt != options.end(); ++optIt) {
00112 
00113     ActionValueMap &thisState = value[stateFluents];
00114 
00115     if (thisState.find(*optIt) == thisState.end()) {
00116       //initialize to default
00117       thisState[*optIt] = defval->value(*optIt);
00118     }
00119     ss << value[stateFluents][*optIt] << " ";
00120   }
00121 
00122   ROS_DEBUG(ss.str());
00123 
00124   ROS_INFO_STREAM(ss.str());
00125 
00126   double prob = p.epsilon;
00127 
00128   ActionSet::const_iterator chosen;
00129 
00130   if (rand() <= prob * RAND_MAX) { //random
00131     // cout << " r ";
00132     chosen =  options.begin();
00133     advance(chosen, rand() % options.size());
00134   } else {
00135     // cout << " c ";
00136     chosen = max_element(options.begin(), options.end(),CompareValues(value[stateFluents]));
00137   }
00138 
00139 
00140   if (!(initial.empty() || final.empty())) {
00141 
00142     //we have a full state, action ,reward, state, action sequence!
00143 
00144     double v_s_prime = value[final][*chosen];
00145     updateValue(v_s_prime);
00146 
00147     initial.clear();
00148     final.clear();
00149 
00150   }
00151 
00152   // cout << "action: " << chosen->toString() << endl;
00153 
00154   return chosen;
00155 
00156 }
00157 
00158 void printE(SarsaActionSelector::StateActionMap &e) {
00159   cout << "--- E table ---" << endl;
00160   SarsaActionSelector::StateActionMap::iterator state = e.begin();
00161   for (; state != e.end(); ++state) {
00162     SarsaActionSelector::ActionValueMap::iterator action = state->second.begin();
00163     for (; action != state->second.end(); ++action) {
00164       copy(state->first.begin(), state->first.end(), ostream_iterator<string>(cout , " "));
00165       cout << endl << action->second << " " << action->first.toString() << endl;
00166     }
00167   }
00168   cout << "---" << endl;
00169 }
00170 
00171 void SarsaActionSelector::updateValue(double v_s_prime) {
00172 
00173 
00174 //   printE(e);
00175 
00176   double rew = reward->r(initial,previousAction,final);
00177 
00178   double delta = rew + p.gamma * v_s_prime - v_s;
00179 
00180   //set the elegibility trace of the current state-action pair
00181   double &e_current = e[initial][previousAction];
00182 //   cout << "E " << e_current << " ";
00183   e_current =  p.alpha + (1 - p.alpha) *(p.gamma * p.lambda * e_current);
00184 //   cout << e_current <<endl;
00185 
00186   //set the value function for the current state-action pair
00187   double &v_current = value[initial][previousAction];
00188   v_current += delta * e_current + p.alpha * (v_s - v_current);
00189 
00190   //change every other state along the eligibility trace
00191   StateActionMap::iterator state = e.begin();
00192 
00193   for (; state != e.end(); ++state) {
00194 
00195     bool currentState = state->first == initial;
00196 
00197     ActionValueMap::iterator action = state->second.begin();
00198 
00199     for (; action != state->second.end(); ++action) {
00200 
00201       if (!currentState || !ActionEquality()(action->first,previousAction)) {
00202 
00203         //the current state action pair has alredy been delt with
00204 
00205         double &e_trace = e[state->first][action->first];
00206 //         cout << "e " << e_trace << " ";
00207         e_trace *= p.lambda * p.gamma;
00208 //         cout << e_trace << endl;
00209 
00210         double &v = value[state->first][action->first];
00211         v += delta * e_trace;
00212 
00213       }
00214 
00215     }
00216 
00217   }
00218 
00219 //   printE(e);
00220 
00221   v_s = v_s_prime;
00222 
00223 }
00224 
00225 
00226 //used for filterstate:
00227 void SarsaActionSelector::policyChanged(PartialPolicy* newPolicy) throw() {
00228   policy = dynamic_cast<GraphPolicy*>(newPolicy); //update
00229   if(policy == NULL)
00230     throw runtime_error("the new policy is not a GraphPolicy, SarsaActionSelector cannot continue");
00231 }
00232 
00233 void SarsaActionSelector::goalChanged(std::vector<actasp::AspRule> newGoalRules) throw() {
00234   goalRules = newGoalRules; //update
00235   if (!(newGoalRules == goalRules)) {
00236     notFilteredToFiltered.clear(); //not valid anymore
00237   }
00238 }
00239 bool SarsaActionSelector::stateCompare(const std::set<actasp::AspFluent> state, const std::set<actasp::AspFluent> otherstate) {
00240   if (state.size() != otherstate.size()) {
00241     return false;
00242   }
00243   std::set<actasp::AspFluent>::const_iterator thisIt = state.begin();
00244   std::set<actasp::AspFluent>::const_iterator otherIt = otherstate.begin();
00245   for (; thisIt!=state.end(); ++thisIt) {
00246     std::string thisstring = thisIt->toString(0);
00247     std::string otherstring = otherIt->toString(0);
00248     if (thisstring.compare(otherstring)!=0) { //different
00249       return false;
00250     }
00251     ++otherIt;
00252   }
00253   return true;
00254 }
00255 
00256 
00257 void SarsaActionSelector::actionStarted(const AspFluent&) throw() {
00258 
00259   AnswerSet state = reasoner->currentStateQuery(vector<AspRule>());
00260   initialNotFiltered.clear();
00261   initialNotFiltered.insert(state.getFluents().begin(), state.getFluents().end());
00262   initial.clear();
00263 
00264   if (FILTER && policy != NULL) {
00265 
00266     set<AspFluent> &filtered = notFilteredToFiltered[initialNotFiltered];
00267     if (!filtered.empty()) { //there is already the filtered state in the map
00268       initial = filtered;
00269     } else {
00270       std::vector<actasp::AnswerSet> plansFromHere = policy->plansFrom(initialNotFiltered);
00271       AnswerSet filteredState = reasoner->filterState(plansFromHere, goalRules);
00272       initial.insert(filteredState.getFluents().begin(), filteredState.getFluents().end());
00273       filtered = initial;
00274     }
00275 
00276   } // end of if filter
00277 
00278   else { // no filter
00279     initial.insert(state.getFluents().begin(), state.getFluents().end());
00280   }
00281 }
00282 
00283 
00284 void SarsaActionSelector::actionTerminated(const AspFluent& action) throw() {
00285 
00286   if (final.empty()) { //we have the first state-action pair, we can initialize v_s
00287     ActionValueMap &initState = value[initial];
00288     if (initState.find(action) == initState.end()) {
00289       //use the default value
00290       value[initial][action] = defval->value(action);
00291     }
00292     v_s = value[initial][action];
00293   }
00294 
00295 
00296   AnswerSet state = reasoner->currentStateQuery(vector<AspRule>());
00297   finalNotFiltered.clear();
00298   finalNotFiltered.insert(state.getFluents().begin(), state.getFluents().end());
00299   final.clear();
00300 
00301   if (FILTER && policy != NULL) {
00302     set<AspFluent> &filtered = notFilteredToFiltered[finalNotFiltered]; //check if we already filtered this.
00303     if (!filtered.empty()) { //there is already the filtered state in the map
00304       final = filtered;
00305     } else {
00306       // this optimization is useful in grid environment, but not really in the robots
00307       //set<AspFluent> expected = policy.nextExpected(initialNotFiltered,action); //first, check if the not filtered state was expected.
00308       //if (stateCompare(expected, finalNotFiltered)) { //final is as expected by policy, so final can be derived from initial
00309       //final = reasoner->actionEffects(action, initial);
00310       //filtered = final;
00311       //}
00312       //else { //really need to compute filtered..
00313       std::vector<actasp::AnswerSet> plansFromHere = policy->plansFrom(finalNotFiltered);
00314       AnswerSet filteredState = reasoner->filterState(plansFromHere, goalRules);
00315       final.insert(filteredState.getFluents().begin(), filteredState.getFluents().end());
00316       filtered = final;
00317       //}
00318     }
00319 
00320     if (final.empty()) { //added to avoid seg fault at goal..
00321       if (reasoner->currentStateQuery(goalRules).isSatisfied()) {
00322         final.insert(state.getFluents().begin(), state.getFluents().end());
00323       }
00324     }
00325 
00326   } // end of if filter
00327 
00328   else {  // no filter
00329     final.insert(state.getFluents().begin(), state.getFluents().end());
00330   }
00331 
00332   previousAction = action;
00333 
00334 }
00335 
00336 void SarsaActionSelector::episodeEnded() throw() {
00337 
00338 
00339   if (!initial.empty()) {
00340     //update the last state-action pair
00341     updateValue(0.);
00342   }
00343 
00344   e.clear();
00345   initial.clear();
00346   final.clear();
00347   previousAction = AspFluent("nopreviousaction(0)");
00348   v_s = 0;
00349 }
00350 
00351 void SarsaActionSelector::saveValueInitialState(const std::string& fileName) {
00352   ofstream initialValue(fileName.c_str(), ofstream::app);
00353 
00354   AnswerSet initialAnswerSet= reasoner->currentStateQuery(vector<AspRule>());
00355   State initialState(initialAnswerSet.getFluents().begin(), initialAnswerSet.getFluents().end());
00356 
00357   if (FILTER && policy != NULL) {
00358 
00359     set<AspFluent> &filtered = notFilteredToFiltered[initialState];
00360     if (!filtered.empty()) { //there is already the filtered state in the map
00361       initialState.clear();
00362       initialState = filtered;
00363     } else {
00364       std::vector<actasp::AnswerSet> plansFromHere = policy->plansFrom(initialState);
00365       AnswerSet filteredState = reasoner->filterState(plansFromHere, goalRules);
00366       initialState.clear();
00367       initialState.insert(filteredState.getFluents().begin(), filteredState.getFluents().end());
00368       filtered = initial;
00369     }
00370 
00371   } // end of if filter
00372 
00373   //else nothing, keep the query state
00374 
00375   ActionValueMap &initial_value_map = value[initialState];
00376   ActionValueMap::iterator action_value = initial_value_map.begin();
00377 
00378   time_t rawtime;
00379   struct tm * timeinfo;
00380   char time_string[10];
00381   time(&rawtime);
00382   timeinfo = localtime(&rawtime);
00383   strftime(time_string,10,"%R",timeinfo);
00384 
00385   stringstream actionNames;
00386 
00387   actionNames << time_string << " ";
00388 
00389   if (FILTER && policy != NULL)
00390     initialValue << "filtered state: ";
00391   else
00392     initialValue << "state: ";
00393 
00394   std::set< actasp::AspFluent> state_to_print = initialState;
00395   for (std::set< actasp::AspFluent>::iterator it = state_to_print.begin(); it != state_to_print.end(); ++it) {
00396     initialValue << it->toString() << " ";
00397   }
00398   initialValue << endl << "action: ";
00399 
00400   for (; action_value != initial_value_map.end(); ++action_value) {
00401     initialValue << action_value->second << " ";
00402     actionNames << action_value->first.toString() << " ";
00403   }
00404   initialValue << actionNames.str() << endl << endl;
00405   initialValue.close();
00406 }
00407 
00408 
00409 void SarsaActionSelector::readFrom(std::istream & fromStream) throw() {
00410 
00411   ROS_DEBUG("Loading value function");
00412 
00413   const string whiteSpaces(" \t");
00414 
00415   value.clear();
00416 
00417   while (fromStream.good() &&  !fromStream.eof()) {
00418 
00419     string stateLine;
00420     getline(fromStream,stateLine);
00421 
00422     size_t firstChar = min(stateLine.find_first_of(whiteSpaces),static_cast<size_t>(0));
00423     size_t lastChar = min(stateLine.find_last_not_of(whiteSpaces),stateLine.size());
00424     stateLine = stateLine.substr(firstChar,lastChar-firstChar+1);
00425 
00426     stringstream stateStream(stateLine);
00427 
00428 
00429 
00430     if (stateLine.empty())
00431       return;
00432 
00433     State state;
00434     copy(istream_iterator<string>(stateStream), istream_iterator<string>(), inserter(state, state.begin()));
00435 
00436 
00437     string actionLine;
00438     getline(fromStream,actionLine);
00439 
00440     while (actionLine.find("-----") == string::npos) {
00441 
00442       size_t firstChar = min(actionLine.find_first_of(whiteSpaces), static_cast<size_t>(0));
00443       size_t lastChar = min(actionLine.find_last_not_of(whiteSpaces),actionLine.size());
00444       actionLine = actionLine.substr(firstChar,lastChar-firstChar+1);
00445 
00446       stringstream actionStream(actionLine);
00447 
00448       double actionValue;
00449       string fluentString;
00450 
00451       actionStream >> actionValue >> fluentString;
00452 
00453       AspFluent action(fluentString);
00454 
00455       value[state].insert(make_pair(action,actionValue));
00456 
00457       getline(fromStream,actionLine);
00458     }
00459 
00460   }
00461 }
00462 
00463 void SarsaActionSelector::writeTo(std::ostream & toStream) throw() {
00464 
00465   ROS_DEBUG("Storing value function");
00466 
00467   StateActionMap::const_iterator stateIt = value.begin();
00468   //for each state
00469   ofstream stat("stats.txt", ios::app);
00470   AspFluent initialState("pos(10,0,0)");
00471 
00472   // cout << "value map: " << endl;
00473 
00474   for (; stateIt != value.end(); ++stateIt) {
00475 
00476     // cout << "state: ";
00477     // set<AspFluent>::iterator printing = stateIt->first.begin();
00478     // for (; printing != stateIt->first.end(); ++printing)
00479     //   cout << printing->toString() << " ";
00480     // cout << "  ";
00481 
00482 
00483     if (stateIt->first.find(initialState) != stateIt->first.end()) {
00484       ActionValueMap::const_iterator actionIt= stateIt->second.begin();
00485       for (; actionIt != stateIt->second.end(); ++actionIt) {
00486         if (stateIt->first.find(initialState) != stateIt->first.end())
00487           stat << value[stateIt->first][actionIt->first] << " ";
00488       }
00489 
00490     }
00491 
00492     //write the state in a line
00493     copy(stateIt->first.begin(), stateIt->first.end(), ostream_iterator<string>(toStream, " "));
00494 
00495     //for each action
00496     ActionValueMap::const_iterator actionIt = stateIt->second.begin();
00497     for (; actionIt != stateIt->second.end(); ++actionIt) {
00498 
00499       // cout << "action: " << actionIt->first.toString() << " ";
00500       // cout << "value: " << actionIt->second << "   ";
00501 
00502 
00503       toStream << endl;
00504 
00505       //write the value, then the action
00506 
00507       toStream <<  actionIt->second << " " << actionIt->first.toString();
00508 
00509     }
00510 
00511     // cout << endl;
00512 
00513     //a separator for the next state
00514     toStream << endl << "-----" << endl;
00515   }
00516   stat << endl;
00517   stat.close();
00518 
00519   // cout << endl << endl;
00520 
00521 }
00522 
00523 
00524 
00525 void SarsaActionSelector::readMapFrom(std::istream & fromStream) throw() {
00526 //coming soon
00527 }
00528 
00529 void SarsaActionSelector::writeMapTo(std::ostream & toStream) throw() {
00530 //coming soon
00531 }
00532 
00533 
00534 }


bwi_kr_execution
Author(s): Matteo Leonetti, Piyush Khandelwal
autogenerated on Thu Jun 6 2019 17:57:37