00001
00005 #include <rl_common/Random.h>
00006 #include <rl_common/core.hh>
00007
00008 #include <stdio.h>
00009 #include <string.h>
00010 #include <sys/time.h>
00011
00013
00015 #include <rl_env/RobotCarVel.hh>
00016 #include <rl_env/fourrooms.hh>
00017 #include <rl_env/tworooms.hh>
00018 #include <rl_env/taxi.hh>
00019 #include <rl_env/FuelRooms.hh>
00020 #include <rl_env/stocks.hh>
00021 #include <rl_env/energyrooms.hh>
00022 #include <rl_env/MountainCar.hh>
00023 #include <rl_env/CartPole.hh>
00024 #include <rl_env/LightWorld.hh>
00025
00026
00028
00030 #include <rl_agent/QLearner.hh>
00031 #include <rl_agent/ModelBasedAgent.hh>
00032 #include <rl_agent/DiscretizationAgent.hh>
00033 #include <rl_agent/SavedPolicy.hh>
00034 #include <rl_agent/Dyna.hh>
00035 #include <rl_agent/Sarsa.hh>
00036
00037
00038
00039 #include <vector>
00040 #include <sstream>
00041 #include <iostream>
00042
00043 #include <getopt.h>
00044 #include <stdlib.h>
00045
00046 unsigned NUMEPISODES = 1000;
00047 const unsigned NUMTRIALS = 1;
00048 unsigned MAXSTEPS = 1000;
00049 bool PRINTS = false;
00050
00051
00052 void displayHelp(){
00053 cout << "\n Call experiment --agent type --env type [options]\n";
00054 cout << "Agent types: qlearner sarsa modelbased rmax texplore dyna savedpolicy\n";
00055 cout << "Env types: taxi tworooms fourrooms energy fuelworld mcar cartpole car2to7 car7to2 carrandom stocks lightworld\n";
00056
00057 cout << "\n Agent Options:\n";
00058 cout << "--gamma value (discount factor between 0 and 1)\n";
00059 cout << "--epsilon value (epsilon for epsilon-greedy exploration)\n";
00060 cout << "--alpha value (learning rate alpha)\n";
00061 cout << "--initialvalue value (initial q values)\n";
00062 cout << "--actrate value (action selection rate (Hz))\n";
00063 cout << "--lamba value (lamba for eligibility traces)\n";
00064 cout << "--m value (parameter for R-Max)\n";
00065 cout << "--k value (For Dyna: # of model based updates to do between each real world update)\n";
00066 cout << "--history value (# steps of history to use for planning with delay)\n";
00067 cout << "--filename file (file to load saved policy from for savedpolicy agent)\n";
00068 cout << "--model type (tabular,tree,m5tree)\n";
00069 cout << "--planner type (vi,pi,sweeping,uct,parallel-uct,delayed-uct,delayed-parallel-uct)\n";
00070 cout << "--explore type (unknown,greedy,epsilongreedy,variancenovelty)\n";
00071 cout << "--combo type (average,best,separate)\n";
00072 cout << "--nmodels value (# of models)\n";
00073 cout << "--nstates value (optionally discretize domain into value # of states on each feature)\n";
00074 cout << "--reltrans (learn relative transitions)\n";
00075 cout << "--abstrans (learn absolute transitions)\n";
00076 cout << "--v value (For TEXPLORE: b/v coefficient for rewarding state-actions where models disagree)\n";
00077 cout << "--n value (For TEXPLORE: n coefficient for rewarding state-actions which are novel)\n";
00078
00079 cout << "\n Env Options:\n";
00080 cout << "--deterministic (deterministic version of domain)\n";
00081 cout << "--stochastic (stochastic version of domain)\n";
00082 cout << "--delay value (# steps of action delay (for mcar and tworooms)\n";
00083 cout << "--lag (turn on brake lag for car driving domain)\n";
00084 cout << "--highvar (have variation fuel costs in Fuel World)\n";
00085 cout << "--nsectors value (# sectors for stocks domain)\n";
00086 cout << "--nstocks value (# stocks for stocks domain)\n";
00087
00088 cout << "\n--prints (turn on debug printing of actions/rewards)\n";
00089 cout << "--nepisodes value (# of episodes to run (1000 default)\n";
00090 cout << "--seed value (integer seed for random number generator)\n";
00091
00092 cout << "\n For more info, see: http://www.ros.org/wiki/rl_experiment\n";
00093
00094 exit(-1);
00095
00096 }
00097
00098
00099 int main(int argc, char **argv) {
00100
00101
00102 char* agentType = NULL;
00103 char* envType = NULL;
00104 float discountfactor = 0.99;
00105 float epsilon = 0.1;
00106 float alpha = 0.3;
00107 float initialvalue = 0.0;
00108 float actrate = 10.0;
00109 float lambda = 0.1;
00110 int M = 5;
00111 int modelType = C45TREE;
00112 int exploreType = GREEDY;
00113 int predType = BEST;
00114 int plannerType = PAR_ETUCT_ACTUAL;
00115 int nmodels = 1;
00116 bool reltrans = true;
00117 bool deptrans = false;
00118 float v = 0;
00119 float n = 0;
00120 float featPct = 0.2;
00121 int nstates = 0;
00122 int k = 1000;
00123 char *filename = NULL;
00124 bool stochastic = true;
00125 int nstocks = 3;
00126 int nsectors = 3;
00127 int delay = 0;
00128 bool lag = false;
00129 bool highvar = false;
00130 int history = 0;
00131 int seed = 1;
00132
00133
00134
00135 bool gotAgent = false;
00136 for (int i = 1; i < argc-1; i++){
00137 if (strcmp(argv[i], "--agent") == 0){
00138 gotAgent = true;
00139 agentType = argv[i+1];
00140 }
00141 }
00142 if (!gotAgent) {
00143 cout << "--agent type option is required" << endl;
00144 displayHelp();
00145 }
00146
00147
00148 if (strcmp(agentType, "rmax") == 0){
00149 modelType = RMAX;
00150 exploreType = EXPLORE_UNKNOWN;
00151 predType = BEST;
00152 plannerType = VALUE_ITERATION;
00153 nmodels = 1;
00154 reltrans = false;
00155 M = 5;
00156 history = 0;
00157 } else if (strcmp(agentType, "texplore") == 0){
00158 modelType = C45TREE;
00159 exploreType = DIFF_AND_NOVEL_BONUS;
00160 v = 0;
00161 n = 0;
00162 predType = AVERAGE;
00163 plannerType = PAR_ETUCT_ACTUAL;
00164 nmodels = 5;
00165 reltrans = true;
00166 M = 0;
00167 history = 0;
00168 }
00169
00170
00171 bool gotEnv = false;
00172 for (int i = 1; i < argc-1; i++){
00173 if (strcmp(argv[i], "--env") == 0){
00174 gotEnv = true;
00175 envType = argv[i+1];
00176 }
00177 }
00178 if (!gotEnv) {
00179 cout << "--env type option is required" << endl;
00180 displayHelp();
00181 }
00182
00183
00184 char ch;
00185 const char* optflags = "geairlmoxpcn:";
00186 int option_index = 0;
00187 static struct option long_options[] = {
00188 {"gamma", 1, 0, 'g'},
00189 {"discountfactor", 1, 0, 'g'},
00190 {"epsilon", 1, 0, 'e'},
00191 {"alpha", 1, 0, 'a'},
00192 {"initialvalue", 1, 0, 'i'},
00193 {"actrate", 1, 0, 'r'},
00194 {"lambda", 1, 0, 'l'},
00195 {"m", 1, 0, 'm'},
00196 {"model", 1, 0, 'o'},
00197 {"explore", 1, 0, 'x'},
00198 {"planner", 1, 0, 'p'},
00199 {"combo", 1, 0, 'c'},
00200 {"nmodels", 1, 0, '#'},
00201 {"reltrans", 0, 0, 't'},
00202 {"abstrans", 0, 0, '0'},
00203 {"seed", 1, 0, 's'},
00204 {"agent", 1, 0, 'q'},
00205 {"prints", 0, 0, 'd'},
00206 {"nstates", 1, 0, 'w'},
00207 {"k", 1, 0, 'k'},
00208 {"filename", 1, 0, 'f'},
00209 {"history", 1, 0, 'y'},
00210 {"b", 1, 0, 'b'},
00211 {"v", 1, 0, 'v'},
00212 {"n", 1, 0, 'n'},
00213
00214 {"env", 1, 0, 1},
00215 {"deterministic", 0, 0, 2},
00216 {"stochastic", 0, 0, 3},
00217 {"delay", 1, 0, 4},
00218 {"nsectors", 1, 0, 5},
00219 {"nstocks", 1, 0, 6},
00220 {"lag", 0, 0, 7},
00221 {"nolag", 0, 0, 8},
00222 {"highvar", 0, 0, 11},
00223 {"nepisodes", 1, 0, 12}
00224
00225 };
00226
00227 bool epsilonChanged = false;
00228 bool actrateChanged = false;
00229 bool mChanged = false;
00230 bool bvnChanged = false;
00231 bool lambdaChanged = false;
00232
00233 while(-1 != (ch = getopt_long_only(argc, argv, optflags, long_options, &option_index))) {
00234 switch(ch) {
00235
00236 case 'g':
00237 discountfactor = std::atof(optarg);
00238 cout << "discountfactor: " << discountfactor << endl;
00239 break;
00240
00241 case 'e':
00242 epsilonChanged = true;
00243 epsilon = std::atof(optarg);
00244 cout << "epsilon: " << epsilon << endl;
00245 break;
00246
00247 case 'y':
00248 {
00249 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00250 history = std::atoi(optarg);
00251 cout << "history: " << history << endl;
00252 } else {
00253 cout << "--history is not a valid option for agent: " << agentType << endl;
00254 exit(-1);
00255 }
00256 break;
00257 }
00258
00259 case 'k':
00260 {
00261 if (strcmp(agentType, "dyna") == 0){
00262 k = std::atoi(optarg);
00263 cout << "k: " << k << endl;
00264 } else {
00265 cout << "--k is only a valid option for the Dyna agent" << endl;
00266 exit(-1);
00267 }
00268 break;
00269 }
00270
00271 case 'f':
00272 filename = optarg;
00273 cout << "policy filename: " << filename << endl;
00274 break;
00275
00276 case 'a':
00277 {
00278 if (strcmp(agentType, "qlearner") == 0 || strcmp(agentType, "dyna") == 0 || strcmp(agentType, "sarsa") == 0){
00279 alpha = std::atof(optarg);
00280 cout << "alpha: " << alpha << endl;
00281 } else {
00282 cout << "--alpha option is only valid for Q-Learning, Dyna, and Sarsa" << endl;
00283 exit(-1);
00284 }
00285 break;
00286 }
00287
00288 case 'i':
00289 {
00290 if (strcmp(agentType, "qlearner") == 0 || strcmp(agentType, "dyna") == 0 || strcmp(agentType, "sarsa") == 0){
00291 initialvalue = std::atof(optarg);
00292 cout << "initialvalue: " << initialvalue << endl;
00293 } else {
00294 cout << "--initialvalue option is only valid for Q-Learning, Dyna, and Sarsa" << endl;
00295 exit(-1);
00296 }
00297 break;
00298 }
00299
00300 case 'r':
00301 {
00302 actrateChanged = true;
00303 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0){
00304 actrate = std::atof(optarg);
00305 cout << "actrate: " << actrate << endl;
00306 } else {
00307 cout << "Model-free methods do not require an action rate" << endl;
00308 exit(-1);
00309 }
00310 break;
00311 }
00312
00313 case 'l':
00314 {
00315 lambdaChanged = true;
00316 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0 || strcmp(agentType, "sarsa") == 0){
00317 lambda = std::atof(optarg);
00318 cout << "lambda: " << lambda << endl;
00319 } else {
00320 cout << "--lambda option is invalid for this agent: " << agentType << endl;
00321 exit(-1);
00322 }
00323 break;
00324 }
00325
00326 case 'm':
00327 {
00328 mChanged = true;
00329 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0){
00330 M = std::atoi(optarg);
00331 cout << "M: " << M << endl;
00332 } else {
00333 cout << "--M option only useful for model-based agents, not " << agentType << endl;
00334 exit(-1);
00335 }
00336 break;
00337 }
00338
00339 case 'o':
00340 {
00341 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0){
00342 if (strcmp(optarg, "tabular") == 0) modelType = RMAX;
00343 else if (strcmp(optarg, "tree") == 0) modelType = C45TREE;
00344 else if (strcmp(optarg, "texplore") == 0) modelType = C45TREE;
00345 else if (strcmp(optarg, "c45tree") == 0) modelType = C45TREE;
00346 else if (strcmp(optarg, "m5tree") == 0) modelType = M5ALLMULTI;
00347 if (strcmp(agentType, "rmax") == 0 && modelType != RMAX){
00348 cout << "R-Max should use tabular model" << endl;
00349 exit(-1);
00350 }
00351 } else {
00352 cout << "Model-free methods do not need a model, --model option does nothing for this agent type" << endl;
00353 exit(-1);
00354 }
00355 cout << "model: " << modelNames[modelType] << endl;
00356 break;
00357 }
00358
00359 case 'x':
00360 {
00361 if (strcmp(optarg, "unknown") == 0) exploreType = EXPLORE_UNKNOWN;
00362 else if (strcmp(optarg, "greedy") == 0) exploreType = GREEDY;
00363 else if (strcmp(optarg, "epsilongreedy") == 0) exploreType = EPSILONGREEDY;
00364 else if (strcmp(optarg, "unvisitedstates") == 0) exploreType = UNVISITED_BONUS;
00365 else if (strcmp(optarg, "unvisitedactions") == 0) exploreType = UNVISITED_ACT_BONUS;
00366 else if (strcmp(optarg, "variancenovelty") == 0) exploreType = DIFF_AND_NOVEL_BONUS;
00367 if (strcmp(agentType, "rmax") == 0 && exploreType != EXPLORE_UNKNOWN){
00368 cout << "R-Max should use \"--explore unknown\" exploration" << endl;
00369 exit(-1);
00370 }
00371 else if (strcmp(agentType, "texplore") != 0 && strcmp(agentType, "modelbased") != 0 && strcmp(agentType, "rmax") != 0 && (exploreType != GREEDY && exploreType != EPSILONGREEDY)) {
00372 cout << "Model free methods must use either greedy or epsilon-greedy exploration!" << endl;
00373 exploreType = EPSILONGREEDY;
00374 exit(-1);
00375 }
00376 cout << "explore: " << exploreNames[exploreType] << endl;
00377 break;
00378 }
00379
00380 case 'p':
00381 {
00382 if (strcmp(optarg, "vi") == 0) plannerType = VALUE_ITERATION;
00383 else if (strcmp(optarg, "valueiteration") == 0) plannerType = VALUE_ITERATION;
00384 else if (strcmp(optarg, "policyiteration") == 0) plannerType = POLICY_ITERATION;
00385 else if (strcmp(optarg, "pi") == 0) plannerType = POLICY_ITERATION;
00386 else if (strcmp(optarg, "sweeping") == 0) plannerType = PRI_SWEEPING;
00387 else if (strcmp(optarg, "prioritizedsweeping") == 0) plannerType = PRI_SWEEPING;
00388 else if (strcmp(optarg, "uct") == 0) plannerType = ET_UCT_ACTUAL;
00389 else if (strcmp(optarg, "paralleluct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00390 else if (strcmp(optarg, "realtimeuct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00391 else if (strcmp(optarg, "realtime-uct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00392 else if (strcmp(optarg, "parallel-uct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00393 else if (strcmp(optarg, "delayeduct") == 0) plannerType = POMDP_ETUCT;
00394 else if (strcmp(optarg, "delayed-uct") == 0) plannerType = POMDP_ETUCT;
00395 else if (strcmp(optarg, "delayedparalleluct") == 0) plannerType = POMDP_PAR_ETUCT;
00396 else if (strcmp(optarg, "delayed-parallel-uct") == 0) plannerType = POMDP_PAR_ETUCT;
00397 if (strcmp(agentType, "texplore") != 0 && strcmp(agentType, "modelbased") != 0 && strcmp(agentType, "rmax") != 0){
00398 cout << "Model-free methods do not require planners, --planner option does nothing with this agent" << endl;
00399 exit(-1);
00400 }
00401 if (strcmp(agentType, "rmax") == 0 && plannerType != VALUE_ITERATION){
00402 cout << "Typical implementation of R-Max would use value iteration, but another planner type is ok" << endl;
00403 }
00404 cout << "planner: " << plannerNames[plannerType] << endl;
00405 break;
00406 }
00407
00408 case 'c':
00409 {
00410 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00411 if (strcmp(optarg, "average") == 0) predType = AVERAGE;
00412 else if (strcmp(optarg, "weighted") == 0) predType = WEIGHTAVG;
00413 else if (strcmp(optarg, "best") == 0) predType = BEST;
00414 else if (strcmp(optarg, "separate") == 0) predType = SEPARATE;
00415 cout << "predType: " << comboNames[predType] << endl;
00416 } else {
00417 cout << "--combo is an invalid option for agent: " << agentType << endl;
00418 exit(-1);
00419 }
00420 break;
00421 }
00422
00423 case '#':
00424 {
00425 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00426 nmodels = std::atoi(optarg);
00427 cout << "nmodels: " << nmodels << endl;
00428 } else {
00429 cout << "--nmodels is an invalid option for agent: " << agentType << endl;
00430 exit(-1);
00431 }
00432 if (nmodels < 1){
00433 cout << "nmodels must be > 0" << endl;
00434 exit(-1);
00435 }
00436 break;
00437 }
00438
00439 case 't':
00440 {
00441 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00442 reltrans = true;
00443 cout << "reltrans: " << reltrans << endl;
00444 } else {
00445 cout << "--reltrans is an invalid option for agent: " << agentType << endl;
00446 exit(-1);
00447 }
00448 break;
00449 }
00450
00451 case '0':
00452 {
00453 if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00454 reltrans = false;
00455 cout << "reltrans: " << reltrans << endl;
00456 } else {
00457 cout << "--abstrans is an invalid option for agent: " << agentType << endl;
00458 exit(-1);
00459 }
00460 break;
00461 }
00462
00463 case 's':
00464 seed = std::atoi(optarg);
00465 cout << "seed: " << seed << endl;
00466 break;
00467
00468 case 'q':
00469
00470 cout << "agent: " << agentType << endl;
00471 break;
00472
00473 case 'd':
00474 PRINTS = true;
00475 break;
00476
00477 case 'w':
00478 nstates = std::atoi(optarg);
00479 cout << "nstates for discretization: " << nstates << endl;
00480 break;
00481
00482 case 'v':
00483 case 'b':
00484 {
00485 bvnChanged = true;
00486 if (strcmp(agentType, "texplore") == 0){
00487 v = std::atof(optarg);
00488 cout << "v coefficient (variance bonus): " << v << endl;
00489 }
00490 else {
00491 cout << "--v and --b are invalid options for agent: " << agentType << endl;
00492 exit(-1);
00493 }
00494 break;
00495 }
00496
00497 case 'n':
00498 {
00499 bvnChanged = true;
00500 if (strcmp(agentType, "texplore") == 0){
00501 n = std::atof(optarg);
00502 cout << "n coefficient (novelty bonus): " << n << endl;
00503 }
00504 else {
00505 cout << "--n is an invalid option for agent: " << agentType << endl;
00506 exit(-1);
00507 }
00508 break;
00509 }
00510
00511 case 2:
00512 stochastic = false;
00513 cout << "stochastic: " << stochastic << endl;
00514 break;
00515
00516 case 11:
00517 {
00518 if (strcmp(envType, "fuelworld") == 0){
00519 highvar = true;
00520 cout << "fuel world fuel cost variation: " << highvar << endl;
00521 } else {
00522 cout << "--highvar is only a valid option for the fuelworld domain." << endl;
00523 exit(-1);
00524 }
00525 break;
00526 }
00527
00528 case 3:
00529 stochastic = true;
00530 cout << "stochastic: " << stochastic << endl;
00531 break;
00532
00533 case 4:
00534 {
00535 if (strcmp(envType, "mcar") == 0 || strcmp(envType, "tworooms") == 0){
00536 delay = std::atoi(optarg);
00537 cout << "delay steps: " << delay << endl;
00538 } else {
00539 cout << "--delay option is only valid for the mcar and tworooms domains" << endl;
00540 exit(-1);
00541 }
00542 break;
00543 }
00544
00545 case 5:
00546 {
00547 if (strcmp(envType, "stocks") == 0){
00548 nsectors = std::atoi(optarg);
00549 cout << "nsectors: " << nsectors << endl;
00550 } else {
00551 cout << "--nsectors option is only valid for the stocks domain" << endl;
00552 exit(-1);
00553 }
00554 break;
00555 }
00556
00557 case 6:
00558 {
00559 if (strcmp(envType, "stocks") == 0){
00560 nstocks = std::atoi(optarg);
00561 cout << "nstocks: " << nstocks << endl;
00562 } else {
00563 cout << "--nstocks option is only valid for the stocks domain" << endl;
00564 exit(-1);
00565 }
00566 break;
00567 }
00568
00569 case 7:
00570 {
00571 if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00572 lag = true;
00573 cout << "lag: " << lag << endl;
00574 } else {
00575 cout << "--lag option is only valid for car velocity tasks" << endl;
00576 exit(-1);
00577 }
00578 break;
00579 }
00580
00581 case 8:
00582 {
00583 if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00584 lag = false;
00585 cout << "lag: " << lag << endl;
00586 } else {
00587 cout << "--nolag option is only valid for car velocity tasks" << endl;
00588 exit(-1);
00589 }
00590 break;
00591 }
00592
00593 case 1:
00594
00595 cout << "env: " << envType << endl;
00596 break;
00597
00598 case 12:
00599 NUMEPISODES = std::atoi(optarg);
00600 cout << "Num Episodes: " << NUMEPISODES << endl;
00601 break;
00602
00603 case 'h':
00604 case '?':
00605 case 0:
00606 default:
00607 displayHelp();
00608 break;
00609 }
00610 }
00611
00612
00613 if (exploreType == DIFF_AND_NOVEL_BONUS && v == 0 && n == 0)
00614 exploreType = GREEDY;
00615
00616
00617
00618 if (epsilonChanged && exploreType != EPSILONGREEDY){
00619 cout << "No reason to change epsilon when not using epsilon-greedy exploration" << endl;
00620 exit(-1);
00621 }
00622
00623
00624 if (history > 0 && (plannerType == VALUE_ITERATION || plannerType == POLICY_ITERATION || plannerType == PRI_SWEEPING)){
00625 cout << "No reason to set history higher than 0 if not using a UCT planner" << endl;
00626 exit(-1);
00627 }
00628
00629
00630 if (actrateChanged && (plannerType == VALUE_ITERATION || plannerType == POLICY_ITERATION || plannerType == PRI_SWEEPING)){
00631 cout << "No reason to set actrate if not using a UCT planner" << endl;
00632 exit(-1);
00633 }
00634
00635
00636 if (lambdaChanged && (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0) && (plannerType == VALUE_ITERATION || plannerType == POLICY_ITERATION || plannerType == PRI_SWEEPING)){
00637 cout << "No reason to set actrate if not using a UCT planner" << endl;
00638 exit(-1);
00639 }
00640
00641
00642 if (bvnChanged && exploreType != DIFF_AND_NOVEL_BONUS){
00643 cout << "No reason to set n or v if not doing variance & novelty exploration" << endl;
00644 exit(-1);
00645 }
00646
00647
00648 if (predType != BEST && nmodels == 1){
00649 cout << "No reason to have model combo other than best with nmodels = 1" << endl;
00650 exit(-1);
00651 }
00652
00653
00654 if (mChanged && exploreType != EXPLORE_UNKNOWN){
00655 cout << "No reason to set M if not doing R-max style Explore Unknown exploration" << endl;
00656 exit(-1);
00657 }
00658
00659 if (PRINTS){
00660 if (stochastic)
00661 cout << "Stohastic\n";
00662 else
00663 cout << "Deterministic\n";
00664 }
00665
00666 Random rng(1 + seed);
00667
00668 std::vector<int> statesPerDim;
00669
00670
00671 Environment* e;
00672
00673 if (strcmp(envType, "cartpole") == 0){
00674 if (PRINTS) cout << "Environment: Cart Pole\n";
00675 e = new CartPole(rng, stochastic);
00676 }
00677
00678 else if (strcmp(envType, "mcar") == 0){
00679 if (PRINTS) cout << "Environment: Mountain Car\n";
00680 e = new MountainCar(rng, stochastic, false, delay);
00681 }
00682
00683
00684 else if (strcmp(envType, "taxi") == 0){
00685 if (PRINTS) cout << "Environment: Taxi\n";
00686 e = new Taxi(rng, stochastic);
00687 }
00688
00689
00690 else if (strcmp(envType, "lightworld") == 0){
00691 if (PRINTS) cout << "Environment: Light World\n";
00692 e = new LightWorld(rng, stochastic, 4);
00693 }
00694
00695
00696 else if (strcmp(envType, "tworooms") == 0){
00697 if (PRINTS) cout << "Environment: TwoRooms\n";
00698 e = new TwoRooms(rng, stochastic, true, delay, false);
00699 }
00700
00701
00702 else if (strcmp(envType, "car2to7") == 0){
00703 if (PRINTS) cout << "Environment: Car Velocity 2 to 7 m/s\n";
00704 e = new RobotCarVel(rng, false, true, false, lag);
00705 statesPerDim.resize(4,0);
00706 statesPerDim[0] = 12;
00707 statesPerDim[1] = 120;
00708 statesPerDim[2] = 4;
00709 statesPerDim[3] = 10;
00710 MAXSTEPS = 100;
00711 }
00712
00713 else if (strcmp(envType, "car7to2") == 0){
00714 if (PRINTS) cout << "Environment: Car Velocity 7 to 2 m/s\n";
00715 e = new RobotCarVel(rng, false, false, false, lag);
00716 statesPerDim.resize(4,0);
00717 statesPerDim[0] = 12;
00718 statesPerDim[1] = 120;
00719 statesPerDim[2] = 4;
00720 statesPerDim[3] = 10;
00721 MAXSTEPS = 100;
00722 }
00723
00724 else if (strcmp(envType, "carrandom") == 0){
00725 if (PRINTS) cout << "Environment: Car Velocity Random Velocities\n";
00726 e = new RobotCarVel(rng, true, false, false, lag);
00727 statesPerDim.resize(4,0);
00728 statesPerDim[0] = 12;
00729 statesPerDim[1] = 48;
00730 statesPerDim[2] = 4;
00731 statesPerDim[3] = 10;
00732 MAXSTEPS = 100;
00733 }
00734
00735
00736 else if (strcmp(envType, "fourrooms") == 0){
00737 if (PRINTS) cout << "Environment: FourRooms\n";
00738 e = new FourRooms(rng, stochastic, true, false);
00739 }
00740
00741
00742 else if (strcmp(envType, "energy") == 0){
00743 if (PRINTS) cout << "Environment: EnergyRooms\n";
00744 e = new EnergyRooms(rng, stochastic, true, false);
00745 }
00746
00747
00748 else if (strcmp(envType, "fuelworld") == 0){
00749 if (PRINTS) cout << "Environment: FuelWorld\n";
00750 e = new FuelRooms(rng, highvar, stochastic);
00751 }
00752
00753
00754 else if (strcmp(envType, "stocks") == 0){
00755 if (PRINTS) cout << "Enironment: Stocks with " << nsectors
00756 << " sectors and " << nstocks << " stocks\n";
00757 e = new Stocks(rng, stochastic, nsectors, nstocks);
00758 }
00759
00760 else {
00761 std::cerr << "Invalid env type" << endl;
00762 exit(-1);
00763 }
00764
00765 const int numactions = e->getNumActions();
00766
00767 std::vector<float> minValues;
00768 std::vector<float> maxValues;
00769 e->getMinMaxFeatures(&minValues, &maxValues);
00770 bool episodic = e->isEpisodic();
00771
00772 cout << "Environment is ";
00773 if (!episodic) cout << "NOT ";
00774 cout << "episodic." << endl;
00775
00776
00777 for (unsigned i = 0; i < minValues.size(); i++){
00778 if (PRINTS) cout << "Feat " << i << " min: " << minValues[i]
00779 << " max: " << maxValues[i] << endl;
00780 }
00781
00782
00783 float rMax = 0.0;
00784 float rMin = -1.0;
00785
00786 e->getMinMaxReward(&rMin, &rMax);
00787 float rRange = rMax - rMin;
00788 if (PRINTS) cout << "Min Reward: " << rMin
00789 << ", Max Reward: " << rMax << endl;
00790
00791
00792 if (rMax <= 0.0 && (exploreType == TWO_MODE_PLUS_R ||
00793 exploreType == CONTINUOUS_BONUS_R ||
00794 exploreType == CONTINUOUS_BONUS ||
00795 exploreType == THRESHOLD_BONUS_R)){
00796 rMax = 1.0;
00797 }
00798
00799
00800 float rsum = 0;
00801
00802 if (statesPerDim.size() == 0){
00803 cout << "set statesPerDim to " << nstates << " for all dim" << endl;
00804 statesPerDim.resize(minValues.size(), nstates);
00805 }
00806
00807 for (unsigned j = 0; j < NUMTRIALS; ++j) {
00808
00809
00810 Agent* agent;
00811
00812 if (strcmp(agentType, "qlearner") == 0){
00813 if (PRINTS) cout << "Agent: QLearner" << endl;
00814 agent = new QLearner(numactions,
00815 discountfactor,
00816 initialvalue,
00817 alpha,
00818 epsilon,
00819 rng);
00820 }
00821
00822 else if (strcmp(agentType, "dyna") == 0){
00823 if (PRINTS) cout << "Agent: Dyna" << endl;
00824 agent = new Dyna(numactions,
00825 discountfactor,
00826 initialvalue,
00827 alpha,
00828 k,
00829 epsilon,
00830 rng);
00831 }
00832
00833 else if (strcmp(agentType, "sarsa") == 0){
00834 if (PRINTS) cout << "Agent: SARSA" << endl;
00835 agent = new Sarsa(numactions,
00836 discountfactor,
00837 initialvalue,
00838 alpha,
00839 epsilon,
00840 lambda,
00841 rng);
00842 }
00843
00844 else if (strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") || strcmp(agentType, "texplore")){
00845 if (PRINTS) cout << "Agent: Model Based" << endl;
00846 agent = new ModelBasedAgent(numactions,
00847 discountfactor,
00848 rMax, rRange,
00849 modelType,
00850 exploreType,
00851 predType,
00852 nmodels,
00853 plannerType,
00854 epsilon,
00855 lambda,
00856 (1.0/actrate),
00857 M,
00858 minValues, maxValues,
00859 statesPerDim,
00860 history, v, n,
00861 deptrans, reltrans, featPct, stochastic, episodic,
00862 rng);
00863 }
00864
00865 else if (strcmp(agentType, "savedpolicy") == 0){
00866 if (PRINTS) cout << "Agent: Saved Policy" << endl;
00867 agent = new SavedPolicy(numactions,filename);
00868 }
00869
00870 else {
00871 std::cerr << "ERROR: Invalid agent type" << endl;
00872 exit(-1);
00873 }
00874
00875
00876 int totalStates = 1;
00877 Agent* a2 = agent;
00878
00879 if (nstates > 0 && (modelType != M5ALLMULTI || strcmp(agentType, "qlearner") == 0)){
00880 int totalStates = powf(nstates,minValues.size());
00881 if (PRINTS) cout << "Discretize with " << nstates << ", total: " << totalStates << endl;
00882 agent = new DiscretizationAgent(nstates, a2,
00883 minValues, maxValues, PRINTS);
00884 }
00885 else {
00886 totalStates = 1;
00887 for (unsigned i = 0; i < minValues.size(); i++){
00888 int range = 1+maxValues[i] - minValues[i];
00889 totalStates *= range;
00890 }
00891 if (PRINTS) cout << "No discretization, total: " << totalStates << endl;
00892 }
00893
00894
00895 agent->seedExp(e->getSeedings());
00896
00897
00898 if (!episodic){
00899
00900
00901 float sum = 0;
00902 int steps = 0;
00903 float trialSum = 0.0;
00904
00905 int a = 0;
00906 float r = 0;
00907
00909
00911 for (unsigned i = 0; i < NUMEPISODES; ++i){
00912
00913 std::vector<float> es = e->sensation();
00914
00915
00916 if (i == 0){
00917
00918
00919 a = agent->first_action(es);
00920 r = e->apply(a);
00921
00922 } else {
00923
00924 a = agent->next_action(r, es);
00925 r = e->apply(a);
00926 }
00927
00928
00929 sum += r;
00930 ++steps;
00931
00932 std::cerr << r << endl;
00933
00934 }
00936
00937 rsum += sum;
00938 trialSum += sum;
00939 if (PRINTS) cout << "Rsum(trial " << j << "): " << trialSum << " Avg: "
00940 << (rsum / (float)(j+1))<< endl;
00941
00942 }
00943
00944
00945 else {
00946
00948
00950 for (unsigned i = 0; i < NUMEPISODES; ++i) {
00951
00952
00953 float sum = 0;
00954 int steps = 0;
00955
00956
00957 std::vector<float> es = e->sensation();
00958 int a = agent->first_action(es);
00959 float r = e->apply(a);
00960
00961
00962 sum += r;
00963 ++steps;
00964
00965 while (!e->terminal() && steps < MAXSTEPS) {
00966
00967
00968 es = e->sensation();
00969 a = agent->next_action(r, es);
00970 r = e->apply(a);
00971
00972
00973 sum += r;
00974 ++steps;
00975
00976 }
00977
00978
00979 if (e->terminal()){
00980 agent->last_action(r);
00981 }else{
00982 agent->next_action(r, e->sensation());
00983 }
00984
00985 e->reset();
00986 std::cerr << sum << endl;
00987
00988 rsum += sum;
00989
00990 }
00991
00992 }
00993
00994 if (NUMTRIALS > 1) delete agent;
00995
00996 }
00997
00998 if (PRINTS) cout << "Avg Rsum: " << (rsum / (float)NUMTRIALS) << endl;
00999
01000 }
01001