00001
00006 #ifndef _MODELBASED_HH_
00007 #define _MODELBASED_HH_
00008
00009 #include <rl_common/Random.h>
00010 #include <rl_common/core.hh>
00011
00012 #include <set>
00013 #include <vector>
00014 #include <map>
00015
00017 class ModelBasedAgent: public Agent {
00018 public:
00045 ModelBasedAgent(int numactions, float gamma, float rmax, float rrange,
00046 int modelType, int exploreType,
00047 int predType, int nModels, int plannerType,
00048 float epsilon, float lambda, float MAX_TIME,
00049 float m, const std::vector<float> &featmin,
00050 const std::vector<float> &featmax,
00051 int statesPerDim, int history, float v, float n,
00052 bool depTrans, bool relTrans, float featPct,
00053 bool stoch, bool episodic, Random rng = Random());
00054
00080 ModelBasedAgent(int numactions, float gamma, float rmax, float rrange,
00081 int modelType, int exploreType,
00082 int predType, int nModels, int plannerType,
00083 float epsilon, float lambda, float MAX_TIME,
00084 float m, const std::vector<float> &featmin,
00085 const std::vector<float> &featmax,
00086 std::vector<int> statesPerDim, int history, float v, float n,
00087 bool depTrans, bool relTrans, float featPct,
00088 bool stoch, bool episodic, Random rng = Random());
00089
00091 void initParams();
00092
00095 ModelBasedAgent(const ModelBasedAgent &);
00096
00097 virtual ~ModelBasedAgent();
00098
00099 virtual int first_action(const std::vector<float> &s);
00100 virtual int next_action(float r, const std::vector<float> &s);
00101 virtual void last_action(float r);
00102 virtual void seedExp(std::vector<experience> seeds);
00103 virtual void setDebug(bool d);
00104 virtual void savePolicy(const char* filename);
00105
00107 void logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax);
00108
00109 bool AGENTDEBUG;
00110 bool POLICYDEBUG;
00111 bool ACTDEBUG;
00112 bool SIMPLEDEBUG;
00113 bool TIMEDEBUG;
00114
00115 bool seeding;
00116
00118 MDPModel* model;
00119
00121 Planner* planner;
00122
00123 float planningTime;
00124 float modelUpdateTime;
00125 float actionTime;
00126
00127 std::vector<float> featmin;
00128 std::vector<float> featmax;
00129
00130 protected:
00131
00135 typedef const std::vector<float> *state_t;
00136
00138 void saveStateAndAction(const std::vector<float> &s, int act);
00139
00141 int chooseAction(const std::vector<float> &s);
00142
00144 void initModel(int nfactors);
00145
00147 void initPlanner();
00148
00150 void updateWithNewExperience(const std::vector<float> &last,
00151 const std::vector<float> & curr,
00152 int lastact, float reward, bool term);
00153
00155 double getSeconds();
00156
00157 private:
00158
00160 std::vector<float> prevstate;
00162 int prevact;
00163
00164 int nstates;
00165 int nactions;
00166
00167 bool modelNeedsUpdate;
00168 int lastUpdate;
00169
00170 int BATCH_FREQ;
00171
00172 bool modelChanged;
00173
00174 const int numactions;
00175 const float gamma;
00176
00177 const float rmax;
00178 const float rrange;
00179 const float qmax;
00180 const int modelType;
00181 const int exploreType;
00182 const int predType;
00183 const int nModels;
00184 const int plannerType;
00185
00186 const float epsilon;
00187 const float lambda;
00188 const float MAX_TIME;
00189
00190 const float M;
00191 const std::vector<int> statesPerDim;
00192 const int history;
00193 const float v;
00194 const float n;
00195 const bool depTrans;
00196 const bool relTrans;
00197 const float featPct;
00198 const bool stoch;
00199 const bool episodic;
00200
00201 Random rng;
00202
00203 };
00204
00205 #endif