Go to the documentation of this file.00001
00006 #include <rl_env/CartPole.hh>
00007
00008
00009 CartPole::CartPole(Random &rand):
00010 noisy(false),
00011 rng(rand),
00012 s(4),
00013 cartPos(s[0]),
00014 cartVel(s[1]),
00015 poleAngle(s[2]),
00016 poleVel(s[3])
00017 {
00018 reset();
00019
00020 }
00021
00022
00023 CartPole::CartPole(Random &rand, bool stochastic):
00024 noisy(stochastic),
00025 rng(rand),
00026 s(4),
00027 cartPos(s[0]),
00028 cartVel(s[1]),
00029 poleAngle(s[2]),
00030 poleVel(s[3])
00031 {
00032 reset();
00033 }
00034
00035
00036 CartPole::~CartPole() { }
00037
00038 const std::vector<float> &CartPole::sensation() const {
00039
00040
00041 return s;
00042 }
00043
00044
00045 float CartPole::transition(float force){
00046
00047
00048
00049 float xacc;
00050 float thetaacc;
00051 float costheta;
00052 float sintheta;
00053 float temp;
00054
00055
00056 if (noisy){
00057 float thisNoise=1.0*FORCE_MAG*(rng.uniform(-0.5, 0.5));
00058 force+=thisNoise;
00059 }
00060
00061 costheta = cos(poleAngle);
00062 sintheta = sin(poleAngle);
00063
00064 temp = (force + POLEMASS_LENGTH * poleVel * poleVel * sintheta) / TOTAL_MASS;
00065
00066 thetaacc = (GRAVITY * sintheta - costheta * temp) / (LENGTH * (FOURTHIRDS - MASSPOLE * costheta * costheta / TOTAL_MASS));
00067
00068 xacc = temp - POLEMASS_LENGTH * thetaacc * costheta / TOTAL_MASS;
00069
00070
00071 cartPos += TAU * cartVel;
00072 cartVel += TAU * xacc;
00073 poleAngle += TAU * poleVel;
00074 poleVel += TAU * thetaacc;
00075
00076
00077 while (poleAngle >= M_PI) {
00078 poleAngle -= 2.0 * M_PI;
00079 }
00080 while (poleAngle < -M_PI) {
00081 poleAngle += 2.0 * M_PI;
00082 }
00083
00084
00085 if (fabs(cartVel) > 3){
00086
00087 if (cartVel > 0)
00088 cartVel = 3;
00089 else
00090 cartVel = -3;
00091 }
00092 if (fabs(poleVel) > M_PI){
00093
00094 if (poleVel > 0)
00095 poleVel = M_PI;
00096 else
00097 poleVel = -M_PI;
00098 }
00099
00100 return reward();
00101
00102 }
00103
00104
00105 float CartPole::apply(int action) {
00106
00107 float force = 0;
00108 if (action == 1) {
00109 force = FORCE_MAG;
00110 } else {
00111 force = -FORCE_MAG;
00112 }
00113
00114 return transition(force);
00115 }
00116
00117
00118
00119 float CartPole::reward() {
00120
00121
00122 if (terminal())
00123 return 0.0;
00124 else
00125 return 1.0;
00126 }
00127
00128
00129
00130 bool CartPole::terminal() const {
00131
00132 return (fabs(poleAngle) > (DEG_T_RAD*12.0) || fabs(cartPos) > 2.4);
00133 }
00134
00135
00136
00137 void CartPole::reset() {
00138
00139 GRAVITY = 9.8;
00140 MASSCART = 1.0;
00141 MASSPOLE = 0.1;
00142 TOTAL_MASS = (MASSPOLE + MASSCART);
00143 LENGTH = 0.5;
00144
00145 POLEMASS_LENGTH = (MASSPOLE * LENGTH);
00146 FORCE_MAG = 10.0;
00147 TAU = 0.02;
00148
00149 FOURTHIRDS = 4.0 / 3.0;
00150 DEG_T_RAD = 0.01745329;
00151 RAD_T_DEG = 1.0/DEG_T_RAD;
00152
00153 if (noisy){
00154 cartPos = rng.uniform(-0.5, 0.5);
00155 cartVel = rng.uniform(-0.5, 0.5);
00156 poleAngle = rng.uniform(-0.0625, 0.0625);
00157 poleVel = rng.uniform(-0.0625, 0.0625);
00158 } else {
00159 cartPos = 0.0;
00160 cartVel = 0.0;
00161 poleAngle = 0.0;
00162 poleVel = 0.0;
00163 }
00164
00165 }
00166
00167
00168
00169 int CartPole::getNumActions(){
00170 return 2;
00171 }
00172
00173
00174 void CartPole::setSensation(std::vector<float> newS){
00175 if (s.size() != newS.size()){
00176 cerr << "Error in sensation sizes" << endl;
00177 }
00178
00179 for (unsigned i = 0; i < newS.size(); i++){
00180 s[i] = newS[i];
00181 }
00182 }
00183
00184 std::vector<experience> CartPole::getSeedings() {
00185
00186
00187 std::vector<experience> seeds;
00188
00189
00190 seeds.push_back(getExp(-2.4, -0.1, 0, 0, 0));
00191 seeds.push_back(getExp(2.4, 0.2, 0.1, 0.2, 1));
00192 seeds.push_back(getExp(0.4, 0.3, 0.2, 0.3, 0));
00193 seeds.push_back(getExp(-.3, 0.05, -0.2, -0.4, 1));
00194
00195 reset();
00196
00197 return seeds;
00198
00199 }
00200
00201 experience CartPole::getExp(float s0, float s1, float s2, float s3, int a){
00202
00203 experience e;
00204
00205 e.s.resize(4, 0.0);
00206 e.next.resize(4, 0.0);
00207
00208 cartPos = s0;
00209 cartVel = s1;
00210 poleAngle = s2;
00211 poleVel = s3;
00212
00213 e.act = a;
00214 e.s = sensation();
00215 e.reward = apply(e.act);
00216
00217 e.terminal = terminal();
00218 e.next = sensation();
00219
00220 return e;
00221 }
00222
00223 void CartPole::getMinMaxFeatures(std::vector<float> *minFeat,
00224 std::vector<float> *maxFeat){
00225
00226 minFeat->resize(s.size(), 0.0);
00227 maxFeat->resize(s.size(), 1.0);
00228
00229 (*minFeat)[0] = -2.5;
00230 (*maxFeat)[0] = 2.5;
00231
00232 (*minFeat)[1] = -3.0;
00233 (*maxFeat)[1] = 3.0;
00234
00235 (*minFeat)[2] = -12.0 * DEG_T_RAD;
00236 (*maxFeat)[2] = 12.0 * DEG_T_RAD;
00237
00238 (*minFeat)[3] = -M_PI;
00239 (*maxFeat)[3] = M_PI;
00240
00241 }
00242
00243 void CartPole::getMinMaxReward(float *minR,
00244 float *maxR){
00245
00246 *minR = 0.0;
00247 *maxR = 1.0;
00248
00249 }