00001 #include <rl_agent/Sarsa.hh>
00002 #include <algorithm>
00003
00004 Sarsa::Sarsa(int numactions, float gamma,
00005 float initialvalue, float alpha, float ep, float lambda,
00006 Random rng):
00007 numactions(numactions), gamma(gamma),
00008 initialvalue(initialvalue), alpha(alpha),
00009 epsilon(ep), lambda(lambda),
00010 rng(rng)
00011 {
00012
00013 currentq = NULL;
00014 ACTDEBUG = false;
00015 ELIGDEBUG = false;
00016
00017 }
00018
00019 Sarsa::~Sarsa() {}
00020
00021 int Sarsa::first_action(const std::vector<float> &s) {
00022
00023 if (ACTDEBUG){
00024 cout << "First - in state: ";
00025 printState(s);
00026 cout << endl;
00027 }
00028
00029
00030 for (std::map<state_t, std::vector<float> >::iterator i = eligibility.begin();
00031 i != eligibility.end(); i++){
00032
00033 std::vector<float> & elig_s = (*i).second;
00034 for (int j = 0; j < numactions; j++){
00035 elig_s[j] = 0.0;
00036 }
00037 }
00038
00039
00040 state_t si = canonicalize(s);
00041 std::vector<float> &Q_s = Q[si];
00042
00043
00044 const std::vector<float>::iterator a =
00045 rng.uniform() < epsilon
00046 ? Q_s.begin() + rng.uniformDiscrete(0, numactions - 1)
00047 : random_max_element(Q_s.begin(), Q_s.end());
00048
00049
00050 std::vector<float> &elig_s = eligibility[si];
00051 elig_s[a-Q_s.begin()] = 1.0;
00052
00053 if (ACTDEBUG){
00054 cout << " act: " << (a-Q_s.begin()) << " val: " << *a << endl;
00055 for (int iAct = 0; iAct < numactions; iAct++){
00056 cout << " Action: " << iAct
00057 << " val: " << Q_s[iAct] << endl;
00058 }
00059 cout << "Took action " << (a-Q_s.begin()) << " from state ";
00060 printState(s);
00061 cout << endl;
00062 }
00063
00064 return a - Q_s.begin();
00065 }
00066
00067 int Sarsa::next_action(float r, const std::vector<float> &s) {
00068
00069 if (ACTDEBUG){
00070 cout << "Next: got reward " << r << " in state: ";
00071 printState(s);
00072 cout << endl;
00073 }
00074
00075
00076 state_t st = canonicalize(s);
00077 std::vector<float> &Q_s = Q[st];
00078 const std::vector<float>::iterator max =
00079 random_max_element(Q_s.begin(), Q_s.end());
00080
00081
00082 const std::vector<float>::iterator a =
00083 rng.uniform() < epsilon
00084 ? Q_s.begin() + rng.uniformDiscrete(0, numactions - 1)
00085 : max;
00086
00087
00088 for (std::map<state_t, std::vector<float> >::iterator i = eligibility.begin();
00089 i != eligibility.end(); i++){
00090
00091 state_t si = (*i).first;
00092 std::vector<float> & elig_s = (*i).second;
00093 for (int j = 0; j < numactions; j++){
00094 if (elig_s[j] > 0.0){
00095 if (ELIGDEBUG) {
00096 cout << "updating state " << (*((*i).first))[0] << ", " << (*((*i).first))[1] << " act: " << j << " with elig: " << elig_s[j] << endl;
00097 }
00098
00099 Q[si][j] += alpha * elig_s[j] * (r + gamma * (*a) - Q[si][j]);
00100 elig_s[j] *= lambda;
00101 }
00102 }
00103
00104 }
00105
00106
00107 eligibility[st][a-Q_s.begin()] = 1.0;
00108
00109 if (ACTDEBUG){
00110 cout << " act: " << (a-Q_s.begin()) << " val: " << *a << endl;
00111 for (int iAct = 0; iAct < numactions; iAct++){
00112 cout << " Action: " << iAct
00113 << " val: " << Q_s[iAct] << endl;
00114 }
00115 cout << "Took action " << (a-Q_s.begin()) << " from state ";
00116 printState(s);
00117 cout << endl;
00118 }
00119
00120 return a - Q_s.begin();
00121 }
00122
00123 void Sarsa::last_action(float r) {
00124
00125 if (ACTDEBUG){
00126 cout << "Last: got reward " << r << endl;
00127 }
00128
00129
00130 for (std::map<state_t, std::vector<float> >::iterator i = eligibility.begin();
00131 i != eligibility.end(); i++){
00132
00133 state_t si = (*i).first;
00134 std::vector<float> & elig_s = (*i).second;
00135 for (int j = 0; j < numactions; j++){
00136 if (elig_s[j] > 0.0){
00137 if (ELIGDEBUG){
00138 cout << "updating state " << (*((*i).first))[0] << ", " << (*((*i).first))[1] << " act: " << j << " with elig: " << elig_s[j] << endl;
00139 }
00140
00141 Q[si][j] += alpha * elig_s[j] * (r - Q[si][j]);
00142 elig_s[j] = 0.0;
00143 }
00144 }
00145 }
00146
00147 }
00148
00149 Sarsa::state_t Sarsa::canonicalize(const std::vector<float> &s) {
00150 const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00151 statespace.insert(s);
00152 state_t retval = &*result.first;
00153 if (result.second) {
00154 std::vector<float> &Q_s = Q[retval];
00155 Q_s.resize(numactions,initialvalue);
00156 std::vector<float> &elig = eligibility[retval];
00157 elig.resize(numactions,0);
00158 }
00159 return retval;
00160 }
00161
00162
00163
00164 std::vector<float>::iterator
00165 Sarsa::random_max_element(
00166 std::vector<float>::iterator start,
00167 std::vector<float>::iterator end) {
00168
00169 std::vector<float>::iterator max =
00170 std::max_element(start, end);
00171 int n = std::count(max, end, *max);
00172 if (n > 1) {
00173 n = rng.uniformDiscrete(1, n);
00174 while (n > 1) {
00175 max = std::find(max + 1, end, *max);
00176 --n;
00177 }
00178 }
00179 return max;
00180 }
00181
00182
00183
00184
00185 void Sarsa::setDebug(bool d){
00186 ACTDEBUG = d;
00187 }
00188
00189
00190 void Sarsa::printState(const std::vector<float> &s){
00191 for (unsigned j = 0; j < s.size(); j++){
00192 cout << s[j] << ", ";
00193 }
00194 }
00195
00196
00197
00198 void Sarsa::seedExp(std::vector<experience> seeds){
00199
00200
00201 for (unsigned i = 0; i < seeds.size(); i++){
00202 experience e = seeds[i];
00203
00204 std::vector<float> &Q_s = Q[canonicalize(e.s)];
00205
00206
00207 const std::vector<float>::iterator a = Q_s.begin() + e.act;
00208
00209
00210 Q_s[e.act] += alpha * (e.reward + gamma * (*a) - Q_s[e.act]);
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223 }
00224
00225
00226 }
00227
00228 void Sarsa::logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax){
00229 std::vector<float> s;
00230 s.resize(2, 0.0);
00231 for (int i = xmin ; i < xmax; i++){
00232 for (int j = ymin; j < ymax; j++){
00233 s[0] = j;
00234 s[1] = i;
00235 std::vector<float> &Q_s = Q[canonicalize(s)];
00236 const std::vector<float>::iterator max =
00237 random_max_element(Q_s.begin(), Q_s.end());
00238 *of << (*max) << ",";
00239 }
00240 }
00241 }
00242
00243
00244 float Sarsa::getValue(std::vector<float> state){
00245
00246 state_t s = canonicalize(state);
00247
00248
00249 std::vector<float> &Q_s = Q[s];
00250
00251
00252 const std::vector<float>::iterator a =
00253 random_max_element(Q_s.begin(), Q_s.end());
00254
00255
00256 float valSum = 0.0;
00257 float cnt = 0;
00258 for (std::set<std::vector<float> >::iterator i = statespace.begin();
00259 i != statespace.end(); i++){
00260
00261 state_t s = canonicalize(*i);
00262
00263
00264 std::vector<float> &Q_s = Q[s];
00265
00266 for (int j = 0; j < numactions; j++){
00267 valSum += Q_s[j];
00268 cnt++;
00269 }
00270 }
00271
00272 cout << "Avg Value: " << (valSum / cnt) << endl;
00273
00274 return *a;
00275 }
00276
00277
00278 void Sarsa::savePolicy(const char* filename){
00279
00280 ofstream policyFile(filename, ios::out | ios::binary | ios::trunc);
00281
00282
00283 std::set< std::vector<float> >::iterator i = statespace.begin();
00284 int fsize = (*i).size();
00285 policyFile.write((char*)&fsize, sizeof(int));
00286
00287
00288 policyFile.write((char*)&numactions, sizeof(int));
00289
00290
00291 for (std::set< std::vector<float> >::iterator i = statespace.begin();
00292 i != statespace.end(); i++){
00293
00294 state_t s = canonicalize(*i);
00295 std::vector<float> *Q_s = &(Q[s]);
00296
00297
00298 policyFile.write((char*)&((*i)[0]), sizeof(float)*fsize);
00299
00300
00301 policyFile.write((char*)&((*Q_s)[0]), sizeof(float)*numactions);
00302
00303 }
00304
00305 policyFile.close();
00306 }
00307
00308