00001 #include <ros/ros.h>
00002 #include "std_msgs/String.h"
00003
00004 #include <rl_msgs/RLStateReward.h>
00005 #include <rl_msgs/RLEnvDescription.h>
00006 #include <rl_msgs/RLAction.h>
00007 #include <rl_msgs/RLExperimentInfo.h>
00008 #include <rl_msgs/RLEnvSeedExperience.h>
00009
00010 #include <ros/callback_queue.h>
00011 #include <tf/transform_broadcaster.h>
00012
00013 #include <rl_common/core.hh>
00014 #include <rl_common/Random.h>
00015
00016 #include <rl_env/RobotCarVel.hh>
00017 #include <rl_env/fourrooms.hh>
00018 #include <rl_env/tworooms.hh>
00019 #include <rl_env/taxi.hh>
00020 #include <rl_env/FuelRooms.hh>
00021 #include <rl_env/stocks.hh>
00022 #include <rl_env/energyrooms.hh>
00023 #include <rl_env/MountainCar.hh>
00024 #include <rl_env/CartPole.hh>
00025 #include <rl_env/LightWorld.hh>
00026
00027 #include <getopt.h>
00028 #include <stdlib.h>
00029
00030 #define NODE "RLEnvironment"
00031
00032 static ros::Publisher out_env_desc;
00033 static ros::Publisher out_env_sr;
00034 static ros::Publisher out_seed;
00035
00036 Environment* e;
00037 Random rng;
00038 bool PRINTS = false;
00039 int seed = 1;
00040 char* envType;
00041
00042
00043 bool stochastic = true;
00044 int nstocks = 3;
00045 int nsectors = 3;
00046 int delay = 0;
00047 bool lag = false;
00048 bool highvar = false;
00049
00050 void displayHelp(){
00051 cout << "\n Call env --env type [options]\n";
00052 cout << "Env types: taxi tworooms fourrooms energy fuelworld mcar cartpole car2to7 car7to2 carrandom stocks lightworld\n";
00053 cout << "\n Options:\n";
00054 cout << "--seed value (integer seed for random number generator)\n";
00055 cout << "--deterministic (deterministic version of domain)\n";
00056 cout << "--stochastic (stochastic version of domain)\n";
00057 cout << "--delay value (# steps of action delay (for mcar and tworooms)\n";
00058 cout << "--lag (turn on brake lag for car driving domain)\n";
00059 cout << "--highvar (have variation fuel costs in Fuel World)\n";
00060 cout << "--nsectors value (# sectors for stocks domain)\n";
00061 cout << "--nstocks value (# stocks for stocks domain)\n";
00062 cout << "--prints (turn on debug printing of actions/rewards)\n";
00063
00064 cout << "\n For more info, see: http://www.ros.org/wiki/rl_env\n";
00065 exit(-1);
00066 }
00067
00068
00070 void processAction(const rl_msgs::RLAction::ConstPtr &actionIn){
00071
00072 rl_msgs::RLStateReward sr;
00073
00074
00075 sr.reward = e->apply(actionIn->action);
00076 sr.state = e->sensation();
00077 sr.terminal = e->terminal();
00078
00079
00080 if (PRINTS) cout << "Got action " << actionIn->action << " at state: " << sr.state[0] << ", " << sr.state[1] << ", reward: " << sr.reward << endl;
00081
00082 out_env_sr.publish(sr);
00083
00084 }
00085
00088 void processEpisodeInfo(const rl_msgs::RLExperimentInfo::ConstPtr &infoIn){
00089
00090 if (PRINTS) cout << "Episode " << infoIn->episode_number << " terminated with reward: " << infoIn->episode_reward << ", start new episode " << endl;
00091
00092 e->reset();
00093
00094 rl_msgs::RLStateReward sr;
00095 sr.reward = 0;
00096 sr.state = e->sensation();
00097 sr.terminal = false;
00098 out_env_sr.publish(sr);
00099 }
00100
00101
00103 void initEnvironment(){
00104
00105
00106 e = NULL;
00107 rl_msgs::RLEnvDescription desc;
00108
00109
00110 if (strcmp(envType, "cartpole") == 0){
00111 desc.title = "Environment: Cart Pole\n";
00112 e = new CartPole(rng, stochastic);
00113 }
00114
00115 else if (strcmp(envType, "mcar") == 0){
00116 desc.title = "Environment: Mountain Car\n";
00117 e = new MountainCar(rng, stochastic, false, delay);
00118 }
00119
00120
00121 else if (strcmp(envType, "taxi") == 0){
00122 desc.title = "Environment: Taxi\n";
00123 e = new Taxi(rng, stochastic);
00124 }
00125
00126
00127 else if (strcmp(envType, "lightworld") == 0){
00128 desc.title = "Environment: Light World\n";
00129 e = new LightWorld(rng, stochastic, 4);
00130 }
00131
00132
00133 else if (strcmp(envType, "tworooms") == 0){
00134 desc.title = "Environment: TwoRooms\n";
00135 e = new TwoRooms(rng, stochastic, true, delay, false);
00136 }
00137
00138
00139 else if (strcmp(envType, "car2to7") == 0){
00140 desc.title = "Environment: Car Velocity 2 to 7 m/s\n";
00141 e = new RobotCarVel(rng, false, true, false, lag);
00142 }
00143
00144 else if (strcmp(envType, "car7to2") == 0){
00145 desc.title = "Environment: Car Velocity 7 to 2 m/s\n";
00146 e = new RobotCarVel(rng, false, false, false, lag);
00147 }
00148
00149 else if (strcmp(envType, "carrandom") == 0){
00150 desc.title = "Environment: Car Velocity Random Velocities\n";
00151 e = new RobotCarVel(rng, true, false, false, lag);
00152 }
00153
00154
00155 else if (strcmp(envType, "fourrooms") == 0){
00156 desc.title = "Environment: FourRooms\n";
00157 e = new FourRooms(rng, stochastic, true, false);
00158 }
00159
00160
00161 else if (strcmp(envType, "energy") == 0){
00162 desc.title = "Environment: EnergyRooms\n";
00163 e = new EnergyRooms(rng, stochastic, true, false);
00164 }
00165
00166
00167 else if (strcmp(envType, "fuelworld") == 0){
00168 desc.title = "Environment: FuelWorld\n";
00169 e = new FuelRooms(rng, highvar, stochastic);
00170 }
00171
00172
00173 else if (strcmp(envType, "stocks") == 0){
00174 desc.title = "Environment: Stocks\n";
00175 e = new Stocks(rng, stochastic, nsectors, nstocks);
00176 }
00177
00178 else {
00179 std::cerr << "Invalid env type" << endl;
00180 displayHelp();
00181 exit(-1);
00182 }
00183
00184
00185 desc.num_actions = e->getNumActions();
00186 desc.episodic = e->isEpisodic();
00187
00188 std::vector<float> maxFeats;
00189 std::vector<float> minFeats;
00190
00191 e->getMinMaxFeatures(&minFeats, &maxFeats);
00192 desc.num_states = minFeats.size();
00193 desc.min_state_range = minFeats;
00194 desc.max_state_range = maxFeats;
00195
00196 desc.stochastic = stochastic;
00197 float minReward;
00198 float maxReward;
00199 e->getMinMaxReward(&minReward, &maxReward);
00200 desc.max_reward = maxReward;
00201 desc.reward_range = maxReward - minReward;
00202
00203 cout << desc.title << endl;
00204
00205
00206 out_env_desc.publish(desc);
00207
00208 sleep(1);
00209
00210
00211 std::vector<experience> seeds = e->getSeedings();
00212 for (unsigned i = 0; i < seeds.size(); i++){
00213 rl_msgs::RLEnvSeedExperience seed;
00214 seed.from_state = seeds[i].s;
00215 seed.to_state = seeds[i].next;
00216 seed.action = seeds[i].act;
00217 seed.reward = seeds[i].reward;
00218 seed.terminal = seeds[i].terminal;
00219 out_seed.publish(seed);
00220 }
00221
00222
00223 rl_msgs::RLStateReward sr;
00224 sr.terminal = false;
00225 sr.reward = 0;
00226 sr.state = e->sensation();
00227 out_env_sr.publish(sr);
00228
00229 }
00230
00231
00233 int main(int argc, char *argv[])
00234 {
00235 ros::init(argc, argv, NODE);
00236 ros::NodeHandle node;
00237
00238 if (argc < 2){
00239 cout << "--env type option is required" << endl;
00240 displayHelp();
00241 exit(-1);
00242 }
00243
00244
00245 if (argc < 3){
00246 displayHelp();
00247 exit(-1);
00248 }
00249
00250
00251 envType = argv[1];
00252 seed = std::atoi(argv[2]);
00253
00254
00255 bool gotEnv = false;
00256 for (int i = 1; i < argc-1; i++){
00257 if (strcmp(argv[i], "--env") == 0){
00258 gotEnv = true;
00259 envType = argv[i+1];
00260 }
00261 }
00262 if (!gotEnv) {
00263 cout << "--env type option is required" << endl;
00264 displayHelp();
00265 }
00266
00267
00268 char ch;
00269 const char* optflags = "ds:";
00270 int option_index = 0;
00271 static struct option long_options[] = {
00272 {"env", 1, 0, 'e'},
00273 {"deterministic", 0, 0, 'd'},
00274 {"stochastic", 0, 0, 's'},
00275 {"delay", 1, 0, 'a'},
00276 {"nsectors", 1, 0, 'c'},
00277 {"nstocks", 1, 0, 't'},
00278 {"lag", 0, 0, 'l'},
00279 {"nolag", 0, 0, 'o'},
00280 {"seed", 1, 0, 'x'},
00281 {"prints", 0, 0, 'p'},
00282 {"highvar", 0, 0, 'v'}
00283 };
00284
00285 while(-1 != (ch = getopt_long_only(argc, argv, optflags, long_options, &option_index))) {
00286 switch(ch) {
00287
00288 case 'x':
00289 seed = std::atoi(optarg);
00290 cout << "seed: " << seed << endl;
00291 break;
00292
00293 case 'd':
00294 stochastic = false;
00295 cout << "stochastic: " << stochastic << endl;
00296 break;
00297
00298 case 'v':
00299 {
00300 if (strcmp(envType, "fuelworld") == 0){
00301 highvar = true;
00302 cout << "fuel world fuel cost variation: " << highvar << endl;
00303 } else {
00304 cout << "--highvar is only a valid option for the fuelworld domain." << endl;
00305 exit(-1);
00306 }
00307 break;
00308 }
00309
00310 case 's':
00311 stochastic = true;
00312 cout << "stochastic: " << stochastic << endl;
00313 break;
00314
00315 case 'a':
00316 {
00317 if (strcmp(envType, "mcar") == 0 || strcmp(envType, "tworooms") == 0){
00318 delay = std::atoi(optarg);
00319 cout << "delay steps: " << delay << endl;
00320 } else {
00321 cout << "--delay option is only valid for the mcar and tworooms domains" << endl;
00322 exit(-1);
00323 }
00324 break;
00325 }
00326
00327 case 'c':
00328 {
00329 if (strcmp(envType, "stocks") == 0){
00330 nsectors = std::atoi(optarg);
00331 cout << "nsectors: " << nsectors << endl;
00332 } else {
00333 cout << "--nsectors option is only valid for the stocks domain" << endl;
00334 exit(-1);
00335 }
00336 break;
00337 }
00338
00339 case 't':
00340 {
00341 if (strcmp(envType, "stocks") == 0){
00342 nstocks = std::atoi(optarg);
00343 cout << "nstocks: " << nstocks << endl;
00344 } else {
00345 cout << "--nstocks option is only valid for the stocks domain" << endl;
00346 exit(-1);
00347 }
00348 break;
00349 }
00350
00351 case 'l':
00352 {
00353 if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00354 lag = true;
00355 cout << "lag: " << lag << endl;
00356 } else {
00357 cout << "--lag option is only valid for car velocity tasks" << endl;
00358 exit(-1);
00359 }
00360 break;
00361 }
00362
00363 case 'o':
00364 {
00365 if (strcmp(envType, "car2to7") == 0 || strcmp(envType, "car7to2") == 0 || strcmp(envType, "carrandom") == 0){
00366 lag = false;
00367 cout << "lag: " << lag << endl;
00368 } else {
00369 cout << "--nolag option is only valid for car velocity tasks" << endl;
00370 exit(-1);
00371 }
00372 break;
00373 }
00374
00375 case 'e':
00376
00377 cout << "env: " << envType << endl;
00378 break;
00379
00380 case 'p':
00381 PRINTS = true;
00382 break;
00383
00384 case 'h':
00385 case '?':
00386 case 0:
00387 default:
00388 displayHelp();
00389 break;
00390 }
00391 }
00392
00393
00394 int qDepth = 1;
00395
00396
00397 ros::init(argc, argv, "my_tf_broadcaster");
00398 tf::Transform transform;
00399
00400
00401 out_env_desc = node.advertise<rl_msgs::RLEnvDescription>("rl_env/rl_env_description",qDepth,true);
00402 out_env_sr = node.advertise<rl_msgs::RLStateReward>("rl_env/rl_state_reward",qDepth,false);
00403 out_seed = node.advertise<rl_msgs::RLEnvSeedExperience>("rl_env/rl_seed",20,false);
00404
00405
00406 ros::TransportHints noDelay = ros::TransportHints().tcpNoDelay(true);
00407 ros::Subscriber rl_action = node.subscribe("rl_agent/rl_action", qDepth, processAction, noDelay);
00408 ros::Subscriber rl_exp_info = node.subscribe("rl_agent/rl_experiment_info", qDepth, processEpisodeInfo, noDelay);
00409
00410
00411
00412 rng = Random(1+seed);
00413 initEnvironment();
00414
00415
00416 ROS_INFO(NODE ": starting main loop");
00417
00418 ros::spin();
00419
00420
00421
00422
00423 return 0;
00424 }
00425
00426
00427
00428