00001
00006 #include <rl_agent/ModelBasedAgent.hh>
00007 #include <algorithm>
00008
00009 #include <sys/time.h>
00010
00011
00012 #include "../Planners/ValueIteration.hh"
00013 #include "../Planners/PolicyIteration.hh"
00014 #include "../Planners/PrioritizedSweeping.hh"
00015 #include "../Planners/ETUCT.hh"
00016 #include "../Planners/ParallelETUCT.hh"
00017 #include "../Planners/PO_ETUCT.hh"
00018 #include "../Planners/PO_ParallelETUCT.hh"
00019 #include "../Planners/MBS.hh"
00020
00021
00022 #include "../Models/RMaxModel.hh"
00023 #include "../Models/FactoredModel.hh"
00024 #include "../Models/ExplorationModel.hh"
00025
00026 ModelBasedAgent::ModelBasedAgent(int numactions, float gamma,
00027 float rmax, float rrange,
00028 int modelType, int exploreType,
00029 int predType, int nModels, int plannerType,
00030 float epsilon, float lambda, float MAX_TIME,
00031 float m, const std::vector<float> &featmin,
00032 const std::vector<float> &featmax,
00033 std::vector<int> nstatesPerDim, int history, float v, float n,
00034 bool depTrans, bool relTrans, float featPct, bool stoch, bool episodic,
00035 Random rng):
00036 featmin(featmin), featmax(featmax),
00037 numactions(numactions), gamma(gamma), rmax(rmax), rrange(rrange),
00038 qmax(rmax/(1.0-gamma)),
00039 modelType(modelType), exploreType(exploreType),
00040 predType(predType), nModels(nModels), plannerType(plannerType),
00041 epsilon(epsilon), lambda(lambda), MAX_TIME(MAX_TIME),
00042 M(m), statesPerDim(nstatesPerDim), history(history), v(v), n(n),
00043 depTrans(depTrans), relTrans(relTrans), featPct(featPct),
00044 stoch(stoch), episodic(episodic), rng(rng)
00045 {
00046
00047 if (statesPerDim[0] > 0){
00048 cout << "MBA: Planner will use states discretized by various amounts per dim with continuous model" << endl;
00049 }
00050
00051 initParams();
00052
00053 }
00054
00055
00056
00057 ModelBasedAgent::ModelBasedAgent(int numactions, float gamma,
00058 float rmax, float rrange,
00059 int modelType, int exploreType,
00060 int predType, int nModels, int plannerType,
00061 float epsilon, float lambda, float MAX_TIME,
00062 float m, const std::vector<float> &featmin,
00063 const std::vector<float> &featmax,
00064 int nstatesPerDim, int history, float v, float n,
00065 bool depTrans, bool relTrans, float featPct,
00066 bool stoch, bool episodic, Random rng):
00067 featmin(featmin), featmax(featmax),
00068 numactions(numactions), gamma(gamma), rmax(rmax), rrange(rrange),
00069 qmax(rmax/(1.0-gamma)),
00070 modelType(modelType), exploreType(exploreType),
00071 predType(predType), nModels(nModels), plannerType(plannerType),
00072 epsilon(epsilon), lambda(lambda), MAX_TIME(MAX_TIME),
00073 M(m), statesPerDim(featmin.size(),nstatesPerDim), history(history), v(v), n(n),
00074 depTrans(depTrans), relTrans(relTrans), featPct(featPct),
00075 stoch(stoch), episodic(episodic), rng(rng)
00076 {
00077
00078 if (statesPerDim[0] > 0){
00079 cout << "MBA: Planner will use states discretized by " << statesPerDim[0] << " with continuous model" << endl;
00080 }
00081
00082 initParams();
00083
00084 }
00085
00086
00087 void ModelBasedAgent::initParams(){
00088
00089 nstates = 0;
00090 nactions = 0;
00091
00092 model = NULL;
00093 planner = NULL;
00094
00095 modelUpdateTime = 0.0;
00096 planningTime = 0.0;
00097 actionTime = 0.0;
00098
00099 modelChanged = false;
00100
00101
00102 BATCH_FREQ = 1;
00103
00104 TIMEDEBUG = false;
00105 AGENTDEBUG = false;
00106 ACTDEBUG = false;
00107 SIMPLEDEBUG = false;
00108
00109
00110 if (qmax <= 0.1 && (exploreType == TWO_MODE_PLUS_R ||
00111 exploreType == CONTINUOUS_BONUS_R ||
00112 exploreType == CONTINUOUS_BONUS ||
00113 exploreType == THRESHOLD_BONUS_R)) {
00114 std::cerr << "For this exploration type, rmax needs to be an additional positive bonus value, not a replacement for the q-value" << endl;
00115 exit(-1);
00116 }
00117
00118 if (exploreType == TWO_MODE || exploreType == TWO_MODE_PLUS_R){
00119 std::cerr << "This exploration type does not work in this agent." << endl;
00120 exit(-1);
00121 }
00122
00123 seeding = false;
00124
00125 if (SIMPLEDEBUG)
00126 cout << "qmax: " << qmax << endl;
00127
00128 }
00129
00130 ModelBasedAgent::~ModelBasedAgent() {
00131 delete planner;
00132 delete model;
00133 featmin.clear();
00134 featmax.clear();
00135 prevstate.clear();
00136 }
00137
00138 int ModelBasedAgent::first_action(const std::vector<float> &s) {
00139 if (AGENTDEBUG) cout << "first_action(s)" << endl;
00140
00141 if (model == NULL)
00142 initModel(s.size());
00143
00144 planner->setFirst();
00145
00146
00147 if (plannerType == PARALLEL_ET_UCT || plannerType == PAR_ETUCT_ACTUAL)
00148 planner->planOnNewModel();
00149
00150
00151 int act = chooseAction(s);
00152
00153
00154 saveStateAndAction(s, act);
00155
00156 if (ACTDEBUG)
00157 cout << "Took action " << act << " from state "
00158 << s[0] << "," << s[1]
00159 << endl;
00160
00161
00162 return act;
00163
00164 }
00165
00166 int ModelBasedAgent::next_action(float r, const std::vector<float> &s) {
00167 if (AGENTDEBUG) {
00168 cout << "next_action(r = " << r
00169 << ", s = " << &s << ")" << endl;
00170 }
00171
00172 if (SIMPLEDEBUG) cout << "Got Reward " << r;
00173
00174
00175
00176 updateWithNewExperience(prevstate, s, prevact, r, false);
00177
00178
00179 int act = chooseAction(s);
00180
00181
00182 saveStateAndAction(s, act);
00183
00184 if (ACTDEBUG){
00185 cout << "Took action " << act << " from state "
00186 << (s)[0];
00187 for (unsigned i = 1; i < s.size(); i++){
00188 cout << "," << (s)[i];
00189 }
00190 cout << endl;
00191 }
00192
00193
00194 return act;
00195
00196 }
00197
00198 void ModelBasedAgent::last_action(float r) {
00199 if (AGENTDEBUG) cout << "last_action(r = " << r
00200 << ")" << endl;
00201
00202 if (AGENTDEBUG) cout << "Got Reward " << r;
00203
00204
00205
00206 updateWithNewExperience(prevstate, prevstate, prevact, r, true);
00207
00208 }
00209
00210
00211
00213
00215
00216
00217 void ModelBasedAgent::initModel(int nfactors){
00218 if ( AGENTDEBUG) cout << "initModel nfactors: " << nfactors << endl;
00219
00220 bool needConf =
00221 (exploreType != NO_EXPLORE && exploreType != EXPLORE_UNKNOWN &&
00222 exploreType != EPSILONGREEDY && exploreType != UNVISITED_BONUS &&
00223 exploreType != UNVISITED_ACT_BONUS);
00224
00225 std::vector<float> featRange(featmax.size(), 0);
00226 for (unsigned i = 0; i < featmax.size(); i++){
00227 featRange[i] = featmax[i] - featmin[i];
00228 cout << "feature " << i << " has range " << featRange[i] << endl;
00229 }
00230 cout << "reward range: " << rrange << endl;
00231
00232 float treeRangePct = 0.0001;
00233
00234
00235 if (modelType == M5MULTI || modelType == M5SINGLE ||
00236 modelType == M5ALLMULTI || modelType == M5ALLSINGLE ||
00237 modelType == ALLM5TYPES){
00238 treeRangePct = 0.0001;
00239 }
00240
00241
00242 for (int i = 0; i < modelType; i++){
00243 rng.uniform(0, 1);
00244 }
00245
00246
00247 if (modelType == RMAX) {
00248 model = new RMaxModel(M, numactions, rng);
00249 }
00250
00251
00252 else if (modelType == C45TREE || modelType == STUMP ||
00253 modelType == M5MULTI || modelType == M5SINGLE ||
00254 modelType == M5ALLMULTI || modelType == M5ALLSINGLE ||
00255 modelType == ALLM5TYPES ||
00256 modelType == LSTMULTI || modelType == LSTSINGLE ||
00257 modelType == GPREGRESS || modelType == GPTREE){
00258
00259 model = new FactoredModel(0,numactions, M, modelType, predType, nModels, treeRangePct, featRange, rrange, needConf, depTrans, relTrans, featPct, stoch, episodic, rng);
00260 }
00261
00262
00263
00264
00265
00266
00267
00268
00269 if (exploreType != NO_EXPLORE && exploreType != EPSILONGREEDY){
00270 MDPModel* m2 = model;
00271
00272 model = new ExplorationModel(m2, modelType, exploreType,
00273 predType, nModels, M, numactions,
00274 rmax, qmax, rrange, nfactors, v, n,
00275 featmax, featmin, rng);
00276
00277 }
00278
00279 initPlanner();
00280 planner->setModel(model);
00281
00282 }
00283
00284 void ModelBasedAgent::initPlanner(){
00285 if (AGENTDEBUG) cout << "InitPlanner type: " << plannerType << endl;
00286
00287 int max_path = 200;
00288
00289
00290 if (plannerType == VALUE_ITERATION){
00291 planner = new ValueIteration(numactions, gamma, 500000, 10.0, modelType, featmax, featmin, statesPerDim, rng);
00292 }
00293 else if (plannerType == MBS_VI){
00294 planner = new MBS(numactions, gamma, 500000, 10.0, modelType, featmax, featmin, statesPerDim, history, rng);
00295 }
00296 else if (plannerType == POLICY_ITERATION){
00297 planner = new PolicyIteration(numactions, gamma, 500000, 10.0, modelType, featmax, featmin, statesPerDim, rng);
00298 }
00299 else if (plannerType == PRI_SWEEPING){
00300 planner = new PrioritizedSweeping(numactions, gamma, 10.0, true, modelType, featmax, featmin, rng);
00301 }
00302 else if (plannerType == MOD_PRI_SWEEPING){
00303 planner = new PrioritizedSweeping(numactions, gamma, 10.0, false, modelType, featmax, featmin, rng);
00304 }
00305 else if (plannerType == ET_UCT){
00306 planner = new ETUCT(numactions, gamma, rrange, lambda, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, false, history, rng);
00307 }
00308 else if (plannerType == POMDP_ETUCT){
00309 planner = new PO_ETUCT(numactions, gamma, rrange, lambda, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, true, history, rng);
00310 }
00311 else if (plannerType == POMDP_PAR_ETUCT){
00312 planner = new PO_ParallelETUCT(numactions, gamma, rrange, lambda, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, true, history, rng);
00313 }
00314 else if (plannerType == ET_UCT_ACTUAL){
00315 planner = new ETUCT(numactions, gamma, rrange, lambda, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, true, history, rng);
00316 }
00317 else if (plannerType == PARALLEL_ET_UCT){
00318 planner = new ParallelETUCT(numactions, gamma, rrange, lambda, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, false, history, rng);
00319 }
00320 else if (plannerType == PAR_ETUCT_ACTUAL){
00321 planner = new ParallelETUCT(numactions, gamma, rrange, lambda, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, true, history, rng);
00322 }
00323 else if (plannerType == ET_UCT_L1){
00324 planner = new ETUCT(numactions, gamma, rrange, 1.0, 500000, MAX_TIME, max_path, modelType, featmax, featmin, statesPerDim, false, history, rng);
00325 }
00326 else {
00327 std::cerr << "ERROR: invalid planner type: " << plannerType << endl;
00328 exit(-1);
00329 }
00330
00331 }
00332
00333 void ModelBasedAgent::updateWithNewExperience(const std::vector<float> &last,
00334 const std::vector<float> &curr,
00335 int lastact, float reward,
00336 bool terminal){
00337 if (AGENTDEBUG) cout << "updateWithNewExperience(last = " << &last
00338 << ", curr = " << &curr
00339 << ", lastact = " << lastact
00340 << ", r = " << reward
00341 << ", t = " << terminal
00342 << ")" << endl;
00343
00344 double initTime = 0;
00345 double timeTwo = 0;
00346 double timeThree = 0;
00347
00348 if (model == NULL)
00349 initModel(last.size());
00350
00351
00352 if (false || TIMEDEBUG) initTime = getSeconds();
00353
00354 modelChanged = planner->updateModelWithExperience(last, lastact, curr, reward, terminal) || modelChanged;
00355
00356 if (false || TIMEDEBUG) timeTwo = getSeconds();
00357
00358 if (AGENTDEBUG) cout << "Agent Added exp: " << modelChanged << endl;
00359
00360
00361 if ((modelChanged && (!seeding || modelType == RMAX)
00362 && (nactions % BATCH_FREQ == 0))){
00363 planner->planOnNewModel();
00364 modelChanged = false;
00365 }
00366
00367 if (TIMEDEBUG){
00368
00369 timeThree = getSeconds();
00370
00371 planningTime += (timeThree-timeTwo);
00372 modelUpdateTime += (timeTwo - initTime);
00373
00374 if (nactions % 10 == 0){
00375 cout << nactions
00376 << " UpdateModel " << modelUpdateTime/ (float)nactions
00377 << " createPolicy " << planningTime/(float)nactions << endl;
00378
00379 }
00380 }
00381
00382
00383 }
00384
00385
00386 int ModelBasedAgent::chooseAction(const std::vector<float> &s){
00387 if (AGENTDEBUG) cout << "chooseAction(s = " << &s
00388 << ")" << endl;
00389
00390 double initTime = 0;
00391 double timeTwo = 0;
00392
00393
00394 if (TIMEDEBUG) initTime = getSeconds();
00395 int act = planner->getBestAction(s);
00396 if (TIMEDEBUG) {
00397 timeTwo = getSeconds();
00398 planningTime += (timeTwo - initTime);
00399 }
00400
00401 if (exploreType == EPSILONGREEDY && rng.bernoulli(epsilon)){
00402
00403 act = rng.uniformDiscrete(0, numactions-1);
00404 }
00405
00406 if (SIMPLEDEBUG){
00407 cout << endl << "Action " << nactions
00408 << ": State " << (s)[0];
00409 for (unsigned i = 1; i < s.size(); i++){
00410 cout << "," << (s)[i];
00411 }
00412 cout << ", Took action " << act << ", ";
00413 }
00414
00415 nactions++;
00416
00417
00418 return act;
00419 }
00420
00421 void ModelBasedAgent::saveStateAndAction(const std::vector<float> &s, int act){
00422 if (AGENTDEBUG) cout << "saveStateAndAction(s = " << &s
00423 << ", act = " << act
00424 << ")" << endl;
00425 prevstate = s;
00426 prevact = act;
00427
00428 }
00429
00430
00431
00432
00433
00434 double ModelBasedAgent::getSeconds(){
00435 struct timezone tz;
00436 timeval timeT;
00437 gettimeofday(&timeT, &tz);
00438 return timeT.tv_sec + (timeT.tv_usec / 1000000.0);
00439 }
00440
00441
00442 void ModelBasedAgent::seedExp(std::vector<experience> seeds){
00443 if (AGENTDEBUG) cout << "seed experiences" << endl;
00444
00445 if (seeds.size() == 0) return;
00446
00447 if (model == NULL)
00448 initModel(seeds[0].s.size());
00449
00450 seeding = true;
00451 planner->setSeeding(true);
00452
00453
00454 for (unsigned i = 0; i < seeds.size(); i++){
00455 experience e = seeds[i];
00456
00457
00458
00459 updateWithNewExperience(e.s, e.next, e.act, e.reward, e.terminal);
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470 }
00471
00472 seeding = false;
00473 planner->setSeeding(false);
00474
00475 if (seeds.size() > 0)
00476 planner->planOnNewModel();
00477
00478 }
00479
00480
00481 void ModelBasedAgent::setDebug(bool d){
00482 AGENTDEBUG = d;
00483 }
00484
00485 void ModelBasedAgent::savePolicy(const char* filename){
00486 planner->savePolicy(filename);
00487 }
00488
00489
00490
00491 void ModelBasedAgent::logValues(ofstream *of, int xmin, int xmax, int ymin, int ymax){
00492
00493
00494 if (plannerType == PARALLEL_ET_UCT){
00495 ((ParallelETUCT*)planner)->logValues(of, xmin, xmax, ymin, ymax);
00496 }
00497 if (plannerType == ET_UCT){
00498 ((ETUCT*)planner)->logValues(of, xmin, xmax, ymin, ymax);
00499 }
00500
00501 }