Go to the documentation of this file.00001
00012 #ifndef _PO_ParallelETUCT_HH_
00013 #define _PO_ParallelETUCT_HH_
00014
00015 #include <rl_common/Random.h>
00016 #include <rl_common/core.hh>
00017 #include <rl_common/ExperienceFile.hh>
00018
00019 #include "../Models/FactoredModel.hh"
00020 #include "../Models/C45Tree.hh"
00021
00022 #include <set>
00023 #include <vector>
00024 #include <map>
00025 #include <sstream>
00026 #include <deque>
00027
00029 void* poParallelSearchStart(void* arg);
00030
00032 void* poParallelModelLearningStart(void* arg);
00033
00035 class PO_ParallelETUCT: public Planner {
00036 public:
00037
00054 PO_ParallelETUCT(int numactions, float gamma, float rrange, float lambda,
00055 int MAX_ITER, float MAX_TIME, int MAX_DEPTH, int modelType,
00056 const std::vector<float> &featmax, const std::vector<float> &featmin,
00057 const std::vector<int> &statesPerDim, bool trackActual, int historySize,
00058 Random rng = Random());
00059
00062 PO_ParallelETUCT(const PO_ParallelETUCT &);
00063
00064 virtual ~PO_ParallelETUCT();
00065
00066 virtual void setModel(MDPModel* model);
00067 virtual bool updateModelWithExperience(const std::vector<float> &last,
00068 int act,
00069 const std::vector<float> &curr,
00070 float reward, bool term);
00071 virtual void planOnNewModel();
00072 virtual int getBestAction(const std::vector<float> &s);
00073
00074 virtual void setSeeding(bool seed);
00075 virtual void setFirst();
00076
00077 bool PLANNERDEBUG;
00078 bool POLICYDEBUG;
00079 bool MODELDEBUG;
00080 bool ACTDEBUG;
00081 bool UCTDEBUG;
00082 bool PTHREADDEBUG;
00083 bool ATHREADDEBUG;
00084 bool MTHREADDEBUG;
00085 bool TIMINGDEBUG;
00086 bool REALSTATEDEBUG;
00087 bool HISTORYDEBUG;
00088
00090 MDPModel* model;
00091
00093 MDPModel* modelcopy;
00094
00098 typedef const std::vector<float> *state_t;
00099
00100
00101
00103
00105
00106
00107 bool modelThreadStarted;
00108 bool planThreadStarted;
00109
00110
00112 pthread_t planThread;
00113
00115 pthread_t modelThread;
00116
00117
00119 std::vector<experience> expList;
00120
00122 state_t discPlanState;
00123
00125 std::vector<float> actualPlanState;
00126
00128 state_t startState;
00129
00130
00132 pthread_mutex_t update_mutex;
00134 pthread_mutex_t nactions_mutex;
00136 pthread_mutex_t plan_state_mutex;
00138 pthread_mutex_t model_mutex;
00140 pthread_mutex_t list_mutex;
00142 pthread_mutex_t history_mutex;
00143
00145 pthread_mutex_t statespace_mutex;
00146
00147
00148 pthread_cond_t list_cond;
00149
00150
00162 float uctSearch(const std::vector<float> &actS, state_t state, int depth);
00163
00165 std::vector<float> selectRandomState();
00166
00168 void parallelModelLearning();
00169
00171 void parallelSearch();
00172
00174 void loadPolicy(const char* filename);
00175
00177 void logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax);
00178
00180 std::vector<float> addVec(const std::vector<float> &a, const std::vector<float> &b);
00181
00183 std::vector<float> subVec(const std::vector<float> &a, const std::vector<float> &b);
00184
00185 protected:
00186
00187
00188 struct state_info;
00189 struct model_info;
00190
00192 struct state_samples {
00193 std::vector<state_t> samples;
00194 };
00195
00197 struct state_info {
00198
00199
00200 StateActionInfo* model;
00201
00202
00203 std::vector<float> Q;
00204
00205
00206 int uctVisits;
00207 std::vector<int> uctActions;
00208 short unsigned int visited;
00209 short unsigned int id;
00210
00211
00212 bool needsUpdate;
00213
00214
00215 pthread_mutex_t statemodel_mutex;
00216 pthread_mutex_t stateinfo_mutex;
00217
00218 };
00219
00221 void initStateInfo(state_t s,state_info* info, int id);
00222
00226 state_t canonicalize(const std::vector<float> &s);
00227
00229 void deleteInfo(state_info* info);
00230
00232 void createPolicy();
00233
00235 void printStates();
00236
00238 void calculateReachableStates();
00239
00241 void removeUnreachableStates();
00242
00244 void updateStateActionFromModel(state_t s, int a, state_info* info);
00245
00247 void updateStateActionHistoryFromModel(const std::vector<float> modState, int a, StateActionInfo *newModel);
00248
00250 double getSeconds();
00251
00252
00254 void resetAndUpdateStateActions();
00255
00258 std::vector<float> simulateNextState(const std::vector<float> &actS, state_t state, state_info* info, int action, float* reward, bool* term);
00259
00261 int selectUCTAction(state_info* info);
00262
00264 void canonNextStates(StateActionInfo* modelInfo);
00265
00267 void initStates();
00268
00270 void fillInState(std::vector<float>s, int depth);
00271
00272 virtual void savePolicy(const char* filename);
00273
00275 std::vector<float> discretizeState(const std::vector<float> &s);
00276
00277 private:
00278
00282 std::set<std::vector<float> > statespace;
00283
00285 std::map<state_t, state_info> statedata;
00286
00287 std::vector<float> featmax;
00288 std::vector<float> featmin;
00289
00291 std::deque<float> saHistory;
00292
00293 state_t prevstate;
00294 int prevact;
00295 state_info* previnfo;
00296
00297 double planTime;
00298 double initTime;
00299 double setTime;
00300 bool seedMode;
00301
00302 int nstates;
00303 int nsaved;
00304 int nactions;
00305 int lastUpdate;
00306
00307 bool timingType;
00308
00309 const int numactions;
00310 const float gamma;
00311 const float rrange;
00312 const float lambda;
00313
00314 const int MAX_ITER;
00315 const float MAX_TIME;
00316 const int MAX_DEPTH;
00317 const int modelType;
00318 const std::vector<int> &statesPerDim;
00319 const bool trackActual;
00320 const int HISTORY_SIZE;
00321 const int HISTORY_FL_SIZE;
00322
00323 ExperienceFile expfile;
00324 };
00325
00326 #endif