00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "SarsaActionSelector.h"
00012 #include "DefaultActionValue.h"
00013 #include "RewardFunction.h"
00014
00015 #include <actasp/AspKR.h>
00016 #include <actasp/AspFluent.h>
00017
00018 #define ROSOUTPUT 0
00019 #if ROSOUTPUT
00020 #include <ros/console.h>
00021 #else
00022 #include <iostream>
00023
00024
00025
00026
00027 #define ROS_DEBUG( X )
00028 #define ROS_DEBUG_STREAM( X )
00029 #define ROS_INFO_STREAM( X )
00030 #endif
00031
00032
00033 #include <algorithm>
00034 #include <iterator>
00035 #include <sstream>
00036 #include <fstream>
00037 #include <iostream>
00038 #include <ctime>
00039
00040 #define FILTER true
00041
00042
00043 using namespace std;
00044 using namespace actasp;
00045
00046 namespace bwi_krexec {
00047
00048 SarsaActionSelector::SarsaActionSelector(actasp::FilteringKR* reasoner, DefaultActionValue *defval,
00049 RewardFunction<State>*reward, const SarsaParams& p) :
00050 reasoner(reasoner), defval(defval), p(p), reward(reward),value(),e(),
00051 initial(), final(), previousAction("nopreviousaction(0)"), v_s(0), policy(NULL) {}
00052
00053 struct CompareValues {
00054
00055 CompareValues(SarsaActionSelector::ActionValueMap& value) : value(value) {}
00056
00057 bool operator()(const AspFluent& first, const AspFluent& second) {
00058
00059
00060
00061
00062
00063 return value[first] < value[second];
00064
00065 }
00066
00067 SarsaActionSelector::ActionValueMap& value;
00068 };
00069
00070 actasp::ActionSet::const_iterator SarsaActionSelector::choose(const actasp::ActionSet& options) throw() {
00071
00072 stringstream ss;
00073 ss << "Evaluating options: ";
00074 copy(options.begin(), options.end(), ostream_iterator<string>(ss, " "));
00075 ss << endl;
00076
00077 AnswerSet currentState = reasoner->currentStateQuery(vector<AspRule>());
00078
00079 set<AspFluent> currentSet(currentState.getFluents().begin(), currentState.getFluents().end());
00080 set<AspFluent> stateFluents;
00081
00082
00083 if (FILTER && policy != NULL) {
00084
00085
00086
00087 set<AspFluent> &filtered = notFilteredToFiltered[currentSet];
00088 if (!filtered.empty()) {
00089 stateFluents = filtered;
00090 } else {
00091 vector<AnswerSet> plansFromHere = policy->plansFrom(currentSet);
00092 AnswerSet filteredCurrentState = reasoner->filterState(plansFromHere, goalRules);
00093 set<AspFluent> temp(filteredCurrentState.getFluents().begin(), filteredCurrentState.getFluents().end());
00094 stateFluents = temp;
00095 filtered = stateFluents;
00096 }
00097
00098 }
00099
00100 else {
00101 stateFluents = currentSet;
00102 }
00103
00104
00105
00106
00107
00108
00109
00110 ActionSet::const_iterator optIt = options.begin();
00111 for (; optIt != options.end(); ++optIt) {
00112
00113 ActionValueMap &thisState = value[stateFluents];
00114
00115 if (thisState.find(*optIt) == thisState.end()) {
00116
00117 thisState[*optIt] = defval->value(*optIt);
00118 }
00119 ss << value[stateFluents][*optIt] << " ";
00120 }
00121
00122 ROS_DEBUG(ss.str());
00123
00124 ROS_INFO_STREAM(ss.str());
00125
00126 double prob = p.epsilon;
00127
00128 ActionSet::const_iterator chosen;
00129
00130 if (rand() <= prob * RAND_MAX) {
00131
00132 chosen = options.begin();
00133 advance(chosen, rand() % options.size());
00134 } else {
00135
00136 chosen = max_element(options.begin(), options.end(),CompareValues(value[stateFluents]));
00137 }
00138
00139
00140 if (!(initial.empty() || final.empty())) {
00141
00142
00143
00144 double v_s_prime = value[final][*chosen];
00145 updateValue(v_s_prime);
00146
00147 initial.clear();
00148 final.clear();
00149
00150 }
00151
00152
00153
00154 return chosen;
00155
00156 }
00157
00158 void printE(SarsaActionSelector::StateActionMap &e) {
00159 cout << "--- E table ---" << endl;
00160 SarsaActionSelector::StateActionMap::iterator state = e.begin();
00161 for (; state != e.end(); ++state) {
00162 SarsaActionSelector::ActionValueMap::iterator action = state->second.begin();
00163 for (; action != state->second.end(); ++action) {
00164 copy(state->first.begin(), state->first.end(), ostream_iterator<string>(cout , " "));
00165 cout << endl << action->second << " " << action->first.toString() << endl;
00166 }
00167 }
00168 cout << "---" << endl;
00169 }
00170
00171 void SarsaActionSelector::updateValue(double v_s_prime) {
00172
00173
00174
00175
00176 double rew = reward->r(initial,previousAction,final);
00177
00178 double delta = rew + p.gamma * v_s_prime - v_s;
00179
00180
00181 double &e_current = e[initial][previousAction];
00182
00183 e_current = p.alpha + (1 - p.alpha) *(p.gamma * p.lambda * e_current);
00184
00185
00186
00187 double &v_current = value[initial][previousAction];
00188 v_current += delta * e_current + p.alpha * (v_s - v_current);
00189
00190
00191 StateActionMap::iterator state = e.begin();
00192
00193 for (; state != e.end(); ++state) {
00194
00195 bool currentState = state->first == initial;
00196
00197 ActionValueMap::iterator action = state->second.begin();
00198
00199 for (; action != state->second.end(); ++action) {
00200
00201 if (!currentState || !ActionEquality()(action->first,previousAction)) {
00202
00203
00204
00205 double &e_trace = e[state->first][action->first];
00206
00207 e_trace *= p.lambda * p.gamma;
00208
00209
00210 double &v = value[state->first][action->first];
00211 v += delta * e_trace;
00212
00213 }
00214
00215 }
00216
00217 }
00218
00219
00220
00221 v_s = v_s_prime;
00222
00223 }
00224
00225
00226
00227 void SarsaActionSelector::policyChanged(PartialPolicy* newPolicy) throw() {
00228 policy = dynamic_cast<GraphPolicy*>(newPolicy);
00229 if(policy == NULL)
00230 throw runtime_error("the new policy is not a GraphPolicy, SarsaActionSelector cannot continue");
00231 }
00232
00233 void SarsaActionSelector::goalChanged(std::vector<actasp::AspRule> newGoalRules) throw() {
00234 goalRules = newGoalRules;
00235 if (!(newGoalRules == goalRules)) {
00236 notFilteredToFiltered.clear();
00237 }
00238 }
00239 bool SarsaActionSelector::stateCompare(const std::set<actasp::AspFluent> state, const std::set<actasp::AspFluent> otherstate) {
00240 if (state.size() != otherstate.size()) {
00241 return false;
00242 }
00243 std::set<actasp::AspFluent>::const_iterator thisIt = state.begin();
00244 std::set<actasp::AspFluent>::const_iterator otherIt = otherstate.begin();
00245 for (; thisIt!=state.end(); ++thisIt) {
00246 std::string thisstring = thisIt->toString(0);
00247 std::string otherstring = otherIt->toString(0);
00248 if (thisstring.compare(otherstring)!=0) {
00249 return false;
00250 }
00251 ++otherIt;
00252 }
00253 return true;
00254 }
00255
00256
00257 void SarsaActionSelector::actionStarted(const AspFluent&) throw() {
00258
00259 AnswerSet state = reasoner->currentStateQuery(vector<AspRule>());
00260 initialNotFiltered.clear();
00261 initialNotFiltered.insert(state.getFluents().begin(), state.getFluents().end());
00262 initial.clear();
00263
00264 if (FILTER && policy != NULL) {
00265
00266 set<AspFluent> &filtered = notFilteredToFiltered[initialNotFiltered];
00267 if (!filtered.empty()) {
00268 initial = filtered;
00269 } else {
00270 std::vector<actasp::AnswerSet> plansFromHere = policy->plansFrom(initialNotFiltered);
00271 AnswerSet filteredState = reasoner->filterState(plansFromHere, goalRules);
00272 initial.insert(filteredState.getFluents().begin(), filteredState.getFluents().end());
00273 filtered = initial;
00274 }
00275
00276 }
00277
00278 else {
00279 initial.insert(state.getFluents().begin(), state.getFluents().end());
00280 }
00281 }
00282
00283
00284 void SarsaActionSelector::actionTerminated(const AspFluent& action) throw() {
00285
00286 if (final.empty()) {
00287 ActionValueMap &initState = value[initial];
00288 if (initState.find(action) == initState.end()) {
00289
00290 value[initial][action] = defval->value(action);
00291 }
00292 v_s = value[initial][action];
00293 }
00294
00295
00296 AnswerSet state = reasoner->currentStateQuery(vector<AspRule>());
00297 finalNotFiltered.clear();
00298 finalNotFiltered.insert(state.getFluents().begin(), state.getFluents().end());
00299 final.clear();
00300
00301 if (FILTER && policy != NULL) {
00302 set<AspFluent> &filtered = notFilteredToFiltered[finalNotFiltered];
00303 if (!filtered.empty()) {
00304 final = filtered;
00305 } else {
00306
00307
00308
00309
00310
00311
00312
00313 std::vector<actasp::AnswerSet> plansFromHere = policy->plansFrom(finalNotFiltered);
00314 AnswerSet filteredState = reasoner->filterState(plansFromHere, goalRules);
00315 final.insert(filteredState.getFluents().begin(), filteredState.getFluents().end());
00316 filtered = final;
00317
00318 }
00319
00320 if (final.empty()) {
00321 if (reasoner->currentStateQuery(goalRules).isSatisfied()) {
00322 final.insert(state.getFluents().begin(), state.getFluents().end());
00323 }
00324 }
00325
00326 }
00327
00328 else {
00329 final.insert(state.getFluents().begin(), state.getFluents().end());
00330 }
00331
00332 previousAction = action;
00333
00334 }
00335
00336 void SarsaActionSelector::episodeEnded() throw() {
00337
00338
00339 if (!initial.empty()) {
00340
00341 updateValue(0.);
00342 }
00343
00344 e.clear();
00345 initial.clear();
00346 final.clear();
00347 previousAction = AspFluent("nopreviousaction(0)");
00348 v_s = 0;
00349 }
00350
00351 void SarsaActionSelector::saveValueInitialState(const std::string& fileName) {
00352 ofstream initialValue(fileName.c_str(), ofstream::app);
00353
00354 AnswerSet initialAnswerSet= reasoner->currentStateQuery(vector<AspRule>());
00355 State initialState(initialAnswerSet.getFluents().begin(), initialAnswerSet.getFluents().end());
00356
00357 if (FILTER && policy != NULL) {
00358
00359 set<AspFluent> &filtered = notFilteredToFiltered[initialState];
00360 if (!filtered.empty()) {
00361 initialState.clear();
00362 initialState = filtered;
00363 } else {
00364 std::vector<actasp::AnswerSet> plansFromHere = policy->plansFrom(initialState);
00365 AnswerSet filteredState = reasoner->filterState(plansFromHere, goalRules);
00366 initialState.clear();
00367 initialState.insert(filteredState.getFluents().begin(), filteredState.getFluents().end());
00368 filtered = initial;
00369 }
00370
00371 }
00372
00373
00374
00375 ActionValueMap &initial_value_map = value[initialState];
00376 ActionValueMap::iterator action_value = initial_value_map.begin();
00377
00378 time_t rawtime;
00379 struct tm * timeinfo;
00380 char time_string[10];
00381 time(&rawtime);
00382 timeinfo = localtime(&rawtime);
00383 strftime(time_string,10,"%R",timeinfo);
00384
00385 stringstream actionNames;
00386
00387 actionNames << time_string << " ";
00388
00389 if (FILTER && policy != NULL)
00390 initialValue << "filtered state: ";
00391 else
00392 initialValue << "state: ";
00393
00394 std::set< actasp::AspFluent> state_to_print = initialState;
00395 for (std::set< actasp::AspFluent>::iterator it = state_to_print.begin(); it != state_to_print.end(); ++it) {
00396 initialValue << it->toString() << " ";
00397 }
00398 initialValue << endl << "action: ";
00399
00400 for (; action_value != initial_value_map.end(); ++action_value) {
00401 initialValue << action_value->second << " ";
00402 actionNames << action_value->first.toString() << " ";
00403 }
00404 initialValue << actionNames.str() << endl << endl;
00405 initialValue.close();
00406 }
00407
00408
00409 void SarsaActionSelector::readFrom(std::istream & fromStream) throw() {
00410
00411 ROS_DEBUG("Loading value function");
00412
00413 const string whiteSpaces(" \t");
00414
00415 value.clear();
00416
00417 while (fromStream.good() && !fromStream.eof()) {
00418
00419 string stateLine;
00420 getline(fromStream,stateLine);
00421
00422 size_t firstChar = min(stateLine.find_first_of(whiteSpaces),static_cast<size_t>(0));
00423 size_t lastChar = min(stateLine.find_last_not_of(whiteSpaces),stateLine.size());
00424 stateLine = stateLine.substr(firstChar,lastChar-firstChar+1);
00425
00426 stringstream stateStream(stateLine);
00427
00428
00429
00430 if (stateLine.empty())
00431 return;
00432
00433 State state;
00434 copy(istream_iterator<string>(stateStream), istream_iterator<string>(), inserter(state, state.begin()));
00435
00436
00437 string actionLine;
00438 getline(fromStream,actionLine);
00439
00440 while (actionLine.find("-----") == string::npos) {
00441
00442 size_t firstChar = min(actionLine.find_first_of(whiteSpaces), static_cast<size_t>(0));
00443 size_t lastChar = min(actionLine.find_last_not_of(whiteSpaces),actionLine.size());
00444 actionLine = actionLine.substr(firstChar,lastChar-firstChar+1);
00445
00446 stringstream actionStream(actionLine);
00447
00448 double actionValue;
00449 string fluentString;
00450
00451 actionStream >> actionValue >> fluentString;
00452
00453 AspFluent action(fluentString);
00454
00455 value[state].insert(make_pair(action,actionValue));
00456
00457 getline(fromStream,actionLine);
00458 }
00459
00460 }
00461 }
00462
00463 void SarsaActionSelector::writeTo(std::ostream & toStream) throw() {
00464
00465 ROS_DEBUG("Storing value function");
00466
00467 StateActionMap::const_iterator stateIt = value.begin();
00468
00469 ofstream stat("stats.txt", ios::app);
00470 AspFluent initialState("pos(10,0,0)");
00471
00472
00473
00474 for (; stateIt != value.end(); ++stateIt) {
00475
00476
00477
00478
00479
00480
00481
00482
00483 if (stateIt->first.find(initialState) != stateIt->first.end()) {
00484 ActionValueMap::const_iterator actionIt= stateIt->second.begin();
00485 for (; actionIt != stateIt->second.end(); ++actionIt) {
00486 if (stateIt->first.find(initialState) != stateIt->first.end())
00487 stat << value[stateIt->first][actionIt->first] << " ";
00488 }
00489
00490 }
00491
00492
00493 copy(stateIt->first.begin(), stateIt->first.end(), ostream_iterator<string>(toStream, " "));
00494
00495
00496 ActionValueMap::const_iterator actionIt = stateIt->second.begin();
00497 for (; actionIt != stateIt->second.end(); ++actionIt) {
00498
00499
00500
00501
00502
00503 toStream << endl;
00504
00505
00506
00507 toStream << actionIt->second << " " << actionIt->first.toString();
00508
00509 }
00510
00511
00512
00513
00514 toStream << endl << "-----" << endl;
00515 }
00516 stat << endl;
00517 stat.close();
00518
00519
00520
00521 }
00522
00523
00524
00525 void SarsaActionSelector::readMapFrom(std::istream & fromStream) throw() {
00526
00527 }
00528
00529 void SarsaActionSelector::writeMapTo(std::ostream & toStream) throw() {
00530
00531 }
00532
00533
00534 }