learned_controller.cc
Go to the documentation of this file.
00001 #include "learned_controller.h"
00002 #include <iostream>
00003 #include <fstream>
00004 #include <ros/ros.h>
00005 #include <ros/package.h>
00006 
00008 LearnedSpeedControl::LearnedSpeedControl():
00009   SpeedControl(),numactions(5)
00010 {
00011   LOADDEBUG = false;
00012   loaded = false;
00013 
00014   ROS_WARN("Learned Speed Controller, load policy");
00015 
00016   // create agent and load file
00017   std::string policyPath = (ros::package::getPath("art_pilot")
00018                             + "/src/pilot/control1400.pol");
00019   loadPolicy(policyPath.c_str());
00020 
00021   // init state vector
00022   s.resize(4,0);
00023 }
00024 
00026 LearnedSpeedControl::~LearnedSpeedControl()
00027 {
00028   statespace.clear();
00029   Q.clear();
00030   s.clear();
00031 }
00032 
00045 void LearnedSpeedControl::adjust(float speed, float error,
00046                                 float *brake_req, float *throttle_req)
00047 {
00048 
00049   // lets get actual target vel
00050   float targetVel = speed + error;
00051 
00052   if (targetVel > 11.0 || targetVel < 0.0){
00053     ROS_DEBUG("Target Vel out of range: %f", targetVel);
00054   }
00055 
00056   ROS_DEBUG("Speed %f, error %f, target %f", speed, error, targetVel);
00057   ROS_DEBUG("Throt_pos %f, Throt_req %f, brake_pos %f, brake_req %f",
00058            throttle_position_, *throttle_req, brake_position_, *brake_req);
00059 
00060   // out of range
00061   if (targetVel < 0)
00062     targetVel = 0;
00063   if (targetVel > 11)
00064     targetVel = 11;
00065   if (speed < 0)
00066     speed = 0;
00067   if (speed > 12)
00068     speed = 12;
00069 
00070   // convert to discrete
00071   float f1 = 0.5;
00072   float f2 = 0.1;
00073   float f3 = 0.1;
00074   float f4 = 0.1;
00075 
00076   float EPSILON = 0.001;
00077 
00078   s[0] = (int)((targetVel+EPSILON) / f1);
00079   s[1] = (int)((speed+EPSILON) / f2);
00080   s[2] = (int)((*throttle_req+EPSILON) / f3);
00081   s[3] = (int)((*brake_req+EPSILON) / f4);
00082 
00083   ROS_DEBUG("State: %f, %f, %f, %f", s[0], s[1], s[2], s[3]);
00084 
00085   // get action from agent
00086   int act = getAction(s);
00087   
00088   // fix trouble starting from full brake
00089   if (speed < 0.01 && targetVel > 0 && *brake_req > 0.0 && act != 3){
00090     ROS_WARN("Chose bad accel from stop. State %f, %f, %f, %f, action %i",
00091              s[0], s[1], s[2], s[3], act);
00092     act = 3;
00093   }
00094 
00095   // set throttle and brake based on action
00096   if (act == 0){
00097     // no change
00098   } else if (act == 1){
00099     *throttle_req = 0;
00100     *brake_req += 0.1;
00101   } else if (act == 2){
00102     *throttle_req = 0;
00103     *brake_req -= 0.1;
00104   } else if (act == 3){
00105     *throttle_req += 0.1;
00106     *brake_req = 0;
00107   } else if (act == 4){
00108     *throttle_req -= 0.1;
00109     *brake_req = 0;
00110   } else {
00111     ROS_DEBUG("ERROR: invalid action: %i", act);
00112   }
00113 
00114   // dont allow throttle over 0.4
00115   if(*throttle_req > 0.4)
00116     *throttle_req = 0.4;
00117 
00118   ROS_DEBUG("action %i, throttle %f, brake %f", act, *throttle_req, *brake_req);
00119 
00120 }
00121 
00123 void LearnedSpeedControl::configure(art_pilot::PilotConfig &newconfig)
00124 {
00125 }
00126 
00128 void LearnedSpeedControl::reset(void)
00129 {
00130 }
00131 
00132 
00133 
00134 
00135 
00136 int LearnedSpeedControl::getAction(const std::vector<float> &s) {
00137 
00138   // Get action values
00139   std::vector<float> &Q_s = Q[canonicalize(s)];
00140   const std::vector<float>::iterator max =
00141     std::max_element(Q_s.begin(), Q_s.end());
00142 
00143   // Choose an action
00144   const std::vector<float>::iterator a = max;
00145 
00146   return a - Q_s.begin();
00147 }
00148 
00149 
00150 LearnedSpeedControl::state_t LearnedSpeedControl::canonicalize(const std::vector<float> &s) {
00151   const std::pair<std::set<std::vector<float> >::iterator, bool> result =
00152     statespace.insert(s);
00153   state_t retval = &*result.first; // Dereference iterator then get pointer 
00154   if (result.second) { // s is new, so initialize Q(s,a) for all a
00155     if (loaded){
00156       ROS_ERROR("State unknown in policy: %f, %f, %f, %f", s[0], s[1], s[2], s[3]);
00157     }
00158     std::vector<float> &Q_s = Q[retval];
00159     Q_s.resize(numactions,0.0);
00160   }
00161   return retval; 
00162 }
00163 
00164 
00165 void LearnedSpeedControl::loadPolicy(const char* filename){
00166 
00167   using namespace std;
00168 
00169   ifstream policyFile(filename, ios::in | ios::binary);
00170 
00171   // first part, save the vector size
00172   int fsize;
00173   policyFile.read((char*)&fsize, sizeof(int));
00174   if (fsize != 4){
00175     ROS_WARN("this policy is not valid, loaded nfeat %i, instead of 4", fsize);
00176   }
00177 
00178   // save numactions
00179   int nact;
00180   policyFile.read((char*)&nact, sizeof(int));
00181 
00182   if (nact != 5){
00183     ROS_DEBUG("this policy is not valid, loaded nact %i, instead of 5", nact);
00184   }
00185 
00186   // go through all states, loading q values
00187   while(!policyFile.eof()){
00188     std::vector<float> state;
00189     state.resize(fsize, 0.0);
00190 
00191     // load state
00192     policyFile.read((char*)&(state[0]), sizeof(float)*fsize);
00193     if (LOADDEBUG){
00194       ROS_DEBUG("load policy for state %f, %f, %f, %f", state[0], state[1], state[2], state[3]); 
00195     }
00196 
00197     state_t s = canonicalize(state);
00198 
00199     if (policyFile.eof()) break;
00200 
00201     // load q values
00202     policyFile.read((char*)&(Q[s][0]), sizeof(float)*numactions);
00203     
00204     if (LOADDEBUG){
00205       ROS_DEBUG("Q values: %f, %f, %f, %f, %f", Q[s][0],Q[s][1],Q[s][2],Q[s][3],Q[s][4]);
00206     }
00207   }
00208   
00209   policyFile.close();
00210   ROS_DEBUG("Policy loaded!!!");
00211   loaded = true;
00212 }


art_pilot
Author(s): Austin Robot Technology, Jack O'Quin
autogenerated on Fri Jan 3 2014 11:09:32