PO_ETUCT.hh
Go to the documentation of this file.
00001 
00010 #ifndef _PO_ETUCT_HH_
00011 #define _PO_ETUCT_HH_
00012 
00013 #include <rl_common/Random.h>
00014 #include <rl_common/core.hh>
00015 
00016 #include "../Models/FactoredModel.hh"
00017 
00018 #include <set>
00019 #include <vector>
00020 #include <map>
00021 #include <deque>
00022 
00024 class PO_ETUCT: public Planner {
00025 public:
00026 
00043   PO_ETUCT(int numactions, float gamma, float rrange, float lambda,
00044            int MAX_ITER, float MAX_TIME, int MAX_DEPTH,  int modelType,
00045            const std::vector<float> &featmax, const std::vector<float> &featmin,
00046            const std::vector<int> &statesPerDim, bool trackActual, 
00047            int history, Random rng = Random());
00048   
00051   PO_ETUCT(const PO_ETUCT &);
00052 
00053   virtual ~PO_ETUCT();
00054 
00055   virtual void setModel(MDPModel* model);
00056   virtual bool updateModelWithExperience(const std::vector<float> &last, 
00057                                          int act, 
00058                                          const std::vector<float> &curr, 
00059                                          float reward, bool term);
00060   virtual void planOnNewModel();
00061   virtual int getBestAction(const std::vector<float> &s);
00062 
00063   virtual void setSeeding(bool seed);
00064   virtual void setFirst();
00065 
00067   void logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax);
00068   
00070   std::vector<float> discretizeState(const std::vector<float> &s);
00071 
00072   bool PLANNERDEBUG;
00073   bool MODELDEBUG;
00074   bool ACTDEBUG;
00075   bool UCTDEBUG;
00076   bool REALSTATEDEBUG;
00077   bool HISTORYDEBUG;
00078 
00080   MDPModel* model;
00081 
00085   typedef const std::vector<float> *state_t;
00086 
00087 
00088 protected:
00089 
00090 
00091   struct state_info;
00092   struct model_info;
00093 
00095   struct state_samples {
00096     std::vector<state_t> samples;
00097   };
00098 
00100   struct state_info {
00101 
00102     // data filled in from models
00103     StateActionInfo* model;
00104 
00105     // q values from policy creation
00106     std::vector<float> Q;
00107 
00108     // uct experience data
00109     int uctVisits;
00110     std::vector<int> uctActions;
00111     short unsigned int visited;
00112     short unsigned int id;
00113 
00114     // needs update
00115     bool needsUpdate;
00116 
00117   };
00118 
00119 
00121   void initStateInfo(state_t s, state_info* info);
00122   
00126   state_t canonicalize(const std::vector<float> &s);
00127 
00129   void deleteInfo(state_info* info);
00130   
00132   void initNewState(state_t s);
00133   
00135   void createPolicy();
00136   
00138   void printStates();
00139   
00141   void calculateReachableStates();
00142   
00144   void removeUnreachableStates();
00145 
00147   void updateStateActionFromModel(state_t s, int a, state_info* info);
00148 
00150   void updateStateActionHistoryFromModel(const std::vector<float> modState, int a, StateActionInfo *newModel);
00151 
00153   double getSeconds();
00154 
00156   void resetUCTCounts();
00157   
00169   float uctSearch(const std::vector<float> &actualS, state_t state, int depth);
00170   
00173   std::vector<float> simulateNextState(const std::vector<float> &actualState, state_t discState, state_info* info, int action, float* reward, bool* term);
00174   
00176   int selectUCTAction(state_info* info);
00177   
00179   void canonNextStates(StateActionInfo* modelInfo);
00180   
00181   virtual void savePolicy(const char* filename);
00182   
00184   std::vector<float> addVec(const std::vector<float> &a, const std::vector<float> &b);
00185   
00187   std::vector<float> subVec(const std::vector<float> &a, const std::vector<float> &b);
00188 
00189 private:
00190 
00194   std::set<std::vector<float> > statespace;
00195 
00197   std::map<state_t, state_info> statedata;
00198 
00200   std::deque<float> saHistory;
00201 
00202   std::vector<float> featmax;
00203   std::vector<float> featmin;
00204   
00205   state_t prevstate;
00206   int prevact;
00207   state_info* previnfo;
00208 
00209   double planTime;
00210 
00211   bool seedMode;
00212 
00213   int nstates;
00214   int nactions; 
00215   int lastUpdate;
00216   bool timingType;
00217 
00218   const int numactions;
00219   const float gamma;
00220   const float rrange;
00221   const float lambda;
00222 
00223   const int MAX_ITER;
00224   const float MAX_TIME;
00225   const int MAX_DEPTH;
00226   const int modelType;
00227   const std::vector<int> &statesPerDim;
00228   const bool trackActual;
00229   const int HISTORY_SIZE;
00230   const int HISTORY_FL_SIZE;
00231 
00232 };
00233 
00234 #endif


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13