ETUCT.hh
Go to the documentation of this file.
00001 
00010 #ifndef _ETUCT_HH_
00011 #define _ETUCT_HH_
00012 
00013 #include <rl_common/Random.h>
00014 #include <rl_common/core.hh>
00015 
00016 #include "../Models/FactoredModel.hh"
00017 
00018 
00019 #include <set>
00020 #include <vector>
00021 #include <map>
00022 #include <deque>
00023 
00025 class ETUCT: public Planner {
00026 public:
00027 
00028 
00045   ETUCT(int numactions, float gamma, float rrange, float lambda,
00046         int MAX_ITER, float MAX_TIME, int MAX_DEPTH,  int modelType,
00047         const std::vector<float> &featmax, const std::vector<float> &featmin,
00048         const std::vector<int> &statesPerDim, bool trackActual, 
00049         int history, Random rng = Random());
00050   
00053   ETUCT(const ETUCT &);
00054 
00055   virtual ~ETUCT();
00056 
00057   virtual void setModel(MDPModel* model);
00058   virtual bool updateModelWithExperience(const std::vector<float> &last, 
00059                                          int act, 
00060                                          const std::vector<float> &curr, 
00061                                          float reward, bool term);
00062   virtual void planOnNewModel();
00063   virtual int getBestAction(const std::vector<float> &s);
00064 
00065   virtual void setSeeding(bool seed);
00066   virtual void setFirst();
00067 
00069   void logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax);
00070 
00072   void initStates();
00073   
00075   void fillInState(std::vector<float>s, int depth);
00076   
00078   std::vector<float> discretizeState(const std::vector<float> &s);
00079 
00080   bool PLANNERDEBUG;
00081   bool MODELDEBUG;
00082   bool ACTDEBUG;
00083   bool UCTDEBUG;
00084   bool REALSTATEDEBUG;
00085   bool HISTORYDEBUG;
00086 
00088   MDPModel* model;
00089 
00093   typedef const std::vector<float> *state_t;
00094 
00095 
00096 
00097 protected:
00098 
00099 
00100   struct state_info;
00101   struct model_info;
00102 
00104   struct state_samples {
00105     std::vector<state_t> samples;
00106   };
00107 
00109   struct state_info {
00110 
00111     // data filled in from models
00112     std::map< std::deque<float>, StateActionInfo>* historyModel;
00113 
00114     // q values from policy creation
00115     std::vector<float> Q;
00116 
00117     // uct experience data
00118     int uctVisits;
00119     std::vector<int> uctActions;
00120     short unsigned int visited;
00121     short unsigned int id;
00122 
00123     // needs update
00124     bool needsUpdate;
00125 
00126   };
00127 
00128 
00129 
00131   void initStateInfo(state_t s, state_info* info);
00132   
00136   state_t canonicalize(const std::vector<float> &s);
00137 
00139   void deleteInfo(state_info* info);
00140   
00142   void initNewState(state_t s);
00143   
00145   void createPolicy();
00146   
00148   void printStates();
00149   
00151   void calculateReachableStates();
00152   
00154   void removeUnreachableStates();
00155 
00157   void updateStateActionFromModel(state_t s, int a, state_info* info);
00158   
00160   void updateStateActionHistoryFromModel(const std::vector<float> modState, int a, StateActionInfo *newModel);
00161 
00163   double getSeconds();
00164 
00166   void resetUCTCounts();
00167 
00179   float uctSearch(const std::vector<float> &actualS, state_t state, int depth, std::deque<float> history);
00180 
00183   std::vector<float> simulateNextState(const std::vector<float> &actualState, state_t discState, state_info* info, const std::deque<float> &searchHistory, int action, float* reward, bool* term);
00184 
00186   int selectUCTAction(state_info* info);
00187 
00189   void canonNextStates(StateActionInfo* modelInfo);
00190 
00191   virtual void savePolicy(const char* filename);
00192 
00194   std::vector<float> addVec(const std::vector<float> &a, const std::vector<float> &b);
00195 
00197   std::vector<float> subVec(const std::vector<float> &a, const std::vector<float> &b);
00198 
00199 private:
00200 
00204   std::set<std::vector<float> > statespace;
00205 
00207   std::map<state_t, state_info> statedata;
00208 
00210   std::deque<float> saHistory;
00211 
00212   std::vector<float> featmax;
00213   std::vector<float> featmin;
00214   
00215   state_t prevstate;
00216   int prevact;
00217   state_info* previnfo;
00218 
00219   double planTime;
00220 
00221   bool seedMode;
00222 
00223   int nstates;
00224   int nactions; 
00225   int lastUpdate;
00226   bool timingType;
00227 
00228   const int numactions;
00229   const float gamma;
00230   const float rrange;
00231   const float lambda;
00232 
00233   const int MAX_ITER;
00234   const float MAX_TIME;
00235   const int MAX_DEPTH;
00236   const int modelType;
00237   const std::vector<int> &statesPerDim;
00238   const bool trackActual;
00239   const int HISTORY_SIZE;
00240   const int HISTORY_FL_SIZE;
00241 
00242 };
00243 
00244 #endif


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13