00001 #include "learned_controller.h"
00002 #include <iostream>
00003 #include <fstream>
00004 #include <ros/ros.h>
00005 #include <ros/package.h>
00006
00008 LearnedSpeedControl::LearnedSpeedControl():
00009 SpeedControl(),numactions(5)
00010 {
00011 LOADDEBUG = false;
00012 loaded = false;
00013
00014 ROS_WARN("Learned Speed Controller, load policy");
00015
00016
00017 std::string policyPath = (ros::package::getPath("art_pilot")
00018 + "/src/pilot/control1400.pol");
00019 loadPolicy(policyPath.c_str());
00020
00021
00022 s.resize(4,0);
00023 }
00024
00026 LearnedSpeedControl::~LearnedSpeedControl()
00027 {
00028 statespace.clear();
00029 Q.clear();
00030 s.clear();
00031 }
00032
00045 void LearnedSpeedControl::adjust(float speed, float error,
00046 float *brake_req, float *throttle_req)
00047 {
00048
00049
00050 float targetVel = speed + error;
00051
00052 if (targetVel > 11.0 || targetVel < 0.0){
00053 ROS_DEBUG("Target Vel out of range: %f", targetVel);
00054 }
00055
00056 ROS_DEBUG("Speed %f, error %f, target %f", speed, error, targetVel);
00057 ROS_DEBUG("Throt_pos %f, Throt_req %f, brake_pos %f, brake_req %f",
00058 throttle_position_, *throttle_req, brake_position_, *brake_req);
00059
00060
00061 if (targetVel < 0)
00062 targetVel = 0;
00063 if (targetVel > 11)
00064 targetVel = 11;
00065 if (speed < 0)
00066 speed = 0;
00067 if (speed > 12)
00068 speed = 12;
00069
00070
00071 float f1 = 0.5;
00072 float f2 = 0.1;
00073 float f3 = 0.1;
00074 float f4 = 0.1;
00075
00076 float EPSILON = 0.001;
00077
00078 s[0] = (int)((targetVel+EPSILON) / f1);
00079 s[1] = (int)((speed+EPSILON) / f2);
00080 s[2] = (int)((*throttle_req+EPSILON) / f3);
00081 s[3] = (int)((*brake_req+EPSILON) / f4);
00082
00083 ROS_DEBUG("State: %f, %f, %f, %f", s[0], s[1], s[2], s[3]);
00084
00085
00086 int act = getAction(s);
00087
00088
00089 if (speed < 0.01 && targetVel > 0 && *brake_req > 0.0 && act != 3){
00090 ROS_WARN("Chose bad accel from stop. State %f, %f, %f, %f, action %i",
00091 s[0], s[1], s[2], s[3], act);
00092 act = 3;
00093 }
00094
00095
00096 if (act == 0){
00097
00098 } else if (act == 1){
00099 *throttle_req = 0;
00100 *brake_req += 0.1;
00101 } else if (act == 2){
00102 *throttle_req = 0;
00103 *brake_req -= 0.1;
00104 } else if (act == 3){
00105 *throttle_req += 0.1;
00106 *brake_req = 0;
00107 } else if (act == 4){
00108 *throttle_req -= 0.1;
00109 *brake_req = 0;
00110 } else {
00111 ROS_DEBUG("ERROR: invalid action: %i", act);
00112 }
00113
00114
00115 if(*throttle_req > 0.4)
00116 *throttle_req = 0.4;
00117
00118 ROS_DEBUG("action %i, throttle %f, brake %f", act, *throttle_req, *brake_req);
00119
00120 }
00121
00123 void LearnedSpeedControl::configure(art_pilot::PilotConfig &newconfig)
00124 {
00125 }
00126
00128 void LearnedSpeedControl::reset(void)
00129 {
00130 }
00131
00132
00133
00134
00135
00136 int LearnedSpeedControl::getAction(const std::vector<float> &s) {
00137
00138
00139 std::vector<float> &Q_s = Q[canonicalize(s)];
00140 const std::vector<float>::iterator max =
00141 std::max_element(Q_s.begin(), Q_s.end());
00142
00143
00144 const std::vector<float>::iterator a = max;
00145
00146 return a - Q_s.begin();
00147 }
00148
00149
00150 LearnedSpeedControl::state_t LearnedSpeedControl::canonicalize(const std::vector<float> &s) {
00151 const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00152 statespace.insert(s);
00153 state_t retval = &*result.first;
00154 if (result.second) {
00155 if (loaded){
00156 ROS_ERROR("State unknown in policy: %f, %f, %f, %f", s[0], s[1], s[2], s[3]);
00157 }
00158 std::vector<float> &Q_s = Q[retval];
00159 Q_s.resize(numactions,0.0);
00160 }
00161 return retval;
00162 }
00163
00164
00165 void LearnedSpeedControl::loadPolicy(const char* filename){
00166
00167 using namespace std;
00168
00169 ifstream policyFile(filename, ios::in | ios::binary);
00170
00171
00172 int fsize;
00173 policyFile.read((char*)&fsize, sizeof(int));
00174 if (fsize != 4){
00175 ROS_WARN("this policy is not valid, loaded nfeat %i, instead of 4", fsize);
00176 }
00177
00178
00179 int nact;
00180 policyFile.read((char*)&nact, sizeof(int));
00181
00182 if (nact != 5){
00183 ROS_DEBUG("this policy is not valid, loaded nact %i, instead of 5", nact);
00184 }
00185
00186
00187 while(!policyFile.eof()){
00188 std::vector<float> state;
00189 state.resize(fsize, 0.0);
00190
00191
00192 policyFile.read((char*)&(state[0]), sizeof(float)*fsize);
00193 if (LOADDEBUG){
00194 ROS_DEBUG("load policy for state %f, %f, %f, %f", state[0], state[1], state[2], state[3]);
00195 }
00196
00197 state_t s = canonicalize(state);
00198
00199 if (policyFile.eof()) break;
00200
00201
00202 policyFile.read((char*)&(Q[s][0]), sizeof(float)*numactions);
00203
00204 if (LOADDEBUG){
00205 ROS_DEBUG("Q values: %f, %f, %f, %f, %f", Q[s][0],Q[s][1],Q[s][2],Q[s][3],Q[s][4]);
00206 }
00207 }
00208
00209 policyFile.close();
00210 ROS_DEBUG("Policy loaded!!!");
00211 loaded = true;
00212 }