CartPole.cc
Go to the documentation of this file.
00001 
00006 #include <rl_env/CartPole.hh>
00007 
00008  
00009 CartPole::CartPole(Random &rand):
00010   noisy(false),
00011   rng(rand),
00012   s(4),
00013   cartPos(s[0]),
00014   cartVel(s[1]),
00015   poleAngle(s[2]),
00016   poleVel(s[3])
00017 {
00018   reset();
00019   //cout << *this << endl;
00020 }
00021  
00022 
00023 CartPole::CartPole(Random &rand, bool stochastic):
00024   noisy(stochastic),
00025   rng(rand),
00026   s(4),
00027   cartPos(s[0]),
00028   cartVel(s[1]),
00029   poleAngle(s[2]),
00030   poleVel(s[3])
00031 {
00032   reset();
00033 }
00034 
00035 
00036 CartPole::~CartPole() { }
00037 
00038 const std::vector<float> &CartPole::sensation() const { 
00039   //cout << "At state " << s[0] << ", " << s[1] << endl;
00040 
00041   return s; 
00042 }
00043 
00044 
00045 float CartPole::transition(float force){
00046 
00047   // transition
00048 
00049   float xacc;
00050   float thetaacc;
00051   float costheta;
00052   float sintheta;
00053   float temp;
00054   
00055   //Noise of 1.0 means possibly halfway to opposite action
00056   if (noisy){
00057     float thisNoise=1.0*FORCE_MAG*(rng.uniform(-0.5, 0.5));  
00058     force+=thisNoise;
00059   }
00060 
00061   costheta = cos(poleAngle);
00062   sintheta = sin(poleAngle);
00063 
00064   temp = (force + POLEMASS_LENGTH * poleVel * poleVel * sintheta) / TOTAL_MASS;
00065 
00066   thetaacc = (GRAVITY * sintheta - costheta * temp) / (LENGTH * (FOURTHIRDS - MASSPOLE * costheta * costheta / TOTAL_MASS));
00067 
00068   xacc = temp - POLEMASS_LENGTH * thetaacc * costheta / TOTAL_MASS;
00069 
00070   // Update the four state variables, using Euler's method. 
00071   cartPos += TAU * cartVel;
00072   cartVel += TAU * xacc;
00073   poleAngle += TAU * poleVel;
00074   poleVel += TAU * thetaacc;
00075   
00076   // These probably never happen because the pole would crash 
00077   while (poleAngle >= M_PI) {
00078     poleAngle -= 2.0 * M_PI;
00079   }
00080   while (poleAngle < -M_PI) {
00081     poleAngle += 2.0 * M_PI;
00082   }
00083 
00084   // dont velocities go past ranges
00085   if (fabs(cartVel) > 3){
00086     //    cout << "cart velocity out of range: " << cartVel << endl;
00087     if (cartVel > 0)
00088       cartVel = 3;
00089     else
00090       cartVel = -3;
00091   }
00092   if (fabs(poleVel) > M_PI){
00093     //    cout << "pole velocity out of range: " << poleVel << endl;
00094     if (poleVel > 0)
00095       poleVel = M_PI;
00096     else
00097       poleVel = -M_PI;
00098   }
00099 
00100   return reward();
00101 
00102 }
00103 
00104 
00105 float CartPole::apply(int action) {
00106 
00107   float force = 0;
00108   if (action == 1) {
00109     force = FORCE_MAG;
00110   } else {
00111     force = -FORCE_MAG;
00112   }
00113 
00114   return transition(force);
00115 }
00116 
00117  
00118 
00119 float CartPole::reward() {
00120 
00121   // normally +1 and 0 on goal
00122   if (terminal())
00123     return 0.0;
00124   else
00125     return 1.0;
00126 }
00127 
00128 
00129 
00130 bool CartPole::terminal() const {
00131   // current position past termination conditions (off track, pole angle)
00132   return (fabs(poleAngle) > (DEG_T_RAD*12.0) || fabs(cartPos) > 2.4);
00133 }
00134 
00135 
00136 
00137 void CartPole::reset() {
00138 
00139   GRAVITY = 9.8;
00140   MASSCART = 1.0;
00141   MASSPOLE = 0.1;
00142   TOTAL_MASS = (MASSPOLE + MASSCART);
00143   LENGTH = 0.5;   // actually half the pole's length 
00144   
00145   POLEMASS_LENGTH = (MASSPOLE * LENGTH);
00146   FORCE_MAG = 10.0;
00147   TAU = 0.02;     // seconds between state updates 
00148   
00149   FOURTHIRDS = 4.0 / 3.0;
00150   DEG_T_RAD = 0.01745329;
00151   RAD_T_DEG = 1.0/DEG_T_RAD;
00152 
00153   if (noisy){
00154     cartPos = rng.uniform(-0.5, 0.5);
00155     cartVel = rng.uniform(-0.5, 0.5);
00156     poleAngle = rng.uniform(-0.0625, 0.0625);
00157     poleVel = rng.uniform(-0.0625, 0.0625);
00158   } else {
00159     cartPos = 0.0;
00160     cartVel = 0.0;
00161     poleAngle = 0.0;
00162     poleVel = 0.0;
00163   }
00164 
00165 }
00166 
00167 
00168 
00169 int CartPole::getNumActions(){
00170   return 2;
00171 }
00172 
00173 
00174 void CartPole::setSensation(std::vector<float> newS){
00175   if (s.size() != newS.size()){
00176     cerr << "Error in sensation sizes" << endl;
00177   }
00178 
00179   for (unsigned i = 0; i < newS.size(); i++){
00180     s[i] = newS[i];
00181   }
00182 }
00183 
00184 std::vector<experience> CartPole::getSeedings() {
00185 
00186   // return seedings
00187   std::vector<experience> seeds;
00188 
00189   // single seed of each 4 terminal cases
00190   seeds.push_back(getExp(-2.4, -0.1, 0, 0, 0));
00191   seeds.push_back(getExp(2.4, 0.2, 0.1, 0.2, 1));
00192   seeds.push_back(getExp(0.4, 0.3, 0.2, 0.3, 0));
00193   seeds.push_back(getExp(-.3, 0.05, -0.2, -0.4, 1));
00194 
00195   reset();
00196 
00197   return seeds;
00198 
00199 }
00200 
00201 experience CartPole::getExp(float s0, float s1, float s2, float s3, int a){
00202 
00203   experience e;
00204 
00205   e.s.resize(4, 0.0);
00206   e.next.resize(4, 0.0);
00207 
00208   cartPos = s0;
00209   cartVel = s1;
00210   poleAngle = s2;
00211   poleVel = s3;
00212 
00213   e.act = a;
00214   e.s = sensation();
00215   e.reward = apply(e.act);
00216 
00217   e.terminal = terminal();
00218   e.next = sensation();
00219 
00220   return e;
00221 }
00222 
00223 void CartPole::getMinMaxFeatures(std::vector<float> *minFeat,
00224                                  std::vector<float> *maxFeat){
00225   
00226   minFeat->resize(s.size(), 0.0);
00227   maxFeat->resize(s.size(), 1.0);
00228 
00229   (*minFeat)[0] = -2.5;//3;
00230   (*maxFeat)[0] = 2.5;//3;
00231 
00232   (*minFeat)[1] = -3.0;
00233   (*maxFeat)[1] = 3.0;
00234 
00235   (*minFeat)[2] = -12.0 * DEG_T_RAD;
00236   (*maxFeat)[2] = 12.0 * DEG_T_RAD;
00237   
00238   (*minFeat)[3] = -M_PI;
00239   (*maxFeat)[3] = M_PI;
00240 
00241 }
00242 
00243 void CartPole::getMinMaxReward(float *minR,
00244                               float *maxR){
00245   
00246   *minR = 0.0;
00247   *maxR = 1.0;    
00248   
00249 }


rl_env
Author(s):
autogenerated on Thu Jun 6 2019 22:00:23