Go to the documentation of this file.00001
00011 #include <ANN/ANN.h>
00012 #include <cba/cba.h>
00013 #include <lfd_common/action_complete.h>
00014 #include <lfd_common/classification_point.h>
00015 #include <lfd_common/conf_classification.h>
00016 #include <lfd_common/demonstration.h>
00017 #include <lfd_common/state.h>
00018 #include <limits>
00019 #include <ros/ros.h>
00020 #include <std_msgs/Int32.h>
00021 #include <vector>
00022
00023 using namespace std;
00024
00025 cba_learner::cba_learner()
00026 {
00027
00028 execute = node.advertise<std_msgs::Int32>("execute", 1);
00029 add_point = node.advertise<lfd_common::classification_point>("add_point", -1);
00030 change_point = node.advertise<lfd_common::classification_point>("change_point", -1);
00031 state_listener = node.subscribe<lfd_common::state>("state_listener", 1, &cba_learner::state_listener_callback, this);
00032 a_complete = node.advertiseService("a_complete", &cba_learner::a_complete_callback, this);
00033 classify = node.serviceClient<lfd_common::conf_classification>("classify");
00034 correction = node.serviceClient<lfd_common::demonstration>("correction");
00035 demonstration = node.serviceClient<lfd_common::demonstration>("demonstration");
00036
00037
00038 ros::param::param<int>(MAX_DATA_POINTS, max_pts, DEFAULT_MAX_POINTS);
00039
00040 ros::param::param<double>(DIST_THRESH_MULT, dist_mult, DEFAULT_DIST_MULT);
00041
00042
00043 labels = NULL;
00044 s = NULL;
00045 sc = NULL;
00046 a = -1;
00047 s_size = -1;
00048 pts = 0;
00049 action_complete = true;
00050 autonomous_action = false;
00051 dist_thresh = 0;
00052
00053 ROS_INFO("CBA Learner Initialized");
00054 }
00055
00056 cba_learner::~cba_learner()
00057 {
00058
00059 if (s != NULL)
00060 free(s);
00061 if (sc != NULL)
00062 free(sc);
00063
00064 if (labels != NULL)
00065 free(labels);
00066
00067 if (pts > 0)
00068 annDeallocPts(ann_data);
00069 annClose();
00070 }
00071
00072 void cba_learner::step()
00073 {
00074
00075 if (s == NULL)
00076 return;
00077
00078
00079 if (action_complete)
00080 {
00081
00082 prediction *p = classify_state();
00083
00084
00085 double d = nearest_neighbor();
00086
00087
00088 if (p->c > conf_thresh(p->l, p->db) && d < dist_thresh)
00089 {
00090
00091 std_msgs::Int32 a;
00092 a.data = p->l;
00093 execute.publish(a);
00094 action_complete = false;
00095 autonomous_action = true;
00096
00097 memcpy(sc, s, sizeof(float) * s_size);
00098 }
00099 else
00100 {
00101
00102 if (!demonstration.exists())
00103 ROS_WARN("Could not connect to 'demonstration' service.");
00104 else
00105 {
00106
00107 lfd_common::demonstration dem;
00108 for (int i = 0; i < s_size; i++)
00109 dem.request.s.state_vector.push_back(s[i]);
00110
00111 demonstration.call(dem);
00112
00113
00114 if (dem.response.valid)
00115 {
00116
00117 bool duplicate;
00118 for (int i = 0; i < pts; i++)
00119 {
00120 for (int j = 0; j < s_size; j++)
00121 if (ann_data[i][j] == s[j])
00122 duplicate = true;
00123 else
00124 {
00125 duplicate = false;
00126 break;
00127 }
00128 if (duplicate)
00129 {
00130
00131 if (labels[i] != dem.response.a)
00132 {
00133
00134 labels[i] = dem.response.a;
00135
00136 lfd_common::classification_point cp;
00137 for (int i = 0; i < s_size; i++)
00138 cp.s.state_vector.push_back(s[i]);
00139 cp.l = dem.response.a;
00140 change_point.publish(cp);
00141
00142
00143 update_thresholds();
00144 }
00145 break;
00146 }
00147 }
00148
00149
00150 if (!duplicate)
00151 {
00152
00153 if (pts == max_pts)
00154 ROS_WARN("Too many data points -- latest point ignored.");
00155 else
00156 {
00157
00158 for (int i = 0; i < s_size; i++)
00159 ann_data[pts][i] = s[i];
00160
00161 labels[pts] = dem.response.a;
00162 pts++;
00163
00164
00165 lfd_common::classification_point cp;
00166 for (int i = 0; i < s_size; i++)
00167 cp.s.state_vector.push_back(s[i]);
00168 cp.l = dem.response.a;
00169 add_point.publish(cp);
00170
00171
00172 update_thresholds();
00173 }
00174 }
00175
00176
00177 std_msgs::Int32 a;
00178 a.data = dem.response.a;
00179 execute.publish(a);
00180 action_complete = false;
00181 autonomous_action = false;
00182 }
00183 }
00184 }
00185
00186 free(p);
00187 }
00188 }
00189
00190 prediction *cba_learner::classify_state()
00191 {
00192
00193 prediction *p = (prediction *)malloc(sizeof(prediction));
00194
00195
00196 if (classify.exists())
00197 {
00198
00199 lfd_common::conf_classification cc;
00200 for (int i = 0; i < s_size; i++)
00201 cc.request.s.state_vector.push_back(s[i]);
00202
00203
00204 classify.call(cc);
00205
00206
00207 p->c = cc.response.c;
00208 p->l = cc.response.l;
00209 p->db = cc.response.db;
00210 }
00211 else
00212 {
00213 ROS_WARN("Could not connect to 'classify' service.");
00214
00215 p->c = -numeric_limits<float>::infinity();
00216 p->db = -1;
00217 p->l = -1;
00218 }
00219
00220 return p;
00221 }
00222
00223 double cba_learner::nearest_neighbor()
00224 {
00225
00226 if (pts == 0)
00227 return numeric_limits<float>::infinity();
00228 else
00229 {
00230
00231 ANNidxArray index = new ANNidx[1];
00232 ANNdistArray dists = new ANNdist[1];
00233
00234
00235 ANNkd_tree *kd_tree = new ANNkd_tree(ann_data, pts, s_size);
00236
00237
00238 ANNpoint pt = annAllocPt(s_size);
00239 for (int i = 0; i < s_size; i++)
00240 pt[i] = s[i];
00241
00242
00243 kd_tree->annkSearch(pt, 1, index, dists, ANN_EPSILON);
00244
00245 double d = sqrt(dists[0]);
00246
00247
00248 annDeallocPt(pt);
00249 delete kd_tree;
00250
00251 return d;
00252 }
00253 }
00254
00255 double cba_learner::conf_thresh(int l, int db)
00256 {
00257
00258 for (uint i = 0; i < conf_thresholds.size(); i++)
00259 if (conf_thresholds.at(i)->l == l && conf_thresholds.at(i)->db == db)
00260 return conf_thresholds.at(i)->thresh;
00261
00262
00263 return numeric_limits<float>::infinity();
00264 }
00265
00266 void cba_learner::update_thresholds()
00267 {
00268
00269 if (pts < 2)
00270 return;
00271
00272
00273 for (uint i = 0; i < conf_thresholds.size(); i++)
00274 free(conf_thresholds.at(i));
00275 conf_thresholds.clear();
00276
00277
00278 ANNidxArray index = new ANNidx[2];
00279 ANNdistArray dists = new ANNdist[2];
00280
00281 ANNkd_tree *kd_tree = new ANNkd_tree(ann_data, pts, s_size);
00282
00283
00284 double total_dist = 0;
00285 for (int i = 0; i < pts; i++)
00286 {
00287
00288 kd_tree->annkSearch(ann_data[i], 2, index, dists, ANN_EPSILON);
00289
00290 total_dist += sqrt(dists[1]);
00291
00292
00293 if (classify.exists())
00294 {
00295
00296 lfd_common::conf_classification cc;
00297 for (int j = 0; j < s_size; j++)
00298 cc.request.s.state_vector.push_back(ann_data[i][j]);
00299
00300 classify.call(cc);
00301
00302
00303 conf *c = NULL;
00304 for (uint j = 0; j < conf_thresholds.size() && c == NULL; j++)
00305 if (conf_thresholds.at(j)->l == cc.response.l && conf_thresholds.at(j)->db == cc.response.db)
00306 c = conf_thresholds.at(j);
00307
00308 if (c == NULL)
00309 {
00310 c = (conf *)malloc(sizeof(conf));
00311 c->l = cc.response.l;
00312 c->db = cc.response.db;
00313 c->thresh = 0;
00314 c->cnt = 0;
00315 conf_thresholds.push_back(c);
00316 }
00317
00318
00319 if (cc.response.l != labels[i])
00320 {
00321
00322 c->thresh += cc.response.c;
00323 c->cnt++;
00324 }
00325 }
00326 else
00327 ROS_WARN("Could not connect to 'classify' service.");
00328 }
00329
00330
00331 dist_thresh = (total_dist / pts) * dist_mult;
00332
00333
00334 for (uint i = 0; i < conf_thresholds.size(); i++)
00335 conf_thresholds.at(i)->thresh /= conf_thresholds.at(i)->cnt;
00336
00337
00338 delete kd_tree;
00339 }
00340
00341 void cba_learner::state_listener_callback(const lfd_common::state::ConstPtr &msg)
00342 {
00343
00344 if (s_size == -1)
00345 {
00346 s_size = msg->state_vector.size();
00347
00348 ann_data = annAllocPts(max_pts, s_size);
00349
00350 s = (float *)malloc(sizeof(float) * s_size);
00351 sc = (float *)malloc(sizeof(float) * s_size);
00352 labels = (int *)malloc(sizeof(int32_t) * max_pts);
00353 }
00354
00355
00356 if (s_size != (int)msg->state_vector.size())
00357 ROS_WARN("WARNING: State sizes do not match -- Ignoring current state.");
00358
00359
00360 for (int i = 0; i < s_size; i++)
00361 s[i] = msg->state_vector[i];
00362 }
00363
00364 bool cba_learner::a_complete_callback(lfd_common::action_complete::Request &req,
00365 lfd_common::action_complete::Response &resp)
00366 {
00367
00368 action_complete = true;
00369
00370
00371 if (autonomous_action && req.valid_correction)
00372 {
00373
00374 bool duplicate = false;
00375 for (int i = 0; i < pts; i++)
00376 {
00377 for (int j = 0; j < s_size; j++)
00378 if (ann_data[i][j] == sc[j])
00379 duplicate = true;
00380 else
00381 {
00382 duplicate = false;
00383 break;
00384 }
00385 if (duplicate)
00386 {
00387
00388 if (labels[i] != req.a)
00389 {
00390
00391 labels[i] = req.a;
00392
00393 lfd_common::classification_point cp;
00394 for (int i = 0; i < s_size; i++)
00395 cp.s.state_vector.push_back(s[i]);
00396 cp.l = req.a;
00397 change_point.publish(cp);
00398
00399
00400 update_thresholds();
00401 }
00402 }
00403 }
00404
00405
00406 if (!duplicate)
00407 {
00408
00409 if (pts == max_pts)
00410 ROS_WARN("Too many data points -- latest point ignored.");
00411 else
00412 {
00413
00414 for (int i = 0; i < s_size; i++)
00415 ann_data[pts][i] = s[i];
00416
00417 labels[pts] = req.a;
00418 pts++;
00419
00420
00421 lfd_common::classification_point cp;
00422 for (int i = 0; i < s_size; i++)
00423 cp.s.state_vector.push_back(s[i]);
00424 cp.l = req.a;
00425 add_point.publish(cp);
00426
00427
00428 update_thresholds();
00429 }
00430 }
00431 }
00432
00433
00434 autonomous_action = false;
00435 return true;
00436 }
00437
00438 int main(int argc, char **argv)
00439 {
00440
00441 ros::init(argc, argv, "cba");
00442
00443
00444 cba_learner cba;
00445
00446
00447 while (ros::ok())
00448 {
00449 ros::spinOnce();
00450
00451 cba.step();
00452 }
00453
00454 return EXIT_SUCCESS;
00455 }