cba.cpp
Go to the documentation of this file.
00001 
00011 #include <ANN/ANN.h>
00012 #include <cba/cba.h>
00013 #include <lfd_common/action_complete.h>
00014 #include <lfd_common/classification_point.h>
00015 #include <lfd_common/conf_classification.h>
00016 #include <lfd_common/demonstration.h>
00017 #include <lfd_common/state.h>
00018 #include <limits>
00019 #include <ros/ros.h>
00020 #include <std_msgs/Int32.h>
00021 #include <vector>
00022 
00023 using namespace std;
00024 
00025 cba_learner::cba_learner()
00026 {
00027   // add subscriptions and services
00028   execute = node.advertise<std_msgs::Int32>("execute", 1);
00029   add_point = node.advertise<lfd_common::classification_point>("add_point", -1);
00030   change_point = node.advertise<lfd_common::classification_point>("change_point", -1);
00031   state_listener = node.subscribe<lfd_common::state>("state_listener", 1, &cba_learner::state_listener_callback, this);
00032   a_complete = node.advertiseService("a_complete", &cba_learner::a_complete_callback, this);
00033   classify = node.serviceClient<lfd_common::conf_classification>("classify");
00034   correction = node.serviceClient<lfd_common::demonstration>("correction");
00035   demonstration = node.serviceClient<lfd_common::demonstration>("demonstration");
00036 
00037   // check for the maximum number of data points to allocate
00038   ros::param::param<int>(MAX_DATA_POINTS, max_pts, DEFAULT_MAX_POINTS);
00039   // check for the distance threshold multiplier
00040   ros::param::param<double>(DIST_THRESH_MULT, dist_mult, DEFAULT_DIST_MULT);
00041 
00042   // initial CBA values
00043   labels = NULL;
00044   s = NULL;
00045   sc = NULL;
00046   a = -1;
00047   s_size = -1;
00048   pts = 0;
00049   action_complete = true;
00050   autonomous_action = false;
00051   dist_thresh = 0;
00052 
00053   ROS_INFO("CBA Learner Initialized");
00054 }
00055 
00056 cba_learner::~cba_learner()
00057 {
00058   // clear the state vectors if they exists
00059   if (s != NULL)
00060     free(s);
00061   if (sc != NULL)
00062     free(sc);
00063   // check the labels set
00064   if (labels != NULL)
00065     free(labels);
00066   // cleanup ANN
00067   if (pts > 0)
00068     annDeallocPts(ann_data);
00069   annClose();
00070 }
00071 
00072 void cba_learner::step()
00073 {
00074   // check if the state exists
00075   if (s == NULL)
00076     return;
00077 
00078   // check if the agent has reported their action finished
00079   if (action_complete)
00080   {
00081     // request a prediction from the classifier
00082     prediction *p = classify_state();
00083 
00084     // calculate the nearest neighbor distance
00085     double d = nearest_neighbor();
00086 
00087     // check against the thresholds
00088     if (p->c > conf_thresh(p->l, p->db) && d < dist_thresh)
00089     {
00090       // report the action to be executed
00091       std_msgs::Int32 a;
00092       a.data = p->l;
00093       execute.publish(a);
00094       action_complete = false;
00095       autonomous_action = true;
00096       // save a copy of the current state
00097       memcpy(sc, s, sizeof(float) * s_size);
00098     }
00099     else
00100     {
00101       // request a demonstration
00102       if (!demonstration.exists())
00103         ROS_WARN("Could not connect to 'demonstration' service.");
00104       else
00105       {
00106         // create the service request
00107         lfd_common::demonstration dem;
00108         for (int i = 0; i < s_size; i++)
00109           dem.request.s.state_vector.push_back(s[i]);
00110         // send the service request
00111         demonstration.call(dem);
00112 
00113         // check if the user provided a demonstration
00114         if (dem.response.valid)
00115         {
00116           // check if we already have this data point in our set
00117           bool duplicate;
00118           for (int i = 0; i < pts; i++)
00119           {
00120             for (int j = 0; j < s_size; j++)
00121               if (ann_data[i][j] == s[j])
00122                 duplicate = true;
00123               else
00124               {
00125                 duplicate = false;
00126                 break;
00127               }
00128             if (duplicate)
00129             {
00130               // check if the label was changed
00131               if (labels[i] != dem.response.a)
00132               {
00133                 // change the label
00134                 labels[i] = dem.response.a;
00135                 //update the point in the classifier
00136                 lfd_common::classification_point cp;
00137                 for (int i = 0; i < s_size; i++)
00138                   cp.s.state_vector.push_back(s[i]);
00139                 cp.l = dem.response.a;
00140                 change_point.publish(cp);
00141 
00142                 // update the thresholds
00143                 update_thresholds();
00144               }
00145               break;
00146             }
00147           }
00148 
00149           // new data point
00150           if (!duplicate)
00151           {
00152             // update the data set
00153             if (pts == max_pts)
00154               ROS_WARN("Too many data points -- latest point ignored.");
00155             else
00156             {
00157               // add the data to the ANN set
00158               for (int i = 0; i < s_size; i++)
00159                 ann_data[pts][i] = s[i];
00160               // set the label
00161               labels[pts] = dem.response.a;
00162               pts++;
00163 
00164               // send the point to the classifier
00165               lfd_common::classification_point cp;
00166               for (int i = 0; i < s_size; i++)
00167                 cp.s.state_vector.push_back(s[i]);
00168               cp.l = dem.response.a;
00169               add_point.publish(cp);
00170 
00171               // update the thresholds
00172               update_thresholds();
00173             }
00174           }
00175 
00176           // report the action to be executed
00177           std_msgs::Int32 a;
00178           a.data = dem.response.a;
00179           execute.publish(a);
00180           action_complete = false;
00181           autonomous_action = false;
00182         }
00183       }
00184     }
00185     // cleanup
00186     free(p);
00187   }
00188 }
00189 
00190 prediction *cba_learner::classify_state()
00191 {
00192   // allocate the prediction
00193   prediction *p = (prediction *)malloc(sizeof(prediction));
00194 
00195   // check for the classifier
00196   if (classify.exists())
00197   {
00198     // create the service request
00199     lfd_common::conf_classification cc;
00200     for (int i = 0; i < s_size; i++)
00201       cc.request.s.state_vector.push_back(s[i]);
00202 
00203     // send the service request
00204     classify.call(cc);
00205 
00206     // fill in the struct
00207     p->c = cc.response.c;
00208     p->l = cc.response.l;
00209     p->db = cc.response.db;
00210   }
00211   else
00212   {
00213     ROS_WARN("Could not connect to 'classify' service.");
00214     // fill the prediction with negative infinity confidence
00215     p->c = -numeric_limits<float>::infinity();
00216     p->db = -1;
00217     p->l = -1;
00218   }
00219 
00220   return p;
00221 }
00222 
00223 double cba_learner::nearest_neighbor()
00224 {
00225   // check if we have any data points yet
00226   if (pts == 0)
00227     return numeric_limits<float>::infinity();
00228   else
00229   {
00230     // calculate NN using ANN
00231     ANNidxArray index = new ANNidx[1];
00232     ANNdistArray dists = new ANNdist[1];
00233 
00234     // create the search structure
00235     ANNkd_tree *kd_tree = new ANNkd_tree(ann_data, pts, s_size);
00236 
00237     // create the data point
00238     ANNpoint pt = annAllocPt(s_size);
00239     for (int i = 0; i < s_size; i++)
00240       pt[i] = s[i];
00241 
00242     // calculate nearest neighbor
00243     kd_tree->annkSearch(pt, 1, index, dists, ANN_EPSILON);
00244     // unsquare the distance
00245     double d = sqrt(dists[0]);
00246 
00247     // cleanup
00248     annDeallocPt(pt);
00249     delete kd_tree;
00250 
00251     return d;
00252   }
00253 }
00254 
00255 double cba_learner::conf_thresh(int l, int db)
00256 {
00257   // check the thresholds for the given action label and decision boundary pair
00258   for (uint i = 0; i < conf_thresholds.size(); i++)
00259     if (conf_thresholds.at(i)->l == l && conf_thresholds.at(i)->db == db)
00260       return conf_thresholds.at(i)->thresh;
00261 
00262   // no threshold found for the given pair -- we return infinity
00263   return numeric_limits<float>::infinity();
00264 }
00265 
00266 void cba_learner::update_thresholds()
00267 {
00268   // check if we have at least two points
00269   if (pts < 2)
00270     return;
00271 
00272   // clear the old confidence thresholds
00273   for (uint i = 0; i < conf_thresholds.size(); i++)
00274     free(conf_thresholds.at(i));
00275   conf_thresholds.clear();
00276 
00277   // calculate NN using ANN
00278   ANNidxArray index = new ANNidx[2];
00279   ANNdistArray dists = new ANNdist[2];
00280   // create the search structure
00281   ANNkd_tree *kd_tree = new ANNkd_tree(ann_data, pts, s_size);
00282 
00283   // go through each point and calculate its nearest neighbor and its confidence
00284   double total_dist = 0;
00285   for (int i = 0; i < pts; i++)
00286   {
00287     // calculate second nearest neighbor (first is itself)
00288     kd_tree->annkSearch(ann_data[i], 2, index, dists, ANN_EPSILON);
00289     // add the unsquared distance
00290     total_dist += sqrt(dists[1]);
00291 
00292     // check for the classifier
00293     if (classify.exists())
00294     {
00295       // create the service request
00296       lfd_common::conf_classification cc;
00297       for (int j = 0; j < s_size; j++)
00298         cc.request.s.state_vector.push_back(ann_data[i][j]);
00299       // send the service request
00300       classify.call(cc);
00301 
00302       // find this confidence threshold value
00303       conf *c = NULL;
00304       for (uint j = 0; j < conf_thresholds.size() && c == NULL; j++)
00305         if (conf_thresholds.at(j)->l == cc.response.l && conf_thresholds.at(j)->db == cc.response.db)
00306           c = conf_thresholds.at(j);
00307       // if none was found we create a new one
00308       if (c == NULL)
00309       {
00310         c = (conf *)malloc(sizeof(conf));
00311         c->l = cc.response.l;
00312         c->db = cc.response.db;
00313         c->thresh = 0;
00314         c->cnt = 0;
00315         conf_thresholds.push_back(c);
00316       }
00317 
00318       // check if the prediction was correct
00319       if (cc.response.l != labels[i])
00320       {
00321         // add to the average
00322         c->thresh += cc.response.c;
00323         c->cnt++;
00324       }
00325     }
00326     else
00327       ROS_WARN("Could not connect to 'classify' service.");
00328   }
00329 
00330   // set the distance threshold
00331   dist_thresh = (total_dist / pts) * dist_mult;
00332 
00333   // set the confidence thresholds
00334   for (uint i = 0; i < conf_thresholds.size(); i++)
00335     conf_thresholds.at(i)->thresh /= conf_thresholds.at(i)->cnt;
00336 
00337   // cleanup
00338   delete kd_tree;
00339 }
00340 
00341 void cba_learner::state_listener_callback(const lfd_common::state::ConstPtr &msg)
00342 {
00343   // set the state size if it is not yet set
00344   if (s_size == -1)
00345   {
00346     s_size = msg->state_vector.size();
00347     // allocate space for ANN
00348     ann_data = annAllocPts(max_pts, s_size);
00349     // allocate space for the states and data sets
00350     s = (float *)malloc(sizeof(float) * s_size);
00351     sc = (float *)malloc(sizeof(float) * s_size);
00352     labels = (int *)malloc(sizeof(int32_t) * max_pts);
00353   }
00354 
00355   // check if the state sizes match
00356   if (s_size != (int)msg->state_vector.size())
00357     ROS_WARN("WARNING: State sizes do not match -- Ignoring current state.");
00358 
00359   // copy the state vector
00360   for (int i = 0; i < s_size; i++)
00361     s[i] = msg->state_vector[i];
00362 }
00363 
00364 bool cba_learner::a_complete_callback(lfd_common::action_complete::Request &req,
00365                                       lfd_common::action_complete::Response &resp)
00366 {
00367   // set the action complete value
00368   action_complete = true;
00369 
00370   // check if a correction was given for an autonomous action
00371   if (autonomous_action && req.valid_correction)
00372   {
00373     // find the data point
00374     bool duplicate = false;
00375     for (int i = 0; i < pts; i++)
00376     {
00377       for (int j = 0; j < s_size; j++)
00378         if (ann_data[i][j] == sc[j])
00379           duplicate = true;
00380         else
00381         {
00382           duplicate = false;
00383           break;
00384         }
00385       if (duplicate)
00386       {
00387         // check if the label was changed
00388         if (labels[i] != req.a)
00389         {
00390           // change the label
00391           labels[i] = req.a;
00392           //update the point in the classifier
00393           lfd_common::classification_point cp;
00394           for (int i = 0; i < s_size; i++)
00395             cp.s.state_vector.push_back(s[i]);
00396           cp.l = req.a;
00397           change_point.publish(cp);
00398 
00399           // update the thresholds
00400           update_thresholds();
00401         }
00402       }
00403     }
00404 
00405     // new data point
00406     if (!duplicate)
00407     {
00408       // update the data set
00409       if (pts == max_pts)
00410         ROS_WARN("Too many data points -- latest point ignored.");
00411       else
00412       {
00413         // add the data to the ANN set
00414         for (int i = 0; i < s_size; i++)
00415           ann_data[pts][i] = s[i];
00416         // set the label
00417         labels[pts] = req.a;
00418         pts++;
00419 
00420         // send the point to the classifier
00421         lfd_common::classification_point cp;
00422         for (int i = 0; i < s_size; i++)
00423           cp.s.state_vector.push_back(s[i]);
00424         cp.l = req.a;
00425         add_point.publish(cp);
00426 
00427         // update the thresholds
00428         update_thresholds();
00429       }
00430     }
00431   }
00432 
00433   // if the action is complete, it cannot be autonomous anymore
00434   autonomous_action = false;
00435   return true;
00436 }
00437 
00438 int main(int argc, char **argv)
00439 {
00440   // initialize ROS and the node
00441   ros::init(argc, argv, "cba");
00442 
00443   // initialize the CBA learner
00444   cba_learner cba;
00445 
00446   // continue until a ctrl-c has occurred
00447   while (ros::ok())
00448   {
00449     ros::spinOnce();
00450     // step through the algorithm once
00451     cba.step();
00452   }
00453 
00454   return EXIT_SUCCESS;
00455 }


cba
Author(s): Russell Toris
autogenerated on Thu Jan 2 2014 11:23:56