SavedPolicy.cc
Go to the documentation of this file.
00001 #include <rl_agent/SavedPolicy.hh>
00002 #include <algorithm>
00003 
00004 SavedPolicy::SavedPolicy(int numactions, const char* filename):
00005   numactions(numactions)
00006 {
00007 
00008   ACTDEBUG = false;
00009   LOADDEBUG = false;
00010   loaded = false;
00011 
00012   loadPolicy(filename);
00013 
00014   
00015 }
00016 
00017 SavedPolicy::~SavedPolicy() {}
00018 
00019 int SavedPolicy::first_action(const std::vector<float> &s) {
00020 
00021   if (ACTDEBUG){
00022     cout << "First - in state: ";
00023     printState(s);
00024     cout << endl;
00025   }
00026 
00027   // Get action values
00028   std::vector<float> &Q_s = Q[canonicalize(s)];
00029 
00030   // Choose an action
00031   const std::vector<float>::iterator a =
00032     std::max_element(Q_s.begin(), Q_s.end()); // Choose maximum
00033 
00034   if (ACTDEBUG){
00035     cout << " act: " << (a-Q_s.begin()) << " val: " << *a << endl;
00036     for (int iAct = 0; iAct < numactions; iAct++){
00037       cout << " Action: " << iAct 
00038            << " val: " << Q_s[iAct] << endl;
00039     }
00040     cout << "Took action " << (a-Q_s.begin()) << " from state ";
00041     printState(s);
00042     cout << endl;
00043   }
00044 
00045   return a - Q_s.begin();
00046 }
00047 
00048 int SavedPolicy::next_action(float r, const std::vector<float> &s) {
00049 
00050   if (ACTDEBUG){
00051     cout << "Next: got reward " << r << " in state: ";
00052     printState(s);
00053     cout << endl;
00054   }
00055 
00056   // Get action values
00057   std::vector<float> &Q_s = Q[canonicalize(s)];
00058   const std::vector<float>::iterator max =
00059     std::max_element(Q_s.begin(), Q_s.end());
00060 
00061   // Choose an action
00062   const std::vector<float>::iterator a = max;
00063 
00064   if (ACTDEBUG){
00065     cout << " act: " << (a-Q_s.begin()) << " val: " << *a << endl;
00066     for (int iAct = 0; iAct < numactions; iAct++){
00067       cout << " Action: " << iAct 
00068            << " val: " << Q_s[iAct] << endl;
00069     }
00070     cout << "Took action " << (a-Q_s.begin()) << " from state ";
00071     printState(s);
00072     cout << endl;
00073   }
00074 
00075   return a - Q_s.begin();
00076 }
00077 
00078 void SavedPolicy::last_action(float r) {
00079 
00080   if (ACTDEBUG){
00081     cout << "Last: got reward " << r << endl;
00082   }
00083 
00084 }
00085 
00086 SavedPolicy::state_t SavedPolicy::canonicalize(const std::vector<float> &s) {
00087   const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00088     statespace.insert(s);
00089   state_t retval = &*result.first; // Dereference iterator then get pointer 
00090   if (result.second) { // s is new, so initialize Q(s,a) for all a
00091     if (loaded){
00092       cout << "State unknown in policy!!!" << endl;
00093       for (unsigned i = 0; i < s.size(); i++){
00094         cout << s[i] << ", ";
00095       }
00096       cout << endl;
00097     }
00098     std::vector<float> &Q_s = Q[retval];
00099     Q_s.resize(numactions,0.0);
00100   }
00101   return retval; 
00102 }
00103 
00104 
00105 
00106 void SavedPolicy::printState(const std::vector<float> &s){
00107   for (unsigned j = 0; j < s.size(); j++){
00108     cout << s[j] << ", ";
00109   }
00110 }
00111 
00112 
00113 
00114 void SavedPolicy::seedExp(std::vector<experience> seeds){
00115   return;
00116 }
00117 
00118 
00119 void SavedPolicy::loadPolicy(const char* filename){
00120 
00121   ifstream policyFile(filename, ios::in | ios::binary);
00122 
00123   // first part, save the vector size
00124   int fsize;
00125   policyFile.read((char*)&fsize, sizeof(int));
00126   if (LOADDEBUG) cout << "Numfeats loaded: " << fsize << endl;
00127 
00128   // save numactions
00129   int nact;
00130   policyFile.read((char*)&nact, sizeof(int));
00131 
00132   if (nact != numactions){
00133     cout << "this policy is not valid loaded nact: " << nact 
00134          << " was told: " << numactions << endl;
00135     exit(-1);
00136   }
00137 
00138   // go through all states, loading q values
00139   while(!policyFile.eof()){
00140     std::vector<float> state;
00141     state.resize(fsize, 0.0);
00142 
00143     // load state
00144     policyFile.read((char*)&(state[0]), sizeof(float)*fsize);
00145     if (LOADDEBUG){
00146       cout << "load policy for state: ";
00147       printState(state);
00148     }
00149 
00150     state_t s = canonicalize(state);
00151 
00152     if (policyFile.eof()) break;
00153 
00154     // load q values
00155     policyFile.read((char*)&(Q[s][0]), sizeof(float)*numactions);
00156     
00157     if (LOADDEBUG){
00158       cout << "Q values: " << endl;
00159       for (int iAct = 0; iAct < numactions; iAct++){
00160         cout << " Action: " << iAct << " val: " << Q[s][iAct] << endl;
00161       }
00162     }
00163   }
00164   
00165   policyFile.close();
00166   cout << "Policy loaded!!!" << endl;
00167   loaded = true;
00168 }


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13