00001 #include "pr2_overhead_grasping/random_forest.h"
00002
00003 #include <omp.h>
00004 #include <stdio.h>
00005 #include <signal.h>
00006
00007
00008 using namespace std;
00009 using namespace pr2_overhead_grasping;
00010 int ATTRS_TRY = 44;
00011 double MAX_GAIN_THRESH = 0.0;
00012
00013 namespace collision_detection {
00014
00015 RandomTree::RandomTree(int c_d_tree_num) {
00016 d_tree_num = c_d_tree_num;
00017 }
00018
00019 RandomTree::RandomTree(RandomTreeMsg::Ptr r_tree) {
00020 rand_tree = r_tree;
00021 num_classes = rand_tree->num_classes;
00022 }
00023
00024 bool RandomTree::attrCompare(int inst_i, int inst_j, int attr) {
00025 return (dataset->at(inst_i)->features[attr] < dataset->at(inst_j)->features[attr]);
00026 }
00027
00028 struct AttrComp {
00029 int attr;
00030 vector< SensorPoint::Ptr >* dataset;
00031 AttrComp(vector< SensorPoint::Ptr >* c_dataset, int c_attr) {
00032 dataset = c_dataset;
00033 attr = c_attr;
00034 }
00035 bool operator()(int inst_i, int inst_j) {
00036 return (dataset->at(inst_i)->features[attr] < dataset->at(inst_j)->features[attr]);
00037 }
00038 };
00039
00040 typedef set<int, boost::function<bool (int,int)> > AttrSet;
00041 void RandomTree::findBestSplit(vector<int>* insts, vector<int>& attrs,
00042 pair<int, float>& ret) {
00043 long double max_gain = 0.0;
00044 float best_split_f = 0;
00045 int best_split_attr = 0;
00046 for(uint32_t k=0;k<insts->size();k++)
00047 if((uint32_t) insts->at(k) >= dataset->size())
00048 printf("WTF\n");
00049 for(uint32_t a=0;a<attrs.size();a++) {
00050
00051
00052 AttrComp attr_comp(dataset, attrs[a]);
00053
00054
00055 vector<int> insts_temp(insts->begin(), insts->end());
00056 set<int, AttrComp> insts_sorted(attr_comp);
00057
00058 insts_sorted.insert(insts_temp.begin(), insts_temp.end());
00059
00060 map<int, int> class_sums;
00061 map<int, int> class_cur_sums;
00062 map<int, int>::iterator cs_iter, ccs_iter;
00063
00064 for(uint32_t i =0;i<insts->size();i++) {
00065 int label = dataset->at(insts->at(i))->label;
00066 if(class_sums.count(label) == 0) {
00067 class_sums[label] = 0;
00068 class_cur_sums[label] = 0;
00069 }
00070 class_sums[label]++;
00071 }
00072
00073
00074 uint32_t inst_ctr = 0;
00075 for(AttrSet::iterator inst = insts_sorted.begin();
00076 inst != insts_sorted.end();inst++) {
00077 if(inst_ctr == insts_sorted.size()-1)
00078 break;
00079
00080 class_cur_sums[dataset->at(*inst)->label]++;
00081 long double entropy = 0.0, gain = 0.0;
00082 cs_iter = class_sums.begin();
00083 ccs_iter = class_cur_sums.begin();
00084 for(int j=0;j<num_classes;j++) {
00085 long double prec_in_class_l = ccs_iter->second / (long double) (insts->size());
00086 if(prec_in_class_l > 0.00001)
00087 entropy -= prec_in_class_l * log2(prec_in_class_l);
00088 long double prec_in_class_r = (cs_iter->second - ccs_iter->second) / (long double) (insts->size());
00089 if(prec_in_class_r > 0.00001)
00090 entropy -= prec_in_class_r * log2(prec_in_class_r);
00091
00092 gain += entropy * cs_iter->second;
00093 cs_iter++; ccs_iter++;
00094 }
00095
00096
00097 if(gain > max_gain) {
00098 best_split_f = dataset->at(*inst)->features[attrs[a]];
00099 best_split_attr = attrs[a];
00100
00101 max_gain = gain;
00102
00103
00104
00105 }
00106
00107 inst_ctr++;
00108 }
00109 }
00110 ret.first = best_split_attr; ret.second = best_split_f;
00111 }
00112
00113 void RandomTree::splitNode(vector<int>* node_inds,
00114 pair<int, float>& split_pt,
00115 pair<vector<int>*, vector<int>* >& split_nodes) {
00116 pair<vector<int>*, vector<int>* > nodes;
00117 split_nodes.first = new vector<int>; split_nodes.second = new vector<int>;
00118
00119 for(uint32_t i=0;i<node_inds->size();i++) {
00120
00121 if(dataset->at(node_inds->at(i))->features[split_pt.first] <= split_pt.second)
00122 split_nodes.first->push_back(node_inds->at(i));
00123 else
00124 split_nodes.second->push_back(node_inds->at(i));
00125 }
00126 assert(split_nodes.first->size() > 0);
00127 assert(split_nodes.second->size() > 0);
00128 }
00129
00130 void RandomTree::growTree(vector< SensorPoint::Ptr >* c_dataset,
00131 vector<int>* inds) {
00133
00134
00135
00136
00138 dataset = c_dataset;
00139
00140 int num_attrs = dataset->at(0)->features.size();
00141
00142
00143 rand_tree = boost::shared_ptr<RandomTreeMsg>(new RandomTreeMsg);
00144
00145 rand_tree->tree_num = d_tree_num;
00146 rand_tree->attr_split.resize(2*inds->size());
00147
00148 rand_tree->val_split.resize(2*inds->size());
00149 rand_tree->r_node_inds.resize(2*inds->size(), -999);
00150 vector<int> r_node_stack;
00151
00152 int tm_i = 0;
00153
00155
00156 map<int, vector<int> > class_inds;
00157 for(uint32_t i=0;i<inds->size();i++)
00158 class_inds[dataset->at(inds->at(i))->label].push_back(inds->at(i));
00159 num_classes = class_inds.size();
00160
00161 int minority_class = 0, min_class = class_inds.begin()->second.size();
00162 for(map<int, vector<int> >::iterator i=class_inds.begin();i!=class_inds.end();i++) {
00163 if(i->second.size() < (uint32_t) min_class) {
00164 min_class = i->second.size();
00165 minority_class = i->first;
00166 }
00167 }
00168
00169 vector<bool> oobs(dataset->size(), true);
00170
00171 int num_samples = class_inds[minority_class].size();
00172
00173 vector<int>* head_node = new vector<int>(num_samples*class_inds.size());
00174 int i_cntr = 0;
00175 for(map<int, vector<int> >::iterator i=class_inds.begin();i!=class_inds.end();i++) {
00176
00177 for(int j=0;j<num_samples;j++) {
00178 int sample_ind = i->second[rand() % i->second.size()];
00179
00180 head_node->at(i_cntr++) = sample_ind;
00181 oobs[sample_ind] = false;
00182 }
00183 }
00184
00185
00186
00187
00188
00189
00190
00191
00193
00195
00196 int last_r_node_ind = -1;
00197 vector<vector<int>* > node_stack;
00198 node_stack.push_back(head_node);
00199
00200 while(ros::ok()) {
00201 if(node_stack.size() == 0)
00202 break;
00203 vector<int>* cur_node = node_stack.back();
00204
00205
00206 bool homo = true;
00207 for(uint32_t i=0;i<cur_node->size();i++) {
00208
00209 if(dataset->at(cur_node->at(0))->label != dataset->at(cur_node->at(i))->label) {
00210 homo = false;
00211 break;
00212 }
00213 }
00214 if(homo) {
00215
00216
00217
00218
00219
00220
00221
00222
00223 rand_tree->attr_split[tm_i] = dataset->at(cur_node->at(0))->label;
00224 rand_tree->val_split[tm_i] = 0.0;
00225 rand_tree->r_node_inds[tm_i] = -1;
00226
00227 if(r_node_stack.size() > 0) {
00228 last_r_node_ind = r_node_stack.back();
00229 rand_tree->r_node_inds[r_node_stack.back()] = tm_i + 1;
00230 r_node_stack.pop_back();
00231 tm_i++;
00232 }
00233
00234 node_stack.pop_back();
00235
00236
00237
00238 continue;
00239 }
00240 assert(cur_node->size() > 1);
00241
00242
00243 int num_try = (int) sqrt(num_attrs);
00244 vector<int> attrs(num_try, 0);
00245 for(int i=0;i<num_try;i++) {
00246 attrs[i] = rand() % num_attrs;
00247 }
00248
00249
00250
00251
00252 pair<int, float> split_pt;
00253 findBestSplit(cur_node, attrs, split_pt);
00254
00255 if(split_pt.first < 0) {
00256
00257 assert(false);
00258
00259 split_pt.first *= -1;
00260
00261 pair<vector<int>*, vector<int>* > new_nodes;
00262 splitNode(cur_node, split_pt, new_nodes);
00263
00264
00265 vector<int> class_sums_l(num_classes, 0);
00266 for(uint32_t i=0;i<new_nodes.first->size();i++)
00267 class_sums_l[dataset->at(new_nodes.first->at(i))->label]++;
00268 int popular_class_l = max_element(class_sums_l.begin(), class_sums_l.end()) - class_sums_l.begin();
00269
00270
00271 rand_tree->attr_split[tm_i] = popular_class_l;
00272 rand_tree->val_split[tm_i] = 0.0;
00273 rand_tree->r_node_inds[tm_i] = -1;
00274
00275 rand_tree->r_node_inds[r_node_stack.back()] = tm_i + 1;
00276 r_node_stack.pop_back();
00277 tm_i++;
00278
00279
00280 vector<int> class_sums_r(num_classes, 0);
00281 for(uint32_t i=0;i<new_nodes.second->size();i++)
00282 class_sums_r[dataset->at(new_nodes.second->at(i))->label]++;
00283 int popular_class_r = max_element(class_sums_r.begin(), class_sums_r.end()) - class_sums_r.begin();
00284
00285
00286 rand_tree->attr_split[tm_i] = popular_class_r;
00287 rand_tree->val_split[tm_i] = 0.0;
00288 rand_tree->r_node_inds[tm_i] = -1;
00289
00290 rand_tree->r_node_inds[r_node_stack.back()] = tm_i + 1;
00291 r_node_stack.pop_back();
00292 tm_i++;
00293 } else {
00294
00295
00296 pair<vector<int>*, vector<int>* > new_nodes;
00297 splitNode(cur_node, split_pt, new_nodes);
00298
00299 node_stack.pop_back();
00300
00301 node_stack.push_back(new_nodes.second);
00302
00303 node_stack.push_back(new_nodes.first);
00304
00305
00306 rand_tree->attr_split[tm_i] = split_pt.first;
00307 rand_tree->val_split[tm_i] = split_pt.second;
00308 r_node_stack.push_back(tm_i);
00309 tm_i++;
00310 }
00311
00312
00313
00314 }
00315
00316 rand_tree->r_node_inds[last_r_node_ind] = tm_i - 1;
00317
00318
00319
00320
00321
00322 rand_tree->attr_split.resize(tm_i);
00323 rand_tree->val_split.resize(tm_i);
00324 rand_tree->r_node_inds.resize(tm_i);
00325
00326 for(uint32_t i=0;i<oobs.size();i++)
00327 if(oobs[i])
00328 rand_tree->out_of_bags.push_back(i);
00329 rand_tree->num_classes = num_classes;
00330
00331
00332 }
00333
00334 void RandomTree::writeTree(string& bag_file, bool is_first) {
00335
00336 rosbag::Bag bag;
00337 int bagmode;
00338 if(is_first)
00339 bagmode = rosbag::bagmode::Write;
00340 else
00341 bagmode = rosbag::bagmode::Append;
00342 bag.open(bag_file, bagmode);
00343 bag.write("trees", ros::Time::now(), rand_tree);
00344 bag.close();
00345 }
00346
00347 int RandomTree::classifyInstance(SensorPoint::Ptr inst) {
00348 int ind = 0, attr;
00349
00350
00351
00352
00353
00354
00355
00356
00357 while(ros::ok()) {
00358
00359 attr = rand_tree->attr_split[ind];
00360 assert(attr >= 0 && attr < inst->features.size());
00361
00362 assert(ind != -999);
00363 if(rand_tree->r_node_inds[ind] < 0) {
00364
00365
00366 return attr;
00367 }
00368 float feat = inst->features[attr];
00369 if(is_abs)
00370 feat = abs(feat);
00371 if(feat <= rand_tree->val_split[ind]) {
00372
00373 ind++;
00374 } else {
00375
00376 ind = rand_tree->r_node_inds[ind];
00377 }
00378 }
00379
00380 return -1;
00381 }
00382
00383 void RandomForest::loadDataset() {
00384 dataset = new vector< SensorPoint::Ptr >;
00385 string bag_path;
00386 XmlRpc::XmlRpcValue bag_names, bag_labels;
00387 nh_priv->getParam("bag_path", bag_path);
00388 nh_priv->getParam("bag_names", bag_names);
00389 nh_priv->getParam("bag_labels", bag_labels);
00390 for(int i=0;i<bag_names.size();i++) {
00391 string file_loc = bag_path + (string) bag_names[i];
00392 loadDataBag(file_loc, bag_labels[i]);
00393 }
00394 }
00395
00396 void RandomForest::loadDataBag(string& data_bag, int label) {
00397
00398 rosbag::Bag bag(data_bag);
00399 rosbag::View view(bag, rosbag::TopicQuery("/collision_data"));
00400 BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00401 SensorPoint::Ptr sp = m.instantiate<SensorPoint>();
00402 if(sp != NULL) {
00403 sp->label = label;
00404 if(is_abs)
00405 for(uint32_t i=0;i<sp->features.size();i++)
00406 sp->features[i] = abs(sp->features[i]);
00407 dataset->push_back(sp);
00408 }
00409 }
00410 assert(dataset->size() != 0);
00411 }
00412
00413 void RandomForest::setDataset(vector< SensorPoint::Ptr >* datas) {
00414 dataset = datas;
00415 }
00416
00417 void RandomForest::growForest(vector< SensorPoint::Ptr >* c_dataset,
00418 vector<int>* inds, int c_num_trees) {
00419 dataset = c_dataset;
00420 num_trees = c_num_trees;
00421
00422 trees = new RandomTree*[num_trees];
00423 int i, NUM_CHUNKS = 10;
00424 RandomForest* this_forest = this;
00425 #pragma omp parallel shared(this_forest, inds, NUM_CHUNKS) private(i)
00426 {
00427 #pragma omp for schedule(dynamic, NUM_CHUNKS)
00428 for(i=0;i<this_forest->num_trees;i++) {
00429 ROS_INFO("Growing tree %d", i+1);
00430
00431 while(ros::ok()) {
00432 try {
00433 this_forest->trees[i] = new RandomTree(i);
00434 this_forest->trees[i]->growTree(this_forest->dataset, inds);
00435 break;
00436 }
00437 catch (bad_alloc &) {
00438 ROS_INFO("Memory unavailable...restaring...");
00439 ros::Duration(1.0).sleep();
00440 }
00441 }
00442 }
00443 }
00444 }
00445
00446 RandomForest::~RandomForest() {
00447 for(int i=0;i<num_trees;i++)
00448 delete trees[i];
00449 }
00450
00451 void RandomForest::writeForest() {
00452 string forest_bag_name;
00453 nh_priv->getParam("forest_bag_name", forest_bag_name);
00454 writeForest(forest_bag_name);
00455 }
00456
00457 void RandomForest::writeForest(string name) {
00458 string bag_path;
00459 nh_priv->getParam("bag_path", bag_path);
00460 string file_loc = bag_path + name;
00461 for(int i=0;i<num_trees;i++) {
00462 trees[i]->writeTree(file_loc, i == 0);
00463 }
00464 }
00465
00466
00467 void RandomForest::growWriteForest() {
00468 int c_num_trees;
00469 bool is_validation;
00470 nh_priv->getParam("num_trees", c_num_trees);
00471 nh_priv->getParam("is_validation", is_validation);
00472 loadDataset();
00473 vector<int>* inds = new vector<int>(dataset->size());
00474 for(uint32_t i=0;i<dataset->size();i++)
00475 inds->at(i) = i;
00476 if(is_validation) {
00477 int pos_id;
00478 bool classify_first;
00479 nh_priv->getParam("pos_id", pos_id);
00480 nh_priv->getParam("classify_first", classify_first);
00481 ROS_INFO("Running ten-fold cross validation.");
00482 vector<map<int, int> > votes_total;
00483 RandomForest::runTenFold(dataset, pos_id, c_num_trees, votes_total, classify_first);
00484
00485 bool filter;
00486 nh_priv->getParam("filter", filter);
00487 if(filter) {
00488 double filter_thresh;
00489 string filter_out, bag_path;
00490 ROS_INFO("Filtering bad data.");
00491 nh_priv->getParam("filter_thresh", filter_thresh);
00492 nh_priv->getParam("bag_path", bag_path);
00493 nh_priv->getParam("filter_out", filter_out);
00494 string file_loc = bag_path + filter_out;
00495 rosbag::Bag bag;
00496 int bagmode = rosbag::bagmode::Write;
00497 bag.open(file_loc, bagmode);
00498 for(uint32_t i = 0;i<votes_total.size();i++) {
00499 if(votes_total[i][pos_id] / (double) c_num_trees < filter_thresh &&
00500 dataset->at(i)->label != (uint32_t) pos_id) {
00501
00502 bag.write("/collision_data", ros::Time::now(), dataset->at(i));
00503 }
00504 }
00505 bag.close();
00506 }
00507 ROS_INFO("Done, please exit.");
00508 } else {
00509 ROS_INFO("Growing forest");
00510 growForest(dataset, inds, c_num_trees);
00511 ROS_INFO("Writing forest");
00512 writeForest();
00513 ROS_INFO("Finished writing forest, please exit");
00514 }
00515 }
00516
00517 void RandomForest::loadForest() {
00518 string forest_bag_name;
00519 string bag_path;
00520 nh_priv->getParam("bag_path", bag_path);
00521 nh_priv->getParam("forest_bag_name", forest_bag_name);
00522 string file_loc = bag_path + forest_bag_name;
00523 rosbag::Bag bag(file_loc);
00524 rosbag::View view(bag, rosbag::TopicQuery("trees"));
00525 int ind = 0;
00526 trees = new RandomTree*[view.size()];
00527 BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00528 RandomTreeMsg::Ptr tree = m.instantiate<RandomTreeMsg>();
00529 if(tree != NULL) {
00530 trees[ind++] = new RandomTree(tree);
00531 trees[ind-1]->is_abs = is_abs;
00532 oobs.push_back(tree->out_of_bags);
00533 }
00534 }
00535 num_trees = ind;
00536 num_classes = trees[0]->num_classes;
00537 std_msgs::Bool loaded;
00538 loaded.data = true;
00539 loaded_pub.publish(loaded);
00540 trees_loaded = true;
00541 ROS_INFO("[random_forest] All trees loaded.");
00542 }
00543
00544 void RandomForest::loadCovMat() {
00545 string mahal_bag_name;
00546 string bag_path;
00547 nh_priv->getParam("bag_path", bag_path);
00548 nh_priv->getParam("mahal_bag_name", mahal_bag_name);
00549 string file_loc = bag_path + mahal_bag_name;
00550 rosbag::Bag bag(file_loc);
00551 rosbag::View view(bag, rosbag::TopicQuery("matrix"));
00552 BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00553 CovarianceMatrix::Ptr cov_mat_msg = m.instantiate<CovarianceMatrix>();
00554 if(cov_mat_msg != NULL) {
00555 MatrixXd cov_mat(cov_mat_msg->size, cov_mat_msg->size);
00556 int m_ind = 0;
00557 for(uint32_t i=0;i<cov_mat_msg->size;i++) {
00558 for(uint32_t j=0;j<cov_mat_msg->size;j++)
00559 cov_mat(i,j) = cov_mat_msg->cov_mat[m_ind++];
00560 means(i) = cov_mat_msg->means[i];
00561 }
00562 cov_inv = new LDLT<MatrixXd>(cov_mat);
00563 }
00564 }
00565 std_msgs::Bool loaded;
00566 loaded.data = true;
00567 loaded_pub.publish(loaded);
00568 trees_loaded = true;
00569 ROS_INFO("[random_forest] All trees loaded.");
00570 }
00571
00572 void RandomForest::createCovMat() {
00573 loadDataset();
00574
00575 int num_feats = dataset->at(0)->features.size();
00576
00577 VectorXd means = VectorXd::Zero(num_feats);
00578 for(uint32_t i=0;i<dataset->size();i++) {
00579 for(int j=0;j<num_feats;j++) {
00580 means(j) += dataset->at(i)->features[j];
00581 }
00582 }
00583 means /= dataset->size();
00584 int j=0, l=0;
00585 uint32_t i=0;
00586 MatrixXd var_mat = MatrixXd::Zero(num_feats, num_feats);
00587 #pragma omp parallel default(shared) private(j, l, i) num_threads(10)
00588 {
00589 #pragma omp for schedule(dynamic, 10)
00590 for(j=0;j<num_feats;j++) {
00591 ROS_INFO("j %d", j);
00592 for(l=0;l<num_feats;l++) {
00593 for(i=0;i<dataset->size();i++)
00594 var_mat(j,l) += (dataset->at(i)->features[j] - means(j)) * (dataset->at(i)->features[l] - means(l));
00595 var_mat(j,l) /= dataset->size();
00596 }
00597 }
00598 }
00599
00600 CovarianceMatrix cov_mat_msg;
00601 cov_mat_msg.size = num_feats;
00602 for(j=0;j<num_feats;j++)
00603 for(l=0;l<num_feats;l++)
00604 cov_mat_msg.cov_mat.push_back(var_mat(j,l));
00605
00606 string mahal_bag_name;
00607 string bag_path;
00608 nh_priv->getParam("bag_path", bag_path);
00609 nh_priv->getParam("mahal_bag_name", mahal_bag_name);
00610 string file_loc = bag_path + mahal_bag_name;
00611 rosbag::Bag bag;
00612 string forest_bag_name;
00613 int bagmode = rosbag::bagmode::Write;
00614 bag.open(file_loc, bagmode);
00615 bag.write("matrix", ros::Time::now(), cov_mat_msg);
00616 bag.close();
00617 }
00618
00619 void RandomForest::collectVotes(SensorPoint::Ptr inst, map<int, int>& class_votes) {
00620 int i;
00621
00622 for(i=0;i<num_trees;i++) {
00623 int class_vote = trees[i]->classifyInstance(inst);
00624 if(class_votes.count(class_vote) == 0)
00625 class_votes[class_vote] = 0;
00626 class_votes[class_vote]++;
00627
00628 }
00629
00630 }
00631
00632 void RandomForest::classifyCallback(const boost::shared_ptr<SensorPoint>& inst) {
00633 if(!trees_loaded) {
00634 ROS_INFO("[random_forest] Classifcation requested but trees not loaded.");
00635 return;
00636 }
00637 ClassVotes::Ptr class_votes(new ClassVotes);
00638 map<int, int> cvs;
00639 for(uint32_t i=0;i<inst->features.size();i++)
00640 if(inst->features[i] != inst->features[i]) {
00641 ROS_INFO("NAN %d", i);
00642 return;
00643 }
00644
00645 collectVotes(inst, cvs);
00646 class_votes->votes.resize(num_classes);
00647 class_votes->classes.resize(num_classes);
00648 int ind = 0;
00649 for(map<int, int>::iterator iter=cvs.begin();iter!=cvs.end();iter++) {
00650
00651
00652 class_votes->votes[iter->first] = iter->second;
00653 class_votes->classes[ind++] = iter->first;
00654 }
00655 class_votes->classifier_name = classifier_name;
00656 class_votes->classifier_id = classifier_id;
00657 results_pub.publish(class_votes);
00658 }
00659
00660
00661 int RandomForest::findFirstClass(vector<pair<map<int, int>, float > >* votes_list, int pos_id, float thresh) {
00662 int min_class = pos_id;
00663 float min_delay = 10000.0;
00664 for(uint32_t j=0;j<votes_list->size();j++) {
00665 int pred_class;
00666 if(votes_list->at(j).first[pos_id] > thresh)
00667 pred_class = pos_id;
00668 else {
00669 int max_class = 0, max_votes = -1;
00670 for(map<int, int>::iterator iter=votes_list->at(j).first.begin();iter!=votes_list->at(j).first.end();iter++) {
00671 if(iter->second > max_votes) {
00672 max_class = iter->first;
00673 max_votes = iter->second;
00674 }
00675 }
00676 pred_class = max_class;
00677 }
00678 if(votes_list->at(j).second < min_delay && pred_class != pos_id) {
00679 min_class = pred_class;
00680 min_delay = votes_list->at(j).second;
00681 }
00682 }
00683 return min_class;
00684 }
00685
00686 int RandomForest::findFrequentClass(vector<pair<map<int, int>, float > >* votes_list, int pos_id, float thresh) {
00687
00688 map<int, int> vote_counts;
00689 for(uint32_t j=0;j<votes_list->size();j++) {
00690 int pred_class;
00691 if(votes_list->at(j).first[pos_id] > thresh)
00692 pred_class = pos_id;
00693 else {
00694 int max_class = 0, max_votes = -1;
00695 for(map<int, int>::iterator iter=votes_list->at(j).first.begin();iter!=votes_list->at(j).first.end();iter++) {
00696 if(iter->second > max_votes) {
00697 max_class = iter->first;
00698 max_votes = iter->second;
00699 }
00700 }
00701 pred_class = max_class;
00702 }
00703 if(vote_counts.count(pred_class) == 0)
00704 vote_counts[pred_class] = 0;
00705 vote_counts[pred_class]++;
00706 }
00707 int max_class = 0, max_votes = -1;
00708 for(map<int, int>::iterator iter=vote_counts.begin();iter!=vote_counts.end();iter++) {
00709 if(iter->second > max_votes) {
00710 max_class = iter->first;
00711 max_votes = iter->second;
00712 }
00713 }
00714 return max_class;
00715 }
00716
00717 void RandomForest::runTenFold(vector< SensorPoint::Ptr >* train_test_data,
00718 int pos_id,
00719 int c_num_trees,
00720 vector<map<int,int> >& votes_total,
00721 bool classify_first) {
00722 int num_roc = 21;
00723 int NUM_FOLDS = 10;
00724 vector<int> traj_ids;
00725 map<int, int> label_cntr;
00726 vector<int> inv_label_cntr;
00727 map<int, vector<int> > traj_id_map;
00728 int label_ind = 0;
00729 for(uint32_t i=0;i<train_test_data->size();i++) {
00730 int traj_id = train_test_data->at(i)->traj_id;
00731 if(traj_id_map[traj_id].size() == 0)
00732 traj_ids.push_back(traj_id);
00733 traj_id_map[traj_id].push_back(i);
00734 if(label_cntr.count(train_test_data->at(i)->label) == 0) {
00735 inv_label_cntr.push_back(train_test_data->at(i)->label);
00736 label_cntr[train_test_data->at(i)->label] = label_ind++;
00737 }
00738 }
00739 int num_classes = label_cntr.size();
00740
00741
00742 FoldData::Ptr fold_save(new FoldData);
00743 random_shuffle(traj_ids.begin(), traj_ids.end());
00744 vector<int>* folds[NUM_FOLDS];
00745 vector<int>* f_tests[NUM_FOLDS];
00746 for(int i=0;i<NUM_FOLDS;i++) {
00747 folds[i] = new vector<int>;
00748 f_tests[i] = new vector<int>;
00749 }
00750 for(uint32_t i=0;i<traj_ids.size();i+=NUM_FOLDS) {
00751 for(int j=0;j<NUM_FOLDS;j++) {
00752 if(i+j == traj_ids.size())
00753 break;
00754 for(int k=0;k<NUM_FOLDS;k++) {
00755 int id = traj_ids[i+j];
00756 if(k != j) {
00757 for(uint32_t l=0;l<traj_id_map[id].size();l++) {
00758 folds[k]->push_back(traj_id_map[id][l]);
00759 fold_save->fold_data.push_back(traj_id_map[id][l]);
00760 }
00761 } else {
00762 for(uint32_t l=0;l<traj_id_map[id].size();l++) {
00763 f_tests[k]->push_back(traj_id_map[id][l]);
00764 fold_save->test_data.push_back(traj_id_map[id][l]);
00765 }
00766 }
00767 }
00768 }
00769 }
00770
00771
00772
00773
00774
00775
00776
00777
00778
00779
00780
00781
00782 MatrixXi confusion_mats[num_roc];
00783 for(int j=0;j<num_roc;j++)
00784 confusion_mats[j] = MatrixXi::Zero(num_classes, num_classes);
00785
00786
00787 votes_total.resize(train_test_data->size());
00788
00789 vector<float> roc_list;
00790 for(int i=0;i<NUM_FOLDS;i++) {
00791 RandomForest rf;
00792 printf("Growing Forest %d\n", i+1);
00793
00794 for(uint32_t k=0;k<folds[i]->size();k++)
00795 if((uint32_t) folds[i]->at(k) >= train_test_data->size())
00796 printf("WTF\n");
00797
00798
00799 rf.growForest(train_test_data, folds[i], c_num_trees);
00800
00801
00802 printf("Evaluating...\n");
00803
00804
00805 map<int, pair<int, vector<pair<map<int, int>, float > >* > > votes_map;
00806 for(uint32_t j=0;j<f_tests[i]->size();j++) {
00807 int cur_traj_id = train_test_data->at(f_tests[i]->at(j))->traj_id;
00808
00809 if(votes_map.count(cur_traj_id) == 0) {
00810 pair<int, vector<pair<map<int, int>, float > >* > vote_pair;
00811 vote_pair.first = train_test_data->at(f_tests[i]->at(j))->label;
00812 vote_pair.second = new vector<pair<map<int, int>, float > >;
00813 votes_map[cur_traj_id] = vote_pair;
00814 }
00815 map<int, int> cur_votes;
00816 rf.collectVotes(train_test_data->at(f_tests[i]->at(j)), cur_votes);
00817 votes_map[cur_traj_id].second->push_back(
00818 make_pair(cur_votes, train_test_data->at(f_tests[i]->at(j))->detect_delay));
00819 for(map<int, int>::iterator vt_iter=cur_votes.begin();vt_iter!=cur_votes.end();
00820 vt_iter++)
00821 votes_total[f_tests[i]->at(j)][vt_iter->first] = vt_iter->second;
00822
00823 }
00824 float cur_percent = 1.0 / (num_roc + 1);
00825
00826 for(int k=0;k<num_roc;k++) {
00827 map<int, pair<int, vector<pair<map<int, int>, float > >* > >::iterator votes_map_iter;
00828 for(votes_map_iter=votes_map.begin();
00829 votes_map_iter!=votes_map.end();votes_map_iter++) {
00830 int act_class = votes_map_iter->second.first;
00831 int pred_class;
00832 if(classify_first)
00833 pred_class = findFirstClass(votes_map_iter->second.second, pos_id,
00834 cur_percent * c_num_trees);
00835 else
00836 pred_class = findFrequentClass(votes_map_iter->second.second, pos_id,
00837 cur_percent * c_num_trees);
00838 confusion_mats[k](label_cntr[pred_class],label_cntr[act_class])++;
00839 }
00840 if(i == 0)
00841 roc_list.push_back(cur_percent);
00842 cur_percent += 1.0 / (num_roc + 1);
00843 cout << confusion_mats[k] << endl;
00844 }
00845 }
00846
00847 vector<vector<double> > tpr_list(num_roc);
00848 vector<vector<double> > fpr_list(num_roc);
00849 vector<vector<double> > spc_list(num_roc);
00850 for(int j=0;j<num_roc;j++) {
00851 for(int k=0;k<num_classes;k++) {
00852 double tpr = -1.0, fpr = -1.0, spc = -1.0;
00853 double TP = confusion_mats[j](k, k);
00854 double FP = confusion_mats[j].row(k).sum() - confusion_mats[j](k,k);
00855 double TN = confusion_mats[j].trace() - confusion_mats[j](k,k);
00856 double FN = confusion_mats[j].sum() - confusion_mats[j].row(k).sum() - TN;
00857 if((TP + FN) != 0)
00858 tpr = TP / (TP + FN);
00859 if((FP + TN) != 0)
00860 fpr = FP / (FP + TN);
00861 if(fpr != -1)
00862 spc = 1.0 - fpr;
00863 tpr_list[j].push_back(tpr);
00864 fpr_list[j].push_back(fpr);
00865 spc_list[j].push_back(spc);
00866 }
00867 }
00868
00869 for(int k=0;k<num_classes;k++) {
00870 cout << "Class " << inv_label_cntr[k] << endl;
00871 cout << "ROC vals" << endl;
00872 for(int j=0;j<num_roc;j++)
00873 cout << roc_list[j] << ", ";
00874 cout << endl;
00875 cout << "TPR" << endl;
00876 for(int j=0;j<num_roc;j++)
00877 cout << tpr_list[j][k] << ", ";
00878 cout << endl;
00879 cout << "FPR" << endl;
00880 for(int j=0;j<num_roc;j++)
00881 cout << fpr_list[j][k] << ", ";
00882 cout << endl;
00883 cout << "SPC" << endl;
00884 for(int j=0;j<num_roc;j++)
00885 cout << spc_list[j][k] << ", ";
00886 cout << endl;
00887 }
00888 }
00889
00890 void RandomForest::variableImportance() {
00891 loadDataset();
00892 loadForest();
00893 uint32_t num_feats = dataset->at(0)->features.size();
00894 int base_right = 0;
00895 int num_oobs = 0;
00896 vector<int> num_right(num_feats, 0);
00897 for(int i=0;i<num_trees;i++) {
00898 printf("Tree %d\n", i);
00899 int oobs_size = oobs[i].size();
00900 for(int j=0;j<oobs_size;j++) {
00901 int class_vote = trees[i]->classifyInstance(dataset->at(oobs[i][j]));
00902 base_right += class_vote == dataset->at(oobs[i][j])->label;
00903 num_oobs++;
00904 }
00905 uint32_t f;
00906 int j;
00907 int class_vote;
00908 for(f=0;f<num_feats;f++) {
00909 printf("Feature %d\n", f);
00910 vector<int> var_permute(oobs_size);
00911 vector<float> f_vals(oobs_size);
00912 for(j=0;j<oobs_size;j++) {
00913 var_permute[j] = oobs[i][j];
00914 f_vals[j] = dataset->at(var_permute[j])->features[f];
00915 }
00916 random_shuffle(var_permute.begin(), var_permute.end());
00917 int n_right = 0;
00918 #pragma omp parallel default(shared) private(j, class_vote) num_threads(10) reduction( + : n_right)
00919 {
00920 #pragma omp for schedule(dynamic, 10)
00921 for(j=0;j<oobs_size;j++) {
00922 dataset->at(oobs[i][j])->features[f] = f_vals[var_permute[j]];
00923 class_vote = trees[i]->classifyInstance(dataset->at(oobs[i][j]));
00924 n_right += class_vote == dataset->at(oobs[i][j])->label;
00925 dataset->at(oobs[i][j])->features[f] = f_vals[j];
00926 }
00927 }
00928 num_right[f] += n_right;
00929 }
00930
00931 }
00932 for(uint32_t f=0;f<num_feats;f++) {
00933 float raw_score = (base_right - num_right[f]) / num_trees;
00934 printf("%4d score: %f\n", f, raw_score);
00935 }
00936 }
00937
00938 void RandomForest::randomPermuteData() {
00939 loadDataset();
00940
00941 rosbag::Bag bag;
00942 int bagmode;
00943 bagmode = rosbag::bagmode::Write;
00944 string bag_path, data_bag_name;
00945 nh_priv->getParam("data_bag_name", data_bag_name);
00946 nh_priv->getParam("bag_path", bag_path);
00947 string file_loc = bag_path + data_bag_name;
00948 bag.open(file_loc, bagmode);
00949 for(uint32_t i=0;i<dataset->size();i++) {
00950 SensorPoint sp;
00951 sp.features.resize(dataset->at(0)->features.size());
00952 for(uint32_t j=0;j<dataset->at(0)->features.size();j++)
00953 sp.features[j] = dataset->at(rand() % dataset->size())->features[j];
00954 bag.write("/collision_data", ros::Time::now(), sp);
00955 }
00956 bag.close();
00957 }
00958
00959
00960
00961
00962
00963
00964
00965
00966
00967
00968
00969
00970
00971
00972
00973
00974
00975
00976
00977
00978
00979
00980
00981
00982
00983
00984
00985
00986
00987
00988
00989
00990
00991
00992
00993
00994
00995
00996
00997
00998
00999
01000
01001 double RandomForest::mahalanobisDist(LDLT<MatrixXd>* cov_inv, VectorXd& means, VectorXd& pt) {
01002 return sqrt( (pt - means).dot(cov_inv->solve(pt - means)) );
01003 }
01004
01005 void RandomForest::doMahalanobis() {
01006 loadDataset();
01007 vector<vector< SensorPoint::Ptr > > datasets(1000);
01008 for(uint32_t i=0;i<dataset->size();i++)
01009 datasets[dataset->at(i)->label].push_back(dataset->at(i));
01010 for(uint32_t i=0;i<datasets.size();i++)
01011 if(datasets[i].size() == 0) {
01012 datasets.resize(i);
01013 num_classes = i;
01014 break;
01015 }
01016 datasets[0].resize(10000);
01017 datasets[1].resize(10000);
01018 int num_feats = datasets[0][0]->features.size();
01019
01020 vector<LDLT<MatrixXd>* > cov_inv(num_classes);
01021
01022 vector<VectorXd> means(num_classes);
01023 for(int k=0;k<num_classes;k++) {
01024 means[k] = VectorXd::Zero(num_feats);
01025 for(uint32_t i=0;i<datasets[k].size();i++) {
01026 for(int j=0;j<num_feats;j++) {
01027 means[k](j) += datasets[k][i]->features[j];
01028 }
01029 }
01030 means[k] /= datasets[k].size();
01031 }
01032 int k=0, j=0, l=0;
01033 uint32_t i=0;
01034 for(k=0;k<num_classes;k++) {
01035
01036 MatrixXd var_mat = MatrixXd::Zero(num_feats, num_feats);
01037 #pragma omp parallel default(shared) private(j, l, i) num_threads(10)
01038 {
01039 #pragma omp for schedule(dynamic, 10)
01040 for(j=0;j<num_feats;j++) {
01041 ROS_INFO("j %d", j);
01042 for(l=0;l<num_feats;l++) {
01043 for(i=0;i<datasets[k].size();i++)
01044 var_mat(j,l) += (datasets[k][i]->features[j] - means[k](j)) * (datasets[k][i]->features[l] - means[k](l));
01045 var_mat(j,l) /= datasets[k].size();
01046 }
01047 }
01048 }
01049
01050 LDLT<MatrixXd>* qr = new LDLT<MatrixXd>(var_mat);
01051 cov_inv[k] = qr;
01052 }
01053 for(int k=0;k<num_classes;k++) {
01054 for(int l=0;l<num_classes;l++) {
01055 ArrayXd dists(datasets[l].size());
01056 VectorXd feat_vec(num_feats);
01057 int ind = 0, data_size = datasets[l].size();
01058 #pragma omp parallel default(shared) private(ind, j) num_threads(10)
01059 {
01060 #pragma omp for schedule(dynamic, 10)
01061 for(ind=0;ind<data_size;ind++) {
01062 for(j=0;j<num_feats;j++)
01063 feat_vec(j) = datasets[l][ind]->features[j];
01064 dists[ind] = mahalanobisDist(cov_inv[k], means[k], feat_vec);
01065 }
01066 }
01067 int nans = 0;
01068 if(dists[i] != dists[i]) {
01069 dists[i] = 0;
01070 nans++;
01071 }
01072
01073 double mean_dist = dists.sum() / (dists.size() - nans);
01074 VectorXd diff = dists - mean_dist;
01075 double var_dist = diff.dot(diff) / (dists.size() - nans);
01076 double std_dist = sqrt(var_dist);
01077 double min_dist = dists.minCoeff();
01078 double max_dist = dists.maxCoeff();
01079 printf("cov %d, data %d, mean_dist %f, std_dist %f, min %f, max %f, nans %d\n", k, l, mean_dist, std_dist, min_dist, max_dist, nans);
01080 }
01081 printf("cov %d, num_samps %d, rank 0\n", k, (int) datasets[k].size());
01082 }
01083 printf("num_classes %d, num_feats %d\n", num_classes, num_feats);
01084 }
01085
01086 void RandomForest::onInit() {
01087 nh = new ros::NodeHandle;
01088 nh_priv = new ros::NodeHandle("~");
01089
01090 std::string results_topic, classify_topic, data_bag, forest_bag, loaded_topic;
01091 bool random_permute, training_mode, means, variable_import;
01092
01093 nh_priv->param<bool>("is_abs", is_abs, false);
01094
01095 nh_priv->param<bool>("variable_import", variable_import, false);
01096 if(variable_import) {
01097 variableImportance();
01098 return;
01099 }
01100
01101 nh_priv->param<bool>("random_permute", random_permute, false);
01102 if(random_permute) {
01103 randomPermuteData();
01104 return;
01105 }
01106
01107 nh_priv->param<bool>("mahalanobis", means, false);
01108 if(means) {
01109 doMahalanobis();
01110 return;
01111 }
01112
01113 nh_priv->param<bool>("training_mode", training_mode, false);
01114 if(training_mode) {
01115 growWriteForest();
01116 return;
01117 }
01118
01119 nh_priv->getParam("classify_topic", classify_topic);
01120 nh_priv->getParam("results_topic", results_topic);
01121 nh_priv->getParam("loaded_topic", loaded_topic);
01122 nh_priv->getParam("classifier_id", classifier_id);
01123 nh_priv->getParam("classifier_name", classifier_name);
01124
01125 classify_sub = nh->subscribe(classify_topic.c_str(), 2,
01126 &RandomForest::classifyCallback, this);
01127 ROS_INFO("[random_forest] Subscribed to %s", classify_topic.c_str());
01128 results_pub = nh->advertise<ClassVotes>(results_topic, 1);
01129 ROS_INFO("[random_forest] Publishing on %s", results_topic.c_str());
01130 loaded_pub = nh->advertise<std_msgs::Bool>(loaded_topic, 1);
01131 ROS_INFO("[random_forest] Publishing on %s", loaded_topic.c_str());
01132
01133 trees_loaded = false;
01134 loadForest();
01135 }
01136
01137 }
01138
01139 using namespace collision_detection;
01140
01141 void INTHandler(int sig);
01142
01143 void INTHandler(int sig) {
01144 char c;
01145 signal(sig, SIG_IGN);
01146 printf("Want to exit?\n");
01147 c = getchar();
01148 if(c == 'y' || c == 'Y')
01149 exit(0);
01150 else
01151 signal(SIGINT, INTHandler);
01152 }
01153
01154 int main(int argc, char** argv) {
01155 ros::init(argc, argv, "random_forest", ros::init_options::AnonymousName);
01156
01157
01158 RandomForest rf;
01159 rf.onInit();
01160 ros::spin();
01161 printf("Exiting\n");
01162
01163
01164
01165
01166
01167
01168
01169
01170
01171
01172
01173
01174
01175
01176
01177
01178
01179
01180
01181
01182
01183
01184
01185
01186
01187
01188
01189
01190
01191
01192
01193
01194
01195
01196
01197
01198
01199
01200
01201
01202
01203
01204
01205
01206
01207
01208
01209
01210
01211
01212
01213
01214
01215
01216
01217
01218
01219
01220
01221
01222
01223
01224
01225
01226
01227
01228
01229
01230
01231
01232
01233
01234
01235
01236
01237
01238
01239
01240
01241
01242
01243
01244
01245
01246
01247
01248
01249
01250
01251
01252
01253
01254
01255
01256
01257
01258
01259
01260
01261 return 0;
01262 }