$search
00001 #include "learned_controller.h" 00002 #include <iostream> 00003 #include <fstream> 00004 #include <ros/ros.h> 00005 #include <ros/package.h> 00006 00008 LearnedSpeedControl::LearnedSpeedControl(): 00009 SpeedControl(),numactions(5) 00010 { 00011 LOADDEBUG = false; 00012 loaded = false; 00013 00014 ROS_WARN("Learned Speed Controller, load policy"); 00015 00016 // create agent and load file 00017 std::string policyPath = (ros::package::getPath("art_pilot") 00018 + "/src/pilot/control1400.pol"); 00019 loadPolicy(policyPath.c_str()); 00020 00021 // init state vector 00022 s.resize(4,0); 00023 } 00024 00026 LearnedSpeedControl::~LearnedSpeedControl() 00027 { 00028 statespace.clear(); 00029 Q.clear(); 00030 s.clear(); 00031 } 00032 00045 void LearnedSpeedControl::adjust(float speed, float error, 00046 float *brake_req, float *throttle_req) 00047 { 00048 00049 // lets get actual target vel 00050 float targetVel = speed + error; 00051 00052 if (targetVel > 11.0 || targetVel < 0.0){ 00053 ROS_DEBUG("Target Vel out of range: %f", targetVel); 00054 } 00055 00056 ROS_DEBUG("Speed %f, error %f, target %f", speed, error, targetVel); 00057 ROS_DEBUG("Throt_pos %f, Throt_req %f, brake_pos %f, brake_req %f", 00058 throttle_position_, *throttle_req, brake_position_, *brake_req); 00059 00060 // out of range 00061 if (targetVel < 0) 00062 targetVel = 0; 00063 if (targetVel > 11) 00064 targetVel = 11; 00065 if (speed < 0) 00066 speed = 0; 00067 if (speed > 12) 00068 speed = 12; 00069 00070 // convert to discrete 00071 float f1 = 0.5; 00072 float f2 = 0.1; 00073 float f3 = 0.1; 00074 float f4 = 0.1; 00075 00076 float EPSILON = 0.001; 00077 00078 s[0] = (int)((targetVel+EPSILON) / f1); 00079 s[1] = (int)((speed+EPSILON) / f2); 00080 s[2] = (int)((*throttle_req+EPSILON) / f3); 00081 s[3] = (int)((*brake_req+EPSILON) / f4); 00082 00083 ROS_DEBUG("State: %f, %f, %f, %f", s[0], s[1], s[2], s[3]); 00084 00085 // get action from agent 00086 int act = getAction(s); 00087 00088 // fix trouble starting from full brake 00089 if (speed < 0.01 && targetVel > 0 && *brake_req > 0.0 && act != 3){ 00090 ROS_WARN("Chose bad accel from stop. State %f, %f, %f, %f, action %i", 00091 s[0], s[1], s[2], s[3], act); 00092 act = 3; 00093 } 00094 00095 // set throttle and brake based on action 00096 if (act == 0){ 00097 // no change 00098 } else if (act == 1){ 00099 *throttle_req = 0; 00100 *brake_req += 0.1; 00101 } else if (act == 2){ 00102 *throttle_req = 0; 00103 *brake_req -= 0.1; 00104 } else if (act == 3){ 00105 *throttle_req += 0.1; 00106 *brake_req = 0; 00107 } else if (act == 4){ 00108 *throttle_req -= 0.1; 00109 *brake_req = 0; 00110 } else { 00111 ROS_DEBUG("ERROR: invalid action: %i", act); 00112 } 00113 00114 // dont allow throttle over 0.4 00115 if(*throttle_req > 0.4) 00116 *throttle_req = 0.4; 00117 00118 ROS_DEBUG("action %i, throttle %f, brake %f", act, *throttle_req, *brake_req); 00119 00120 } 00121 00123 void LearnedSpeedControl::configure(art_pilot::PilotConfig &newconfig) 00124 { 00125 } 00126 00128 void LearnedSpeedControl::reset(void) 00129 { 00130 } 00131 00132 00133 00134 00135 00136 int LearnedSpeedControl::getAction(const std::vector<float> &s) { 00137 00138 // Get action values 00139 std::vector<float> &Q_s = Q[canonicalize(s)]; 00140 const std::vector<float>::iterator max = 00141 std::max_element(Q_s.begin(), Q_s.end()); 00142 00143 // Choose an action 00144 const std::vector<float>::iterator a = max; 00145 00146 return a - Q_s.begin(); 00147 } 00148 00149 00150 LearnedSpeedControl::state_t LearnedSpeedControl::canonicalize(const std::vector<float> &s) { 00151 const std::pair<std::set<std::vector<float> >::iterator, bool> result = 00152 statespace.insert(s); 00153 state_t retval = &*result.first; // Dereference iterator then get pointer 00154 if (result.second) { // s is new, so initialize Q(s,a) for all a 00155 if (loaded){ 00156 ROS_ERROR("State unknown in policy: %f, %f, %f, %f", s[0], s[1], s[2], s[3]); 00157 } 00158 std::vector<float> &Q_s = Q[retval]; 00159 Q_s.resize(numactions,0.0); 00160 } 00161 return retval; 00162 } 00163 00164 00165 void LearnedSpeedControl::loadPolicy(const char* filename){ 00166 00167 using namespace std; 00168 00169 ifstream policyFile(filename, ios::in | ios::binary); 00170 00171 // first part, save the vector size 00172 int fsize; 00173 policyFile.read((char*)&fsize, sizeof(int)); 00174 if (fsize != 4){ 00175 ROS_WARN("this policy is not valid, loaded nfeat %i, instead of 4", fsize); 00176 } 00177 00178 // save numactions 00179 int nact; 00180 policyFile.read((char*)&nact, sizeof(int)); 00181 00182 if (nact != 5){ 00183 ROS_DEBUG("this policy is not valid, loaded nact %i, instead of 5", nact); 00184 } 00185 00186 // go through all states, loading q values 00187 while(!policyFile.eof()){ 00188 std::vector<float> state; 00189 state.resize(fsize, 0.0); 00190 00191 // load state 00192 policyFile.read((char*)&(state[0]), sizeof(float)*fsize); 00193 if (LOADDEBUG){ 00194 ROS_DEBUG("load policy for state %f, %f, %f, %f", state[0], state[1], state[2], state[3]); 00195 } 00196 00197 state_t s = canonicalize(state); 00198 00199 if (policyFile.eof()) break; 00200 00201 // load q values 00202 policyFile.read((char*)&(Q[s][0]), sizeof(float)*numactions); 00203 00204 if (LOADDEBUG){ 00205 ROS_DEBUG("Q values: %f, %f, %f, %f, %f", Q[s][0],Q[s][1],Q[s][2],Q[s][3],Q[s][4]); 00206 } 00207 } 00208 00209 policyFile.close(); 00210 ROS_DEBUG("Policy loaded!!!"); 00211 loaded = true; 00212 }