Sarsa.cc
Go to the documentation of this file.
00001 #include <rl_agent/Sarsa.hh>
00002 #include <algorithm>
00003 
00004 Sarsa::Sarsa(int numactions, float gamma,
00005              float initialvalue, float alpha, float ep, float lambda,
00006              Random rng):
00007   numactions(numactions), gamma(gamma),
00008   initialvalue(initialvalue), alpha(alpha),
00009   epsilon(ep), lambda(lambda),
00010   rng(rng)
00011 {
00012 
00013   currentq = NULL;
00014   ACTDEBUG = false; //true; //false;
00015   ELIGDEBUG = false;
00016 
00017 }
00018 
00019 Sarsa::~Sarsa() {}
00020 
00021 int Sarsa::first_action(const std::vector<float> &s) {
00022 
00023   if (ACTDEBUG){
00024     cout << "First - in state: ";
00025     printState(s);
00026     cout << endl;
00027   }
00028 
00029   // clear all eligibility traces
00030   for (std::map<state_t, std::vector<float> >::iterator i = eligibility.begin();
00031        i != eligibility.end(); i++){
00032 
00033     std::vector<float> & elig_s = (*i).second;
00034     for (int j = 0; j < numactions; j++){
00035       elig_s[j] = 0.0;
00036     }
00037   }
00038 
00039   // Get action values
00040   state_t si = canonicalize(s);
00041   std::vector<float> &Q_s = Q[si];
00042 
00043   // Choose an action
00044   const std::vector<float>::iterator a =
00045     rng.uniform() < epsilon
00046     ? Q_s.begin() + rng.uniformDiscrete(0, numactions - 1) // Choose randomly
00047     : random_max_element(Q_s.begin(), Q_s.end()); // Choose maximum
00048 
00049   // set eligiblity to 1
00050   std::vector<float> &elig_s = eligibility[si];
00051   elig_s[a-Q_s.begin()] = 1.0;
00052 
00053   if (ACTDEBUG){
00054     cout << " act: " << (a-Q_s.begin()) << " val: " << *a << endl;
00055     for (int iAct = 0; iAct < numactions; iAct++){
00056       cout << " Action: " << iAct 
00057            << " val: " << Q_s[iAct] << endl;
00058     }
00059     cout << "Took action " << (a-Q_s.begin()) << " from state ";
00060     printState(s);
00061     cout << endl;
00062   }
00063 
00064   return a - Q_s.begin();
00065 }
00066 
00067 int Sarsa::next_action(float r, const std::vector<float> &s) {
00068 
00069   if (ACTDEBUG){
00070     cout << "Next: got reward " << r << " in state: ";
00071     printState(s);
00072     cout << endl;
00073   }
00074 
00075   // Get action values
00076   state_t st = canonicalize(s);
00077   std::vector<float> &Q_s = Q[st];
00078   const std::vector<float>::iterator max =
00079     random_max_element(Q_s.begin(), Q_s.end());
00080 
00081   // Choose an action
00082   const std::vector<float>::iterator a =
00083     rng.uniform() < epsilon
00084     ? Q_s.begin() + rng.uniformDiscrete(0, numactions - 1)
00085     : max;
00086 
00087   // Update value for all with positive eligibility
00088   for (std::map<state_t, std::vector<float> >::iterator i = eligibility.begin();
00089        i != eligibility.end(); i++){
00090 
00091     state_t si = (*i).first;
00092     std::vector<float> & elig_s = (*i).second;
00093     for (int j = 0; j < numactions; j++){
00094       if (elig_s[j] > 0.0){
00095         if (ELIGDEBUG) {
00096           cout << "updating state " << (*((*i).first))[0] << ", " << (*((*i).first))[1] << " act: " << j << " with elig: " << elig_s[j] << endl;
00097         }
00098         // update
00099         Q[si][j] += alpha * elig_s[j] * (r + gamma * (*a) - Q[si][j]);
00100         elig_s[j] *= lambda;
00101       }
00102     }
00103         
00104   }
00105 
00106   // Set elig to 1
00107   eligibility[st][a-Q_s.begin()] = 1.0;
00108 
00109   if (ACTDEBUG){
00110     cout << " act: " << (a-Q_s.begin()) << " val: " << *a << endl;
00111     for (int iAct = 0; iAct < numactions; iAct++){
00112       cout << " Action: " << iAct 
00113            << " val: " << Q_s[iAct] << endl;
00114     }
00115     cout << "Took action " << (a-Q_s.begin()) << " from state ";
00116     printState(s);
00117     cout << endl;
00118   }
00119 
00120   return a - Q_s.begin();
00121 }
00122 
00123 void Sarsa::last_action(float r) {
00124 
00125   if (ACTDEBUG){
00126     cout << "Last: got reward " << r << endl;
00127   }
00128 
00129   // Update value for all with positive eligibility
00130   for (std::map<state_t, std::vector<float> >::iterator i = eligibility.begin();
00131        i != eligibility.end(); i++){
00132     
00133     state_t si = (*i).first;
00134     std::vector<float> & elig_s = (*i).second;
00135     for (int j = 0; j < numactions; j++){
00136       if (elig_s[j] > 0.0){
00137         if (ELIGDEBUG){
00138           cout << "updating state " << (*((*i).first))[0] << ", " << (*((*i).first))[1] << " act: " << j << " with elig: " << elig_s[j] << endl;
00139         }
00140         // update
00141         Q[si][j] += alpha * elig_s[j] * (r - Q[si][j]);
00142         elig_s[j] = 0.0;
00143       }
00144     }  
00145   }
00146   
00147 }
00148 
00149 Sarsa::state_t Sarsa::canonicalize(const std::vector<float> &s) {
00150   const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00151     statespace.insert(s);
00152   state_t retval = &*result.first; // Dereference iterator then get pointer 
00153   if (result.second) { // s is new, so initialize Q(s,a) for all a
00154     std::vector<float> &Q_s = Q[retval];
00155     Q_s.resize(numactions,initialvalue);
00156     std::vector<float> &elig = eligibility[retval];
00157     elig.resize(numactions,0);
00158   }
00159   return retval; 
00160 }
00161 
00162 
00163 
00164   std::vector<float>::iterator
00165 Sarsa::random_max_element(
00166                              std::vector<float>::iterator start,
00167                              std::vector<float>::iterator end) {
00168 
00169   std::vector<float>::iterator max =
00170     std::max_element(start, end);
00171   int n = std::count(max, end, *max);
00172   if (n > 1) {
00173     n = rng.uniformDiscrete(1, n);
00174     while (n > 1) {
00175       max = std::find(max + 1, end, *max);
00176       --n;
00177     }
00178   }
00179   return max;
00180 }
00181 
00182 
00183 
00184 
00185 void Sarsa::setDebug(bool d){
00186   ACTDEBUG = d;
00187 }
00188 
00189 
00190 void Sarsa::printState(const std::vector<float> &s){
00191   for (unsigned j = 0; j < s.size(); j++){
00192     cout << s[j] << ", ";
00193   }
00194 }
00195 
00196 
00197 
00198 void Sarsa::seedExp(std::vector<experience> seeds){
00199 
00200   // for each seeding experience, update our model
00201   for (unsigned i = 0; i < seeds.size(); i++){
00202     experience e = seeds[i];
00203      
00204     std::vector<float> &Q_s = Q[canonicalize(e.s)];
00205     
00206     // Get q value for action taken
00207     const std::vector<float>::iterator a = Q_s.begin() + e.act;
00208 
00209     // Update value of action just executed
00210     Q_s[e.act] += alpha * (e.reward + gamma * (*a) - Q_s[e.act]);
00211     
00212  
00213     /*
00214     cout << "Seeding with experience " << i << endl;
00215     cout << "last: " << (e.s)[0] << ", " << (e.s)[1] << ", " 
00216          << (e.s)[2] << endl;
00217     cout << "act: " << e.act << " r: " << e.reward << endl;
00218     cout << "next: " << (e.next)[0] << ", " << (e.next)[1] << ", " 
00219          << (e.next)[2] << ", " << e.terminal << endl;
00220     cout << "Q: " << *currentq << " max: " << *max << endl;
00221     */
00222 
00223   }
00224 
00225 
00226 }
00227 
00228 void Sarsa::logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax){
00229   std::vector<float> s;
00230   s.resize(2, 0.0);
00231   for (int i = xmin ; i < xmax; i++){
00232     for (int j = ymin; j < ymax; j++){
00233       s[0] = j;
00234       s[1] = i;
00235       std::vector<float> &Q_s = Q[canonicalize(s)];
00236       const std::vector<float>::iterator max =
00237         random_max_element(Q_s.begin(), Q_s.end());
00238       *of << (*max) << ",";
00239     }
00240   }
00241 }
00242 
00243 
00244 float Sarsa::getValue(std::vector<float> state){
00245 
00246   state_t s = canonicalize(state);
00247 
00248   // Get Q values
00249   std::vector<float> &Q_s = Q[s];
00250 
00251   // Choose an action
00252   const std::vector<float>::iterator a =
00253     random_max_element(Q_s.begin(), Q_s.end()); // Choose maximum
00254 
00255   // Get avg value
00256   float valSum = 0.0;
00257   float cnt = 0;
00258   for (std::set<std::vector<float> >::iterator i = statespace.begin();
00259        i != statespace.end(); i++){
00260 
00261     state_t s = canonicalize(*i);
00262 
00263     // get state's info
00264     std::vector<float> &Q_s = Q[s];
00265       
00266     for (int j = 0; j < numactions; j++){
00267       valSum += Q_s[j];
00268       cnt++;
00269     }
00270   }
00271 
00272   cout << "Avg Value: " << (valSum / cnt) << endl;
00273 
00274   return *a;
00275 }
00276 
00277 
00278 void Sarsa::savePolicy(const char* filename){
00279 
00280   ofstream policyFile(filename, ios::out | ios::binary | ios::trunc);
00281 
00282   // first part, save the vector size
00283   std::set< std::vector<float> >::iterator i = statespace.begin();
00284   int fsize = (*i).size();
00285   policyFile.write((char*)&fsize, sizeof(int));
00286 
00287   // save numactions
00288   policyFile.write((char*)&numactions, sizeof(int));
00289 
00290   // go through all states, and save Q values
00291   for (std::set< std::vector<float> >::iterator i = statespace.begin();
00292        i != statespace.end(); i++){
00293 
00294     state_t s = canonicalize(*i);
00295     std::vector<float> *Q_s = &(Q[s]);
00296 
00297     // save state
00298     policyFile.write((char*)&((*i)[0]), sizeof(float)*fsize);
00299 
00300     // save q-values
00301     policyFile.write((char*)&((*Q_s)[0]), sizeof(float)*numactions);
00302 
00303   }
00304 
00305   policyFile.close();
00306 }
00307 
00308 


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13