QLearningActionSelector.cpp
Go to the documentation of this file.
00001 
00002 
00003 #include "QLearningActionSelector.h"
00004 
00005 #include "RewardFunction.h"
00006 
00007 #include <cstdlib>
00008 #include <iterator>
00009 #include <algorithm>
00010 
00011 #include <sstream>
00012 #include <iostream>
00013 #include <fstream>
00014 
00015 #define ROSOUTPUT 0
00016 #if ROSOUTPUT
00017 #include <ros/console.h>
00018 #else
00019 #include <iostream>
00020 // #define ROS_DEBUG( X ) std::cout << "* "<< X << std::endl;
00021 // #define ROS_DEBUG_STREAM( X ) std::cout << "* " << X << std::endl;
00022 // #define ROS_INFO_STREAM( X ) std::cout << "** " << X<< std::endl ;
00023 
00024 #define ROS_DEBUG( X )
00025 #define ROS_DEBUG_STREAM( X ) 
00026 #define ROS_INFO_STREAM( X ) 
00027 #endif
00028 #define EPSILON 0.1
00029 
00030 using namespace actasp;
00031 using namespace std;
00032 
00033 namespace bwi_krexec {
00034 
00035 struct CompareValues {
00036 
00037   CompareValues(QLearningActionSelector::ActionValueMap& value) : value(value) {}
00038 
00039   bool operator()(const AspFluent& first, const AspFluent& second) {
00040     return value[first] < value[second];
00041   }
00042 
00043   QLearningActionSelector::ActionValueMap& value;
00044 };
00045 
00046 QLearningActionSelector::QLearningActionSelector(double alpha, RewardFunction<State> *reward, 
00047                                                  actasp::AspKR *reasoner, DefaultActionValue *defval) :
00048   reasoner(reasoner),
00049   defval(defval),
00050   alpha(alpha),
00051   reward(reward),
00052   value(),
00053   initial(),
00054   final(),
00055   previousAction("noaction(0)"),
00056   count(0)  {}
00057 
00058 
00059 struct CompareSecond {
00060   bool operator()(const pair<AspFluent, double>& first, const pair<AspFluent, double>& second) {
00061     return first.second < second.second;
00062   }
00063 };
00064 
00065 actasp::ActionSet::const_iterator QLearningActionSelector::choose(const actasp::ActionSet& options) throw() {
00066 
00067   if (!(initial.empty() || final.empty())) {
00068 
00069     ActionValueMap::const_iterator bestValuePair = max_element(value[final].begin(), value[final].end(),CompareSecond());
00070     
00071     double bestValue = 0;
00072     if(bestValuePair != value[final].end())
00073       bestValue = bestValuePair->second;
00074 
00075     double rew = reward->r(initial,previousAction,final);
00076 
00077 
00078     ROS_INFO_STREAM("old value: " << value[initial][previousAction]);
00079     ROS_INFO_STREAM("reward: " << rew);
00080 
00081     value[initial][previousAction] = (1 - alpha) * value[initial][previousAction] + alpha * (rew + bestValue);
00082 
00083     ROS_INFO_STREAM("new value: " << value[initial][previousAction]);
00084 
00085     initial.clear();
00086     final.clear();
00087 
00088   }
00089 
00090 
00091   stringstream ss;
00092   ss << "Evaluating options: ";
00093   copy(options.begin(), options.end(), ostream_iterator<string>(ss, " "));
00094   ss << endl;
00095 
00096   AnswerSet currentState = reasoner->currentStateQuery(vector<AspRule>());
00097   State state(currentState.getFluents().begin(), currentState.getFluents().end());
00098 
00099   ActionSet::const_iterator optIt = options.begin();
00100   for (; optIt != options.end(); ++optIt) {
00101     ActionValueMap &thisState = value[state];
00102     
00103     if(thisState.find(*optIt) == thisState.end()) {
00104       //initialize to default
00105       thisState[*optIt] = defval->value(*optIt);
00106     }
00107     ss << value[state][*optIt] << " ";
00108   }
00109 
00110   ROS_INFO_STREAM(ss.str());
00111 
00112   double prob = EPSILON;
00113 
00114   if (rand() <= prob * RAND_MAX) { //random
00115     ActionSet::const_iterator chosen =  options.begin();
00116     advance(chosen, rand() % options.size());
00117 //     std::cout << "choosing random " << std::endl;
00118     return chosen;
00119   }
00120 
00121   actasp::ActionSet::const_iterator best = max_element(options.begin(), options.end(),CompareValues(value[state]));
00122 
00123   return best;
00124 
00125 }
00126 
00127 void QLearningActionSelector::actionStarted(const AspFluent&) throw() {
00128   initial.clear();
00129  
00130   AnswerSet currentState = reasoner->currentStateQuery(vector<AspRule>());
00131   initial.insert(currentState.getFluents().begin(), currentState.getFluents().end());
00132 }
00133 
00134 
00135 void QLearningActionSelector::actionTerminated(const AspFluent& action) throw() {
00136   AnswerSet currentState = reasoner->currentStateQuery(vector<AspRule>());
00137   final.clear();
00138   final.insert(currentState.getFluents().begin(), currentState.getFluents().end());
00139   previousAction = action;
00140 }
00141 
00142 void QLearningActionSelector::episodeEnded() {
00143   if(initial.empty())
00144     return;
00145     
00146   ROS_INFO_STREAM("old value: " << value[initial][previousAction]);
00147   value[initial][previousAction] = (1 - alpha) * value[initial][previousAction] + alpha * reward->r(initial,previousAction,final);
00148   ROS_INFO_STREAM("new value: " << value[initial][previousAction]);
00149 
00150   initial.clear();
00151   final.clear();
00152   ++count;
00153 }
00154 
00155 
00156 void QLearningActionSelector::readFrom(std::istream & fromStream) throw() {
00157 
00158   ROS_DEBUG("Loading value function");
00159 
00160   const string whiteSpaces(" \t");
00161 
00162   value.clear();
00163 
00164   while (fromStream.good() &&  !fromStream.eof()) {
00165 
00166     string stateLine;
00167     getline(fromStream,stateLine);
00168 
00169     size_t firstChar = min(stateLine.find_first_of(whiteSpaces),static_cast<size_t>(0));
00170     size_t lastChar = min(stateLine.find_last_not_of(whiteSpaces),stateLine.size());
00171     stateLine = stateLine.substr(firstChar,lastChar-firstChar+1);
00172 
00173     stringstream stateStream(stateLine);
00174 
00175 
00176 
00177     if (stateLine.empty())
00178       return;
00179 
00180     State state;
00181     copy(istream_iterator<string>(stateStream), istream_iterator<string>(), inserter(state, state.begin()));
00182 
00183 
00184     string actionLine;
00185     getline(fromStream,actionLine);
00186 
00187     while (actionLine.find("-----") == string::npos) {
00188 
00189       size_t firstChar = min(actionLine.find_first_of(whiteSpaces), static_cast<size_t>(0));
00190       size_t lastChar = min(actionLine.find_last_not_of(whiteSpaces),actionLine.size());
00191       actionLine = actionLine.substr(firstChar,lastChar-firstChar+1);
00192 
00193       stringstream actionStream(actionLine);
00194 
00195       double actionValue;
00196       string fluentString;
00197 
00198       actionStream >> actionValue >> fluentString;
00199 
00200       AspFluent action(fluentString);
00201 
00202       value[state].insert(make_pair(action,actionValue));
00203 
00204       getline(fromStream,actionLine);
00205     }
00206 
00207   }
00208 }
00209 
00210 
00211 void QLearningActionSelector::writeTo(std::ostream & toStream) throw() {
00212 
00213   ROS_DEBUG("Storing value function");
00214 
00215   StateActionMap::const_iterator stateIt = value.begin();
00216   //for each state
00217       ofstream stat("stats.txt", ios::app);
00218     AspFluent initialState("pos(2,0,0)");
00219   
00220   for (; stateIt != value.end(); ++stateIt) {
00221     
00222 
00223     if(stateIt->first.find(initialState) != stateIt->first.end()) {
00224         ActionValueMap::const_iterator actionIt= stateIt->second.begin();
00225         for (; actionIt != stateIt->second.end(); ++actionIt) {
00226            if(stateIt->first.find(initialState) != stateIt->first.end())
00227             stat << value[stateIt->first][actionIt->first] << " ";
00228         }
00229         
00230      }
00231      
00232 
00233 
00234     //write the state in a line
00235     copy(stateIt->first.begin(), stateIt->first.end(), ostream_iterator<string>(toStream, " "));
00236 
00237     //for each action
00238     ActionValueMap::const_iterator actionIt = stateIt->second.begin();
00239     for (; actionIt != stateIt->second.end(); ++actionIt) {
00240 
00241       toStream << endl;
00242 
00243       //write the value, then the action
00244 
00245       toStream <<  actionIt->second << " " << actionIt->first.toString();
00246 
00247     }
00248 
00249     //a separator for the next state
00250     toStream << endl << "-----" << endl;
00251   }
00252        stat << endl;
00253      stat.close();
00254 
00255 }
00256 }


bwi_kr_execution
Author(s): Matteo Leonetti, Piyush Khandelwal
autogenerated on Thu Jun 6 2019 17:57:37