Go to the documentation of this file.
00005 #include <rl_common/Random.h>
00006 #include <rl_common/core.hh>
00008 #include <stdio.h>
00009 #include <string.h>
00010 #include <sys/time.h>
00013 // Environments //
00015 #include <rl_env/RobotCarVel.hh>
00016 #include <rl_env/fourrooms.hh>
00017 #include <rl_env/tworooms.hh>
00018 #include <rl_env/taxi.hh>
00019 #include <rl_env/FuelRooms.hh>
00020 #include <rl_env/stocks.hh>
00021 #include <rl_env/energyrooms.hh>
00022 #include <rl_env/MountainCar.hh>
00023 #include <rl_env/CartPole.hh>
00024 #include <rl_env/LightWorld.hh>
00028 // Agents //
00030 #include <rl_agent/QLearner.hh>
00031 #include <rl_agent/ModelBasedAgent.hh>
00032 #include <rl_agent/DiscretizationAgent.hh>
00033 #include <rl_agent/SavedPolicy.hh>
00034 #include <rl_agent/Dyna.hh>
00035 #include <rl_agent/Sarsa.hh>
00039 #include <vector>
00040 #include <sstream>
00041 #include <iostream>
00043 #include <getopt.h>
00044 #include <stdlib.h>
00046 unsigned NUMEPISODES = 1000; //10; //200; //500; //200;
00047 const unsigned NUMTRIALS = 1; //30; //30; //5; //30; //30; //50
00048 unsigned MAXSTEPS = 1000; // per episode
00049 bool PRINTS = false;
00052 void displayHelp(){
00053   cout << "\n Call experiment --agent type --env type [options]\n";
00054   cout << "Agent types: qlearner sarsa modelbased rmax texplore dyna savedpolicy\n";
00055   cout << "Env types: taxi tworooms fourrooms energy fuelworld mcar cartpole car2to7 car7to2 carrandom stocks lightworld\n";
00057   cout << "\n Agent Options:\n";
00058   cout << "--gamma value (discount factor between 0 and 1)\n";
00059   cout << "--epsilon value (epsilon for epsilon-greedy exploration)\n";
00060   cout << "--alpha value (learning rate alpha)\n";
00061   cout << "--initialvalue value (initial q values)\n";
00062   cout << "--actrate value (action selection rate (Hz))\n";
00063   cout << "--lamba value (lamba for eligibility traces)\n";
00064   cout << "--m value (parameter for R-Max)\n";
00065   cout << "--k value (For Dyna: # of model based updates to do between each real world update)\n";
00066   cout << "--history value (# steps of history to use for planning with delay)\n";
00067   cout << "--filename file (file to load saved policy from for savedpolicy agent)\n";
00068   cout << "--model type (tabular,tree,m5tree)\n";
00069   cout << "--planner type (vi,pi,sweeping,uct,parallel-uct,delayed-uct,delayed-parallel-uct)\n";
00070   cout << "--explore type (unknown,greedy,epsilongreedy,variancenovelty)\n";
00071   cout << "--combo type (average,best,separate)\n";
00072   cout << "--nmodels value (# of models)\n";
00073   cout << "--nstates value (optionally discretize domain into value # of states on each feature)\n";
00074   cout << "--reltrans (learn relative transitions)\n";
00075   cout << "--abstrans (learn absolute transitions)\n";
00076   cout << "--v value (For TEXPLORE: b/v coefficient for rewarding state-actions where models disagree)\n";
00077   cout << "--n value (For TEXPLORE: n coefficient for rewarding state-actions which are novel)\n";
00079   cout << "\n Env Options:\n";
00080   cout << "--deterministic (deterministic version of domain)\n";
00081   cout << "--stochastic (stochastic version of domain)\n";
00082   cout << "--delay value (# steps of action delay (for mcar and tworooms)\n";
00083   cout << "--lag (turn on brake lag for car driving domain)\n";
00084   cout << "--highvar (have variation fuel costs in Fuel World)\n";
00085   cout << "--nsectors value (# sectors for stocks domain)\n";
00086   cout << "--nstocks value (# stocks for stocks domain)\n";
00088   cout << "\n--prints (turn on debug printing of actions/rewards)\n";
00089   cout << "--nepisodes value (# of episodes to run (1000 default)\n";
00090   cout << "--seed value (integer seed for random number generator)\n";
00092   cout << "\n For more info, see: http://www.ros.org/wiki/rl_experiment\n";
00094   exit(-1);
00096 }
00099 int main(int argc, char **argv) {
00101   // default params for env and agent
00102   char* agentType = NULL;
00103   char* envType = NULL;
00104   float discountfactor = 0.99;
00105   float epsilon = 0.1;
00106   float alpha = 0.3;
00107   float initialvalue = 0.0;
00108   float actrate = 10.0;
00109   float lambda = 0.1;
00110   int M = 5;
00111   int modelType = C45TREE;
00112   int exploreType = GREEDY;
00113   int predType = BEST;
00114   int plannerType = PAR_ETUCT_ACTUAL;
00115   int nmodels = 1;
00116   bool reltrans = true;
00117   bool deptrans = false;
00118   float v = 0;
00119   float n = 0;
00120   float featPct = 0.2;
00121   int nstates = 0;
00122   int k = 1000;
00123   char *filename = NULL;
00124   bool stochastic = true;
00125   int nstocks = 3;
00126   int nsectors = 3;
00127   int delay = 0;
00128   bool lag = false;
00129   bool highvar = false;
00130   int history = 0;
00131   int seed = 1;
00132   // change some of these parameters based on command line args
00134   // parse agent type
00135   bool gotAgent = false;
00136   for (int i = 1; i < argc-1; i++){
00137     if (strcmp(argv[i], "--agent") == 0){
00138       gotAgent = true;
00139       agentType = argv[i+1];
00140     }
00141   }
00142   if (!gotAgent) {
00143     cout << "--agent type  option is required" << endl;
00144     displayHelp();
00145   }
00147   // set some default options for rmax or texplore
00148   if (strcmp(agentType, "rmax") == 0){
00149     modelType = RMAX;
00150     exploreType = EXPLORE_UNKNOWN;
00151     predType = BEST;
00152     plannerType = VALUE_ITERATION;
00153     nmodels = 1;
00154     reltrans = false;
00155     M = 5;
00156     history = 0;
00157   } else if (strcmp(agentType, "texplore") == 0){
00158     modelType = C45TREE;
00159     exploreType = DIFF_AND_NOVEL_BONUS;
00160     v = 0;
00161     n = 0;
00162     predType = AVERAGE;
00163     plannerType = PAR_ETUCT_ACTUAL;
00164     nmodels = 5;
00165     reltrans = true;
00166     M = 0;
00167     history = 0;
00168   }
00170   // parse env type
00171   bool gotEnv = false;
00172   for (int i = 1; i < argc-1; i++){
00173     if (strcmp(argv[i], "--env") == 0){
00174       gotEnv = true;
00175       envType = argv[i+1];
00176     }
00177   }
00178   if (!gotEnv) {
00179     cout << "--env type  option is required" << endl;
00180     displayHelp();
00181   }
00183   // parse other arguments
00184   char ch;
00185   const char* optflags = "geairlmoxpcn:";
00186   int option_index = 0;
00187   static struct option long_options[] = {
00188     {"gamma", 1, 0, 'g'},
00189     {"discountfactor", 1, 0, 'g'},
00190     {"epsilon", 1, 0, 'e'},
00191     {"alpha", 1, 0, 'a'},
00192     {"initialvalue", 1, 0, 'i'},
00193     {"actrate", 1, 0, 'r'},
00194     {"lambda", 1, 0, 'l'},
00195     {"m", 1, 0, 'm'},
00196     {"model", 1, 0, 'o'},
00197     {"explore", 1, 0, 'x'},
00198     {"planner", 1, 0, 'p'},
00199     {"combo", 1, 0, 'c'},
00200     {"nmodels", 1, 0, '#'},
00201     {"reltrans", 0, 0, 't'},
00202     {"abstrans", 0, 0, '0'},
00203     {"seed", 1, 0, 's'},
00204     {"agent", 1, 0, 'q'},
00205     {"prints", 0, 0, 'd'},
00206     {"nstates", 1, 0, 'w'},
00207     {"k", 1, 0, 'k'},
00208     {"filename", 1, 0, 'f'},
00209     {"history", 1, 0, 'y'},
00210     {"b", 1, 0, 'b'},
00211     {"v", 1, 0, 'v'},
00212     {"n", 1, 0, 'n'},
00214     {"env", 1, 0, 1},
00215     {"deterministic", 0, 0, 2},
00216     {"stochastic", 0, 0, 3},
00217     {"delay", 1, 0, 4},
00218     {"nsectors", 1, 0, 5},
00219     {"nstocks", 1, 0, 6},
00220     {"lag", 0, 0, 7},
00221     {"nolag", 0, 0, 8},
00222     {"highvar", 0, 0, 11},
00223     {"nepisodes", 1, 0, 12}
00225   };
00227   bool epsilonChanged = false;
00228   bool actrateChanged = false;
00229   bool mChanged = false;
00230   bool bvnChanged = false;
00231   bool lambdaChanged = false;
00233   while(-1 != (ch = getopt_long_only(argc, argv, optflags, long_options, &option_index))) {
00234     switch(ch) {
00236     case 'g':
00237       discountfactor = std::atof(optarg);
00238       cout << "discountfactor: " << discountfactor << endl;
00239       break;
00241     case 'e':
00242       epsilonChanged = true;
00243       epsilon = std::atof(optarg);
00244       cout << "epsilon: " << epsilon << endl;
00245       break;
00247     case 'y':
00248       {
00249         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00250           history = std::atoi(optarg);
00251           cout << "history: " << history << endl;
00252         } else {
00253           cout << "--history is not a valid option for agent: " << agentType << endl;
00254           exit(-1);
00255         }
00256         break;
00257       }
00259     case 'k':
00260       {
00261         if (strcmp(agentType, "dyna") == 0){
00262           k = std::atoi(optarg);
00263           cout << "k: " << k << endl;
00264         } else {
00265           cout << "--k is only a valid option for the Dyna agent" << endl;
00266           exit(-1);
00267         }
00268         break;
00269       }
00271     case 'f':
00272       filename = optarg;
00273       cout << "policy filename: " <<  filename << endl;
00274       break;
00276     case 'a':
00277       {
00278         if (strcmp(agentType, "qlearner") == 0 || strcmp(agentType, "dyna") == 0 || strcmp(agentType, "sarsa") == 0){
00279           alpha = std::atof(optarg);
00280           cout << "alpha: " << alpha << endl;
00281         } else {
00282           cout << "--alpha option is only valid for Q-Learning, Dyna, and Sarsa" << endl;
00283           exit(-1);
00284         }
00285         break;
00286       }
00288     case 'i':
00289       {
00290         if (strcmp(agentType, "qlearner") == 0 || strcmp(agentType, "dyna") == 0 || strcmp(agentType, "sarsa") == 0){
00291           initialvalue = std::atof(optarg);
00292           cout << "initialvalue: " << initialvalue << endl;
00293         } else {
00294           cout << "--initialvalue option is only valid for Q-Learning, Dyna, and Sarsa" << endl;
00295           exit(-1);
00296         }
00297         break;
00298       }
00300     case 'r':
00301       {
00302         actrateChanged = true;
00303         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0){
00304           actrate = std::atof(optarg);
00305           cout << "actrate: " << actrate << endl;
00306         } else {
00307           cout << "Model-free methods do not require an action rate" << endl;
00308           exit(-1);
00309         }
00310         break;
00311       }
00313     case 'l':
00314       {
00315         lambdaChanged = true;
00316         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0 || strcmp(agentType, "sarsa") == 0){
00317           lambda = std::atof(optarg);
00318           cout << "lambda: " << lambda << endl;
00319         } else {
00320           cout << "--lambda option is invalid for this agent: " << agentType << endl;
00321           exit(-1);
00322         }
00323         break;
00324       }
00326     case 'm':
00327       {
00328         mChanged = true;
00329         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0){
00330           M = std::atoi(optarg);
00331           cout << "M: " << M << endl;
00332         } else {
00333           cout << "--M option only useful for model-based agents, not " << agentType << endl;
00334           exit(-1);
00335         }
00336         break;
00337       }
00339     case 'o':
00340       {
00341         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0){
00342           if (strcmp(optarg, "tabular") == 0) modelType = RMAX;
00343           else if (strcmp(optarg, "tree") == 0) modelType = C45TREE;
00344           else if (strcmp(optarg, "texplore") == 0) modelType = C45TREE;
00345           else if (strcmp(optarg, "c45tree") == 0) modelType = C45TREE;
00346           else if (strcmp(optarg, "m5tree") == 0) modelType = M5ALLMULTI;
00347           if (strcmp(agentType, "rmax") == 0 && modelType != RMAX){
00348             cout << "R-Max should use tabular model" << endl;
00349             exit(-1);
00350           }
00351         } else {
00352           cout << "Model-free methods do not need a model, --model option does nothing for this agent type" << endl;
00353           exit(-1);
00354         }
00355         cout << "model: " << modelNames[modelType] << endl;
00356         break;
00357       }
00359     case 'x':
00360       {
00361         if (strcmp(optarg, "unknown") == 0) exploreType = EXPLORE_UNKNOWN;
00362         else if (strcmp(optarg, "greedy") == 0) exploreType = GREEDY;
00363         else if (strcmp(optarg, "epsilongreedy") == 0) exploreType = EPSILONGREEDY;
00364         else if (strcmp(optarg, "unvisitedstates") == 0) exploreType = UNVISITED_BONUS;
00365         else if (strcmp(optarg, "unvisitedactions") == 0) exploreType = UNVISITED_ACT_BONUS;
00366         else if (strcmp(optarg, "variancenovelty") == 0) exploreType = DIFF_AND_NOVEL_BONUS;
00367         if (strcmp(agentType, "rmax") == 0 && exploreType != EXPLORE_UNKNOWN){
00368           cout << "R-Max should use \"--explore unknown\" exploration" << endl;
00369           exit(-1);
00370         }
00371         else if (strcmp(agentType, "texplore") != 0 && strcmp(agentType, "modelbased") != 0 && strcmp(agentType, "rmax") != 0 && (exploreType != GREEDY && exploreType != EPSILONGREEDY)) {
00372           cout << "Model free methods must use either greedy or epsilon-greedy exploration!" << endl;
00373           exploreType = EPSILONGREEDY;
00374           exit(-1);
00375         }
00376         cout << "explore: " << exploreNames[exploreType] << endl;
00377         break;
00378       }
00380     case 'p':
00381       {
00382         if (strcmp(optarg, "vi") == 0) plannerType = VALUE_ITERATION;
00383         else if (strcmp(optarg, "valueiteration") == 0) plannerType = VALUE_ITERATION;
00384         else if (strcmp(optarg, "policyiteration") == 0) plannerType = POLICY_ITERATION;
00385         else if (strcmp(optarg, "pi") == 0) plannerType = POLICY_ITERATION;
00386         else if (strcmp(optarg, "sweeping") == 0) plannerType = PRI_SWEEPING;
00387         else if (strcmp(optarg, "prioritizedsweeping") == 0) plannerType = PRI_SWEEPING;
00388         else if (strcmp(optarg, "uct") == 0) plannerType = ET_UCT_ACTUAL;
00389         else if (strcmp(optarg, "paralleluct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00390         else if (strcmp(optarg, "realtimeuct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00391         else if (strcmp(optarg, "realtime-uct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00392         else if (strcmp(optarg, "parallel-uct") == 0) plannerType = PAR_ETUCT_ACTUAL;
00393         else if (strcmp(optarg, "delayeduct") == 0) plannerType = POMDP_ETUCT;
00394         else if (strcmp(optarg, "delayed-uct") == 0) plannerType = POMDP_ETUCT;
00395         else if (strcmp(optarg, "delayedparalleluct") == 0) plannerType = POMDP_PAR_ETUCT;
00396         else if (strcmp(optarg, "delayed-parallel-uct") == 0) plannerType = POMDP_PAR_ETUCT;
00397         if (strcmp(agentType, "texplore") != 0 && strcmp(agentType, "modelbased") != 0 && strcmp(agentType, "rmax") != 0){
00398           cout << "Model-free methods do not require planners, --planner option does nothing with this agent" << endl;
00399           exit(-1);
00400         }
00401         if (strcmp(agentType, "rmax") == 0 && plannerType != VALUE_ITERATION){
00402           cout << "Typical implementation of R-Max would use value iteration, but another planner type is ok" << endl;
00403         }
00404         cout << "planner: " << plannerNames[plannerType] << endl;
00405         break;
00406       }
00408     case 'c':
00409       {
00410         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00411           if (strcmp(optarg, "average") == 0) predType = AVERAGE;
00412           else if (strcmp(optarg, "weighted") == 0) predType = WEIGHTAVG;
00413           else if (strcmp(optarg, "best") == 0) predType = BEST;
00414           else if (strcmp(optarg, "separate") == 0) predType = SEPARATE;
00415           cout << "predType: " << comboNames[predType] << endl;
00416         } else {
00417           cout << "--combo is an invalid option for agent: " << agentType << endl;
00418           exit(-1);
00419         }
00420         break;
00421       }
00423     case '#':
00424       {
00425         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00426           nmodels = std::atoi(optarg);
00427           cout << "nmodels: " << nmodels << endl;
00428         } else {
00429           cout << "--nmodels is an invalid option for agent: " << agentType << endl;
00430           exit(-1);
00431         }
00432         if (nmodels < 1){
00433           cout << "nmodels must be > 0" << endl;
00434           exit(-1);
00435         }
00436         break;
00437       }
00439     case 't':
00440       {
00441         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00442           reltrans = true;
00443           cout << "reltrans: " << reltrans << endl;
00444         } else {
00445           cout << "--reltrans is an invalid option for agent: " << agentType << endl;
00446           exit(-1);
00447         }
00448         break;
00449       }
00451     case '0':
00452       {
00453         if (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0){
00454           reltrans = false;
00455           cout << "reltrans: " << reltrans << endl;
00456         } else {
00457           cout << "--abstrans is an invalid option for agent: " << agentType << endl;
00458           exit(-1);
00459         }
00460         break;
00461       }
00463     case 's':
00464       seed = std::atoi(optarg);
00465       cout << "seed: " << seed << endl;
00466       break;
00468     case 'q':
00469       // already processed this one
00470       cout << "agent: " << agentType << endl;
00471       break;
00473     case 'd':
00474       PRINTS = true;
00475       break;
00477     case 'w':
00478       nstates = std::atoi(optarg);
00479       cout << "nstates for discretization: " << nstates << endl;
00480       break;
00482     case 'v':
00483     case 'b':
00484       {
00485         bvnChanged = true;
00486         if (strcmp(agentType, "texplore") == 0){
00487           v = std::atof(optarg);
00488           cout << "v coefficient (variance bonus): " << v << endl;
00489         }
00490         else {
00491           cout << "--v and --b are invalid options for agent: " << agentType << endl;
00492           exit(-1);
00493         }
00494         break;
00495       }
00497     case 'n':
00498       {
00499         bvnChanged = true;
00500         if (strcmp(agentType, "texplore") == 0){
00501           n = std::atof(optarg);
00502           cout << "n coefficient (novelty bonus): " << n << endl;
00503         }
00504         else {
00505           cout << "--n is an invalid option for agent: " << agentType << endl;
00506           exit(-1);
00507         }
00508         break;
00509       }
00511     case 2:
00512       stochastic = false;
00513       cout << "stochastic: " << stochastic << endl;
00514       break;
00516     case 11:
00517       {
00518         if (strcmp(envType, "fuelworld") == 0){
00519           highvar = true;
00520           cout << "fuel world fuel cost variation: " << highvar << endl;
00521         } else {
00522           cout << "--highvar is only a valid option for the fuelworld domain." << endl;
00523           exit(-1);
00524         }
00525         break;
00526       }
00528     case 3:
00529       stochastic = true;
00530       cout << "stochastic: " << stochastic << endl;
00531       break;
00533     case 4:
00534       {
00535         if (strcmp(envType, "mcar") == 0 || strcmp(envType, "tworooms") == 0){
00536           delay = std::atoi(optarg);
00537           cout << "delay steps: " << delay << endl;
00538         } else {
00539           cout << "--delay option is only valid for the mcar and tworooms domains" << endl;
00540           exit(-1);
00541         }
00542         break;
00543       }
00545     case 5:
00546       {
00547         if (strcmp(envType, "stocks") == 0){
00548           nsectors = std::atoi(optarg);
00549           cout << "nsectors: " << nsectors << endl;
00550         } else {
00551           cout << "--nsectors option is only valid for the stocks domain" << endl;
00552           exit(-1);
00553         }
00554         break;
00555       }
00557     case 6:
00558       {
00559         if (strcmp(envType, "stocks") == 0){
00560           nstocks = std::atoi(optarg);
00561           cout << "nstocks: " << nstocks << endl;
00562         } else {
00563           cout << "--nstocks option is only valid for the stocks domain" << endl;
00564           exit(-1);
00565         }
00566         break;
00567       }
00569     case 7:
00570       {
00571         if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00572           lag = true;
00573           cout << "lag: " << lag << endl;
00574         } else {
00575           cout << "--lag option is only valid for car velocity tasks" << endl;
00576           exit(-1);
00577         }
00578         break;
00579       }
00581     case 8:
00582       {
00583         if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00584           lag = false;
00585           cout << "lag: " << lag << endl;
00586         } else {
00587           cout << "--nolag option is only valid for car velocity tasks" << endl;
00588           exit(-1);
00589         }
00590         break;
00591       }
00593     case 1:
00594       // already processed this one
00595       cout << "env: " << envType << endl;
00596       break;
00598     case 12:
00599       NUMEPISODES = std::atoi(optarg);
00600       cout << "Num Episodes: " << NUMEPISODES << endl;
00601       break;
00603     case 'h':
00604     case '?':
00605     case 0:
00606     default:
00607       displayHelp();
00608       break;
00609     }
00610   }
00612   // default back to greedy if no coefficients
00613   if (exploreType == DIFF_AND_NOVEL_BONUS && v == 0 && n == 0)
00614     exploreType = GREEDY;
00616   // check for conflicting options
00617   // changed epsilon but not doing epsilon greedy exploration
00618   if (epsilonChanged && exploreType != EPSILONGREEDY){
00619     cout << "No reason to change epsilon when not using epsilon-greedy exploration" << endl;
00620     exit(-1);
00621   }
00623   // set history value but not doing uct w/history planner
00624   if (history > 0 && (plannerType == VALUE_ITERATION || plannerType == POLICY_ITERATION || plannerType == PRI_SWEEPING)){
00625     cout << "No reason to set history higher than 0 if not using a UCT planner" << endl;
00626     exit(-1);
00627   }
00629   // set action rate but not doing real-time planner
00630   if (actrateChanged && (plannerType == VALUE_ITERATION || plannerType == POLICY_ITERATION || plannerType == PRI_SWEEPING)){
00631     cout << "No reason to set actrate if not using a UCT planner" << endl;
00632     exit(-1);
00633   }
00635   // set lambda but not doing uct (lambda)
00636   if (lambdaChanged && (strcmp(agentType, "texplore") == 0 || strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") == 0) && (plannerType == VALUE_ITERATION || plannerType == POLICY_ITERATION || plannerType == PRI_SWEEPING)){
00637     cout << "No reason to set actrate if not using a UCT planner" << endl;
00638     exit(-1);
00639   }
00641   // set n/v/b but not doing that diff_novel exploration
00642   if (bvnChanged && exploreType != DIFF_AND_NOVEL_BONUS){
00643     cout << "No reason to set n or v if not doing variance & novelty exploration" << endl;
00644     exit(-1);
00645   }
00647   // set combo other than best but only doing 1 model
00648   if (predType != BEST && nmodels == 1){
00649     cout << "No reason to have model combo other than best with nmodels = 1" << endl;
00650     exit(-1);
00651   }
00653   // set M but not doing explore unknown
00654   if (mChanged && exploreType != EXPLORE_UNKNOWN){
00655     cout << "No reason to set M if not doing R-max style Explore Unknown exploration" << endl;
00656     exit(-1);
00657   }
00659   if (PRINTS){
00660     if (stochastic)
00661       cout << "Stohastic\n";
00662     else
00663       cout << "Deterministic\n";
00664   }
00666   Random rng(1 + seed);
00668   std::vector<int> statesPerDim;
00670   // Construct environment here.
00671   Environment* e;
00673   if (strcmp(envType, "cartpole") == 0){
00674     if (PRINTS) cout << "Environment: Cart Pole\n";
00675     e = new CartPole(rng, stochastic);
00676   }
00678   else if (strcmp(envType, "mcar") == 0){
00679     if (PRINTS) cout << "Environment: Mountain Car\n";
00680     e = new MountainCar(rng, stochastic, false, delay);
00681   }
00683   // taxi
00684   else if (strcmp(envType, "taxi") == 0){
00685     if (PRINTS) cout << "Environment: Taxi\n";
00686     e = new Taxi(rng, stochastic);
00687   }
00689   // Light World
00690   else if (strcmp(envType, "lightworld") == 0){
00691     if (PRINTS) cout << "Environment: Light World\n";
00692     e = new LightWorld(rng, stochastic, 4);
00693   }
00695   // two rooms
00696   else if (strcmp(envType, "tworooms") == 0){
00697     if (PRINTS) cout << "Environment: TwoRooms\n";
00698     e = new TwoRooms(rng, stochastic, true, delay, false);
00699   }
00701   // car vel, 2 to 7
00702   else if (strcmp(envType, "car2to7") == 0){
00703     if (PRINTS) cout << "Environment: Car Velocity 2 to 7 m/s\n";
00704     e = new RobotCarVel(rng, false, true, false, lag);
00705     statesPerDim.resize(4,0);
00706     statesPerDim[0] = 12;
00707     statesPerDim[1] = 120;
00708     statesPerDim[2] = 4;
00709     statesPerDim[3] = 10;
00710     MAXSTEPS = 100;
00711   }
00712   // car vel, 7 to 2
00713   else if (strcmp(envType, "car7to2") == 0){
00714     if (PRINTS) cout << "Environment: Car Velocity 7 to 2 m/s\n";
00715     e = new RobotCarVel(rng, false, false, false, lag);
00716     statesPerDim.resize(4,0);
00717     statesPerDim[0] = 12;
00718     statesPerDim[1] = 120;
00719     statesPerDim[2] = 4;
00720     statesPerDim[3] = 10;
00721     MAXSTEPS = 100;
00722   }
00723   // car vel, random vels
00724   else if (strcmp(envType, "carrandom") == 0){
00725     if (PRINTS) cout << "Environment: Car Velocity Random Velocities\n";
00726     e = new RobotCarVel(rng, true, false, false, lag);
00727     statesPerDim.resize(4,0);
00728     statesPerDim[0] = 12;
00729     statesPerDim[1] = 48;
00730     statesPerDim[2] = 4;
00731     statesPerDim[3] = 10;
00732     MAXSTEPS = 100;
00733   }
00735   // four rooms
00736   else if (strcmp(envType, "fourrooms") == 0){
00737     if (PRINTS) cout << "Environment: FourRooms\n";
00738     e = new FourRooms(rng, stochastic, true, false);
00739   }
00741   // four rooms with energy level
00742   else if (strcmp(envType, "energy") == 0){
00743     if (PRINTS) cout << "Environment: EnergyRooms\n";
00744     e = new EnergyRooms(rng, stochastic, true, false);
00745   }
00747   // gridworld with fuel (fuel stations on top and bottom with random costs)
00748   else if (strcmp(envType, "fuelworld") == 0){
00749     if (PRINTS) cout << "Environment: FuelWorld\n";
00750     e = new FuelRooms(rng, highvar, stochastic);
00751   }
00753   // stocks
00754   else if (strcmp(envType, "stocks") == 0){
00755     if (PRINTS) cout << "Enironment: Stocks with " << nsectors
00756                      << " sectors and " << nstocks << " stocks\n";
00757     e = new Stocks(rng, stochastic, nsectors, nstocks);
00758   }
00760   else {
00761     std::cerr << "Invalid env type" << endl;
00762     exit(-1);
00763   }
00765   const int numactions = e->getNumActions(); // Most agents will need this?
00767   std::vector<float> minValues;
00768   std::vector<float> maxValues;
00769   e->getMinMaxFeatures(&minValues, &maxValues);
00770   bool episodic = e->isEpisodic();
00772   cout << "Environment is ";
00773   if (!episodic) cout << "NOT ";
00774   cout << "episodic." << endl;
00776   // lets just check this for now
00777   for (unsigned i = 0; i < minValues.size(); i++){
00778     if (PRINTS) cout << "Feat " << i << " min: " << minValues[i]
00779                      << " max: " << maxValues[i] << endl;
00780   }
00782   // get max/min reward for the domain
00783   float rMax = 0.0;
00784   float rMin = -1.0;
00786   e->getMinMaxReward(&rMin, &rMax);
00787   float rRange = rMax - rMin;
00788   if (PRINTS) cout << "Min Reward: " << rMin
00789                    << ", Max Reward: " << rMax << endl;
00791   // set rmax as a bonus for certain exploration types
00792   if (rMax <= 0.0 && (exploreType == TWO_MODE_PLUS_R ||
00793                       exploreType == CONTINUOUS_BONUS_R ||
00794                       exploreType == CONTINUOUS_BONUS ||
00795                       exploreType == THRESHOLD_BONUS_R)){
00796     rMax = 1.0;
00797   }
00800   float rsum = 0;
00802   if (statesPerDim.size() == 0){
00803     cout << "set statesPerDim to " << nstates << " for all dim" << endl;
00804     statesPerDim.resize(minValues.size(), nstates);
00805   }
00807   for (unsigned j = 0; j < NUMTRIALS; ++j) {
00809     // Construct agent here.
00810     Agent* agent;
00812     if (strcmp(agentType, "qlearner") == 0){
00813       if (PRINTS) cout << "Agent: QLearner" << endl;
00814       agent = new QLearner(numactions,
00815                            discountfactor,
00816                            initialvalue, //0.0, // initialvalue
00817                            alpha, // alpha
00818                            epsilon, // epsilon
00819                            rng);
00820     }
00822     else if (strcmp(agentType, "dyna") == 0){
00823       if (PRINTS) cout << "Agent: Dyna" << endl;
00824       agent = new Dyna(numactions,
00825                        discountfactor,
00826                        initialvalue, //0.0, // initialvalue
00827                        alpha, // alpha
00828                        k, // k
00829                        epsilon, // epsilon
00830                        rng);
00831     }
00833     else if (strcmp(agentType, "sarsa") == 0){
00834       if (PRINTS) cout << "Agent: SARSA" << endl;
00835       agent = new Sarsa(numactions,
00836                         discountfactor,
00837                         initialvalue, //0.0, // initialvalue
00838                         alpha, // alpha
00839                         epsilon, // epsilon
00840                         lambda,
00841                         rng);
00842     }
00844     else if (strcmp(agentType, "modelbased") == 0 || strcmp(agentType, "rmax") || strcmp(agentType, "texplore")){
00845       if (PRINTS) cout << "Agent: Model Based" << endl;
00846       agent = new ModelBasedAgent(numactions,
00847                                   discountfactor,
00848                                   rMax, rRange,
00849                                   modelType,
00850                                   exploreType,
00851                                   predType,
00852                                   nmodels,
00853                                   plannerType,
00854                                   epsilon, // epsilon
00855                                   lambda,
00856                                   (1.0/actrate), //0.1, //0.1, //0.01, // max time
00857                                   M,
00858                                   minValues, maxValues,
00859                                   statesPerDim,//0,
00860                                   history, v, n,
00861                                   deptrans, reltrans, featPct, stochastic, episodic,
00862                                   rng);
00863     }
00865     else if (strcmp(agentType, "savedpolicy") == 0){
00866       if (PRINTS) cout << "Agent: Saved Policy" << endl;
00867       agent = new SavedPolicy(numactions,filename);
00868     }
00870     else {
00871       std::cerr << "ERROR: Invalid agent type" << endl;
00872       exit(-1);
00873     }
00875     // start discrete agent if we're discretizing (if nstates > 0 and not agent type 'c')
00876     int totalStates = 1;
00877     Agent* a2 = agent;
00878     // not for model based when doing continuous model
00879     if (nstates > 0 && (modelType != M5ALLMULTI || strcmp(agentType, "qlearner") == 0)){
00880       int totalStates = powf(nstates,minValues.size());
00881       if (PRINTS) cout << "Discretize with " << nstates << ", total: " << totalStates << endl;
00882       agent = new DiscretizationAgent(nstates, a2,
00883                                       minValues, maxValues, PRINTS);
00884     }
00885     else {
00886       totalStates = 1;
00887       for (unsigned i = 0; i < minValues.size(); i++){
00888         int range = 1+maxValues[i] - minValues[i];
00889         totalStates *= range;
00890       }
00891       if (PRINTS) cout << "No discretization, total: " << totalStates << endl;
00892     }
00894     // before we start, seed the agent with some experiences
00895     agent->seedExp(e->getSeedings());
00897     // STEP BY STEP DOMAIN
00898     if (!episodic){
00900       // performance tracking
00901       float sum = 0;
00902       int steps = 0;
00903       float trialSum = 0.0;
00905       int a = 0;
00906       float r = 0;
00909       // non-episodic
00911       for (unsigned i = 0; i < NUMEPISODES; ++i){
00913         std::vector<float> es = e->sensation();
00915         // first step
00916         if (i == 0){
00918           // first action
00919           a = agent->first_action(es);
00920           r = e->apply(a);
00922         } else {
00923           // next action
00924           a = agent->next_action(r, es);
00925           r = e->apply(a);
00926         }
00928         // update performance
00929         sum += r;
00930         ++steps;
00932         std::cerr << r << endl;
00934       }
00937       rsum += sum;
00938       trialSum += sum;
00939       if (PRINTS) cout << "Rsum(trial " << j << "): " << trialSum << " Avg: "
00940                        << (rsum / (float)(j+1))<< endl;
00942     }
00945     else {
00948       // episodic
00950       for (unsigned i = 0; i < NUMEPISODES; ++i) {
00952         // performance tracking
00953         float sum = 0;
00954         int steps = 0;
00956         // first action
00957         std::vector<float> es = e->sensation();
00958         int a = agent->first_action(es);
00959         float r = e->apply(a);
00961         // update performance
00962         sum += r;
00963         ++steps;
00965         while (!e->terminal() && steps < MAXSTEPS) {
00967           // perform an action
00968           es = e->sensation();
00969           a = agent->next_action(r, es);
00970           r = e->apply(a);
00972           // update performance info
00973           sum += r;
00974           ++steps;
00976         }
00978         // terminal/last state
00979         if (e->terminal()){
00980           agent->last_action(r);
00981         }else{
00982           agent->next_action(r, e->sensation());
00983         }
00985         e->reset();
00986         std::cerr << sum << endl;
00988         rsum += sum;
00990       }
00992     }
00994     if (NUMTRIALS > 1) delete agent;
00996   }
00998   if (PRINTS) cout << "Avg Rsum: " << (rsum / (float)NUMTRIALS) << endl;
01000 } // end main

autogenerated on Thu Jun 6 2019 22:00:27