env.cpp
Go to the documentation of this file.
00001 #include <ros/ros.h>
00002 #include "std_msgs/String.h"
00003 
00004 #include <rl_msgs/RLStateReward.h>
00005 #include <rl_msgs/RLEnvDescription.h>
00006 #include <rl_msgs/RLAction.h>
00007 #include <rl_msgs/RLExperimentInfo.h>
00008 #include <rl_msgs/RLEnvSeedExperience.h>
00009 
00010 #include <ros/callback_queue.h>
00011 #include <tf/transform_broadcaster.h>
00012 
00013 #include <rl_common/core.hh>
00014 #include <rl_common/Random.h>
00015 
00016 #include <rl_env/RobotCarVel.hh>
00017 #include <rl_env/fourrooms.hh>
00018 #include <rl_env/tworooms.hh>
00019 #include <rl_env/taxi.hh>
00020 #include <rl_env/FuelRooms.hh>
00021 #include <rl_env/stocks.hh>
00022 #include <rl_env/energyrooms.hh>
00023 #include <rl_env/MountainCar.hh>
00024 #include <rl_env/CartPole.hh>
00025 #include <rl_env/LightWorld.hh>
00026 
00027 #include <getopt.h>
00028 #include <stdlib.h>
00029 
00030 #define NODE "RLEnvironment"
00031 
00032 static ros::Publisher out_env_desc;
00033 static ros::Publisher out_env_sr;
00034 static ros::Publisher out_seed;
00035 
00036 Environment* e;
00037 Random rng;
00038 bool PRINTS = false;//true;
00039 int seed = 1;
00040 char* envType;
00041 
00042 // some default parameters
00043 bool stochastic = true;
00044 int nstocks = 3;
00045 int nsectors = 3;
00046 int delay = 0;
00047 bool lag = false;
00048 bool highvar = false;
00049 
00050 void displayHelp(){
00051   cout << "\n Call env --env type [options]\n";
00052   cout << "Env types: taxi tworooms fourrooms energy fuelworld mcar cartpole car2to7 car7to2 carrandom stocks lightworld\n";
00053   cout << "\n Options:\n";
00054   cout << "--seed value (integer seed for random number generator)\n";
00055   cout << "--deterministic (deterministic version of domain)\n";
00056   cout << "--stochastic (stochastic version of domain)\n";
00057   cout << "--delay value (# steps of action delay (for mcar and tworooms)\n";
00058   cout << "--lag (turn on brake lag for car driving domain)\n";
00059   cout << "--highvar (have variation fuel costs in Fuel World)\n";
00060   cout << "--nsectors value (# sectors for stocks domain)\n";
00061   cout << "--nstocks value (# stocks for stocks domain)\n";
00062   cout << "--prints (turn on debug printing of actions/rewards)\n";
00063 
00064   cout << "\n For more info, see: http://www.ros.org/wiki/rl_env\n";
00065   exit(-1);
00066 }
00067 
00068 
00070 void processAction(const rl_msgs::RLAction::ConstPtr &actionIn){
00071 
00072   rl_msgs::RLStateReward sr;
00073 
00074   // process action from the agent, affecting the environment
00075   sr.reward = e->apply(actionIn->action);
00076   sr.state = e->sensation();
00077   sr.terminal = e->terminal();
00078 
00079   // publish the state-reward message
00080   if (PRINTS) cout << "Got action " << actionIn->action << " at state: " << sr.state[0] << ", " << sr.state[1] << ", reward: " << sr.reward << endl;
00081 
00082   out_env_sr.publish(sr);
00083 
00084 }
00085 
00088 void processEpisodeInfo(const rl_msgs::RLExperimentInfo::ConstPtr &infoIn){
00089   // start new episode if terminal
00090   if (PRINTS) cout << "Episode " << infoIn->episode_number << " terminated with reward: " << infoIn->episode_reward << ", start new episode " << endl;
00091 
00092   e->reset();
00093 
00094   rl_msgs::RLStateReward sr;
00095   sr.reward = 0;
00096   sr.state = e->sensation();
00097   sr.terminal = false;
00098   out_env_sr.publish(sr);
00099 }
00100 
00101 
00103 void initEnvironment(){
00104 
00105   // init the environment
00106   e = NULL;
00107   rl_msgs::RLEnvDescription desc;
00108   
00109 
00110   if (strcmp(envType, "cartpole") == 0){
00111     desc.title = "Environment: Cart Pole\n";
00112     e = new CartPole(rng, stochastic);
00113   }
00114 
00115   else if (strcmp(envType, "mcar") == 0){
00116     desc.title = "Environment: Mountain Car\n";
00117     e = new MountainCar(rng, stochastic, false, delay);
00118   }
00119 
00120   // taxi
00121   else if (strcmp(envType, "taxi") == 0){
00122     desc.title = "Environment: Taxi\n";
00123     e = new Taxi(rng, stochastic);
00124   }
00125 
00126   // Light World
00127   else if (strcmp(envType, "lightworld") == 0){
00128     desc.title = "Environment: Light World\n";
00129     e = new LightWorld(rng, stochastic, 4);
00130   }
00131 
00132   // two rooms
00133   else if (strcmp(envType, "tworooms") == 0){
00134     desc.title = "Environment: TwoRooms\n";
00135     e = new TwoRooms(rng, stochastic, true, delay, false);
00136   }
00137 
00138   // car vel, 2 to 7
00139   else if (strcmp(envType, "car2to7") == 0){
00140     desc.title = "Environment: Car Velocity 2 to 7 m/s\n";
00141     e = new RobotCarVel(rng, false, true, false, lag);
00142   }
00143   // car vel, 7 to 2
00144   else if (strcmp(envType, "car7to2") == 0){
00145     desc.title = "Environment: Car Velocity 7 to 2 m/s\n";
00146     e = new RobotCarVel(rng, false, false, false, lag);
00147   }
00148   // car vel, random vels
00149   else if (strcmp(envType, "carrandom") == 0){
00150     desc.title = "Environment: Car Velocity Random Velocities\n";
00151     e = new RobotCarVel(rng, true, false, false, lag);
00152   }
00153 
00154   // four rooms
00155   else if (strcmp(envType, "fourrooms") == 0){
00156     desc.title = "Environment: FourRooms\n";
00157     e = new FourRooms(rng, stochastic, true, false);
00158   }
00159 
00160   // four rooms with energy level
00161   else if (strcmp(envType, "energy") == 0){
00162     desc.title = "Environment: EnergyRooms\n";
00163     e = new EnergyRooms(rng, stochastic, true, false);
00164   }
00165 
00166   // gridworld with fuel (fuel stations on top and bottom with random costs)
00167   else if (strcmp(envType, "fuelworld") == 0){
00168     desc.title = "Environment: FuelWorld\n";
00169     e = new FuelRooms(rng, highvar, stochastic);
00170   }
00171 
00172   // stocks
00173   else if (strcmp(envType, "stocks") == 0){
00174     desc.title = "Environment: Stocks\n";
00175     e = new Stocks(rng, stochastic, nsectors, nstocks);
00176   }
00177 
00178   else {
00179     std::cerr << "Invalid env type" << endl;
00180     displayHelp();
00181     exit(-1);
00182   }
00183 
00184   // fill in some more description info
00185   desc.num_actions = e->getNumActions();
00186   desc.episodic = e->isEpisodic();
00187 
00188   std::vector<float> maxFeats;
00189   std::vector<float> minFeats;
00190 
00191   e->getMinMaxFeatures(&minFeats, &maxFeats);
00192   desc.num_states = minFeats.size();
00193   desc.min_state_range = minFeats;
00194   desc.max_state_range = maxFeats;
00195 
00196   desc.stochastic = stochastic;
00197   float minReward;
00198   float maxReward;
00199   e->getMinMaxReward(&minReward, &maxReward);
00200   desc.max_reward = maxReward;
00201   desc.reward_range = maxReward - minReward;
00202 
00203   cout << desc.title << endl;
00204 
00205   // publish environment description
00206   out_env_desc.publish(desc);
00207 
00208   sleep(1);
00209 
00210   // send experiences
00211   std::vector<experience> seeds = e->getSeedings();
00212   for (unsigned i = 0; i < seeds.size(); i++){
00213     rl_msgs::RLEnvSeedExperience seed;
00214     seed.from_state = seeds[i].s;
00215     seed.to_state   = seeds[i].next;
00216     seed.action     = seeds[i].act;
00217     seed.reward     = seeds[i].reward;
00218     seed.terminal   = seeds[i].terminal;
00219     out_seed.publish(seed);
00220   }
00221 
00222   // now send first state message
00223   rl_msgs::RLStateReward sr;
00224   sr.terminal = false;
00225   sr.reward = 0;
00226   sr.state = e->sensation();
00227   out_env_sr.publish(sr);
00228   
00229 }
00230 
00231 
00233 int main(int argc, char *argv[])
00234 {
00235   ros::init(argc, argv, NODE);
00236   ros::NodeHandle node;
00237 
00238   if (argc < 2){
00239     cout << "--env type  option is required" << endl;
00240     displayHelp();
00241     exit(-1);
00242   }
00243 
00244   // env and seed are required
00245   if (argc < 3){
00246     displayHelp();
00247     exit(-1);
00248   }
00249 
00250   // parse options to change these parameters
00251   envType = argv[1];
00252   seed = std::atoi(argv[2]);
00253 
00254   // parse env type first
00255   bool gotEnv = false;
00256   for (int i = 1; i < argc-1; i++){
00257     if (strcmp(argv[i], "--env") == 0){
00258       gotEnv = true;
00259       envType = argv[i+1];
00260     }
00261   }
00262   if (!gotEnv) {
00263     cout << "--env type  option is required" << endl;
00264     displayHelp();
00265   }
00266 
00267   // now parse other options
00268   char ch;
00269   const char* optflags = "ds:";
00270   int option_index = 0;
00271   static struct option long_options[] = {
00272     {"env", 1, 0, 'e'},
00273     {"deterministic", 0, 0, 'd'},
00274     {"stochastic", 0, 0, 's'},
00275     {"delay", 1, 0, 'a'},
00276     {"nsectors", 1, 0, 'c'},
00277     {"nstocks", 1, 0, 't'},
00278     {"lag", 0, 0, 'l'},
00279     {"nolag", 0, 0, 'o'},
00280     {"seed", 1, 0, 'x'},
00281     {"prints", 0, 0, 'p'},
00282     {"highvar", 0, 0, 'v'}
00283   };
00284 
00285   while(-1 != (ch = getopt_long_only(argc, argv, optflags, long_options, &option_index))) {
00286     switch(ch) {
00287 
00288     case 'x':
00289       seed = std::atoi(optarg);
00290       cout << "seed: " << seed << endl;
00291       break;
00292 
00293     case 'd':
00294       stochastic = false;
00295       cout << "stochastic: " << stochastic << endl;
00296       break;
00297 
00298     case 'v':
00299       {
00300         if (strcmp(envType, "fuelworld") == 0){
00301           highvar = true;
00302           cout << "fuel world fuel cost variation: " << highvar << endl;
00303         } else {
00304           cout << "--highvar is only a valid option for the fuelworld domain." << endl;
00305           exit(-1);
00306         }
00307         break;
00308       }
00309 
00310     case 's':
00311       stochastic = true;
00312       cout << "stochastic: " << stochastic << endl;
00313       break;
00314 
00315     case 'a':
00316       {
00317         if (strcmp(envType, "mcar") == 0 || strcmp(envType, "tworooms") == 0){
00318           delay = std::atoi(optarg);
00319           cout << "delay steps: " << delay << endl;
00320         } else {
00321           cout << "--delay option is only valid for the mcar and tworooms domains" << endl;
00322           exit(-1);
00323         }
00324         break;
00325       }
00326 
00327     case 'c':
00328       {
00329         if (strcmp(envType, "stocks") == 0){
00330           nsectors = std::atoi(optarg);
00331           cout << "nsectors: " << nsectors << endl;
00332         } else {
00333           cout << "--nsectors option is only valid for the stocks domain" << endl;
00334           exit(-1);
00335         }
00336         break;
00337       }
00338     
00339     case 't':
00340       {
00341         if (strcmp(envType, "stocks") == 0){
00342           nstocks = std::atoi(optarg);
00343           cout << "nstocks: " << nstocks << endl;
00344         } else {
00345           cout << "--nstocks option is only valid for the stocks domain" << endl;
00346           exit(-1);
00347         }
00348         break;
00349       }
00350       
00351     case 'l':
00352       {
00353         if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00354           lag = true;
00355           cout << "lag: " << lag << endl;
00356         } else {
00357           cout << "--lag option is only valid for car velocity tasks" << endl;
00358           exit(-1);
00359         }
00360         break;
00361       }
00362 
00363     case 'o':
00364        {
00365          if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00366            lag = false;
00367            cout << "lag: " << lag << endl;
00368          } else {
00369            cout << "--nolag option is only valid for car velocity tasks" << endl;
00370            exit(-1);
00371          }
00372          break;
00373        }
00374 
00375     case 'e':
00376       // already processed this one
00377       cout << "env: " << envType << endl;
00378       break;
00379 
00380     case 'p':
00381       PRINTS = true;
00382       break;
00383 
00384     case 'h':
00385     case '?':
00386     case 0:
00387     default:
00388       displayHelp();
00389       break;
00390     }
00391   }
00392 
00393 
00394   int qDepth = 1;
00395 
00396   // Set up Publishers
00397   ros::init(argc, argv, "my_tf_broadcaster");
00398   tf::Transform transform;
00399 
00400   // Set up Publishers
00401   out_env_desc = node.advertise<rl_msgs::RLEnvDescription>("rl_env/rl_env_description",qDepth,true);
00402   out_env_sr = node.advertise<rl_msgs::RLStateReward>("rl_env/rl_state_reward",qDepth,false);
00403   out_seed = node.advertise<rl_msgs::RLEnvSeedExperience>("rl_env/rl_seed",20,false);
00404 
00405   // Set up subscribers
00406   ros::TransportHints noDelay = ros::TransportHints().tcpNoDelay(true);
00407   ros::Subscriber rl_action =  node.subscribe("rl_agent/rl_action", qDepth, processAction, noDelay);
00408   ros::Subscriber rl_exp_info =  node.subscribe("rl_agent/rl_experiment_info", qDepth, processEpisodeInfo, noDelay);
00409 
00410   // publish env description, first state
00411   // Setup RL World
00412   rng = Random(1+seed);
00413   initEnvironment();
00414 
00415 
00416   ROS_INFO(NODE ": starting main loop");
00417   
00418   ros::spin();                          // handle incoming data
00419   //while (ros::ok()){
00420   //  ros::getGlobalCallbackQueue()->callAvailable(ros::WallDuration(0.1));
00421   //}
00422 
00423   return 0;
00424 }
00425 
00426 
00427 
00428 


rl_env
Author(s):
autogenerated on Thu Jun 6 2019 22:00:23