random_forest.cpp
Go to the documentation of this file.
00001 #include "pr2_overhead_grasping/random_forest.h"
00002 //#include <pluginlib/class_list_macros.h>
00003 #include <omp.h>
00004 #include <stdio.h>
00005 #include <signal.h>
00006 //PLUGINLIB_DECLARE_CLASS(collision_detection, random_forest, collision_detection::RandomForest, nodelet::Nodelet)
00007 
00008 using namespace std;
00009 using namespace pr2_overhead_grasping;
00010 int ATTRS_TRY = 44;
00011 double MAX_GAIN_THRESH = 0.0;
00012 
00013 namespace collision_detection {
00014 
00015   RandomTree::RandomTree(int c_d_tree_num) {
00016     d_tree_num = c_d_tree_num;
00017   }
00018 
00019   RandomTree::RandomTree(RandomTreeMsg::Ptr r_tree) {
00020     rand_tree = r_tree;
00021     num_classes = rand_tree->num_classes;
00022   }
00023 
00024   bool RandomTree::attrCompare(int inst_i, int inst_j, int attr) { 
00025     return (dataset->at(inst_i)->features[attr] < dataset->at(inst_j)->features[attr]);
00026   }
00027 
00028   struct AttrComp {
00029     int attr;
00030     vector< SensorPoint::Ptr >* dataset;
00031     AttrComp(vector< SensorPoint::Ptr >* c_dataset, int c_attr) { 
00032       dataset = c_dataset;
00033       attr = c_attr;
00034     }
00035     bool operator()(int inst_i, int inst_j) {
00036       return (dataset->at(inst_i)->features[attr] < dataset->at(inst_j)->features[attr]);
00037     }
00038   };
00039 
00040   typedef set<int, boost::function<bool (int,int)> > AttrSet;
00041   void RandomTree::findBestSplit(vector<int>* insts, vector<int>& attrs,
00042                                              pair<int, float>& ret) {
00043     long double max_gain = 0.0;
00044     float best_split_f = 0;
00045     int best_split_attr = 0;
00046     for(uint32_t k=0;k<insts->size();k++)
00047       if((uint32_t) insts->at(k) >= dataset->size())
00048         printf("WTF\n");
00049     for(uint32_t a=0;a<attrs.size();a++) {
00050       //cout << "0-----------------------------------------------------------" << endl;
00051       //boost::function<bool (int,int)> attr_comp = boost::bind(&RandomTree::attrCompare, *this, _1, _2, attrs[a]);
00052       AttrComp attr_comp(dataset, attrs[a]);
00053       //cout << attrs[a] << " " << dataset->at(0)->features.size() << endl;
00054       //AttrSet insts_sorted(attr_comp);
00055       vector<int> insts_temp(insts->begin(), insts->end());
00056       set<int, AttrComp> insts_sorted(attr_comp);
00057       //cout << "0-----------------------------------------------------------" << endl;
00058       insts_sorted.insert(insts_temp.begin(), insts_temp.end()); 
00059       //cout << "0-----------------------------------------------------------" << endl;
00060       map<int, int> class_sums;
00061       map<int, int> class_cur_sums;
00062       map<int, int>::iterator cs_iter, ccs_iter;
00063       //cout << "1-----------------------------------------------------------" << endl;
00064       for(uint32_t i =0;i<insts->size();i++) {
00065         int label = dataset->at(insts->at(i))->label;
00066         if(class_sums.count(label) == 0) {
00067           class_sums[label] = 0;
00068           class_cur_sums[label] = 0;
00069         }
00070         class_sums[label]++;
00071       }
00072       //class_cur_sums[labels[insts_sorted[0]]]++;
00073       //cout << "2-----------------------------------------------------------" << endl;
00074       uint32_t inst_ctr = 0;
00075       for(AttrSet::iterator inst = insts_sorted.begin();
00076                                               inst != insts_sorted.end();inst++) {
00077         if(inst_ctr == insts_sorted.size()-1)
00078           break;
00079         //cout << dataset->at(insts_sorted[i])->features[attrs[a]] << endl;
00080         class_cur_sums[dataset->at(*inst)->label]++;
00081         long double entropy = 0.0, gain = 0.0;
00082         cs_iter = class_sums.begin();
00083         ccs_iter = class_cur_sums.begin();
00084         for(int j=0;j<num_classes;j++) { // find entropy
00085           long double prec_in_class_l = ccs_iter->second / (long double) (insts->size());
00086           if(prec_in_class_l > 0.00001)
00087             entropy -= prec_in_class_l * log2(prec_in_class_l);
00088           long double prec_in_class_r = (cs_iter->second - ccs_iter->second) / (long double) (insts->size());
00089           if(prec_in_class_r > 0.00001)
00090             entropy -= prec_in_class_r * log2(prec_in_class_r);
00091           //cout << prec_in_class_l << ", " << prec_in_class_r << ", ";
00092           gain += entropy * cs_iter->second;
00093           cs_iter++; ccs_iter++;
00094         }
00095         //cout << "3-----------------------------------------------------------" << endl;
00096         //cout << entropy << ", " << i << endl;
00097         if(gain > max_gain) { // find best entropy
00098           best_split_f = dataset->at(*inst)->features[attrs[a]]; // split point x <= best_split_f : left
00099           best_split_attr = attrs[a]; // attribute being compared
00100           //cout << *inst << ", " << insts_sorted.size() << ", " << dataset->at(*inst)->features[attrs[a]] << ", " << dataset->at(*inst)->features[attrs[a]] << ", " << *inst << ", " << *inst << ", " << gain << ", " << class_sums[0] << ", " << class_sums[1] << ", " << class_cur_sums[0] << ", " << class_cur_sums[1] << endl;
00101           max_gain = gain;
00102           //if(gain > MAX_GAIN_THRESH) 
00103             // return and say that this is a terminal split
00104             //return make_pair(-best_split_attr, best_split_f);
00105         }
00106         //cout << "4-----------------------------------------------------------" << endl;
00107         inst_ctr++;
00108       }
00109     }
00110     ret.first = best_split_attr; ret.second =  best_split_f;
00111   }
00112 
00113   void RandomTree::splitNode(vector<int>* node_inds, 
00114                              pair<int, float>& split_pt,
00115                              pair<vector<int>*, vector<int>* >& split_nodes) {
00116     pair<vector<int>*, vector<int>* > nodes;
00117     split_nodes.first = new vector<int>; split_nodes.second = new vector<int>;
00118     //cout << "YOOOOOOOO" << endl;
00119     for(uint32_t i=0;i<node_inds->size();i++) {
00120       //cout << dataset->at(node_inds->at(i))->features[split_pt.first] << endl;
00121       if(dataset->at(node_inds->at(i))->features[split_pt.first] <= split_pt.second)
00122         split_nodes.first->push_back(node_inds->at(i));
00123       else
00124         split_nodes.second->push_back(node_inds->at(i));
00125     }
00126     assert(split_nodes.first->size() > 0);
00127     assert(split_nodes.second->size() > 0);
00128   }
00129 
00130   void RandomTree::growTree(vector< SensorPoint::Ptr >* c_dataset,
00131                             vector<int>* inds) {
00133     // Balanced Random Tree
00134     // Taken from "Using Random Forest to Learn Imbalanced Data"
00135     // Chao Chen, Andy Liaw, Leo Breiman
00136     // 2004
00138     dataset = c_dataset;
00139     //cout << "YO 1" << endl;
00140     int num_attrs = dataset->at(0)->features.size();
00141 
00142     // Tree will be saved as this message
00143     rand_tree = boost::shared_ptr<RandomTreeMsg>(new RandomTreeMsg);
00144     //cout << "YO 2" << endl;
00145     rand_tree->tree_num = d_tree_num;
00146     rand_tree->attr_split.resize(2*inds->size());
00147     //cout << "YO 3" << endl;
00148     rand_tree->val_split.resize(2*inds->size());
00149     rand_tree->r_node_inds.resize(2*inds->size(), -999);
00150     vector<int> r_node_stack;
00151     //cout << "YO 4" << endl;
00152     int tm_i = 0;
00153 
00155     // make samples
00156     map<int, vector<int> > class_inds; // list of indices indexed by label
00157     for(uint32_t i=0;i<inds->size();i++) 
00158       class_inds[dataset->at(inds->at(i))->label].push_back(inds->at(i));
00159     num_classes = class_inds.size();
00160     // find smallest class in the sample
00161     int minority_class = 0, min_class = class_inds.begin()->second.size(); 
00162     for(map<int, vector<int> >::iterator i=class_inds.begin();i!=class_inds.end();i++) {
00163       if(i->second.size() < (uint32_t) min_class) {
00164         min_class = i->second.size();
00165         minority_class = i->first;
00166       }
00167     }
00168     //cout << "YO 5" << endl;
00169     vector<bool> oobs(dataset->size(), true);
00170     // Balance the sample sets
00171     int num_samples = class_inds[minority_class].size();
00172     // sampled top node to split
00173     vector<int>* head_node = new vector<int>(num_samples*class_inds.size()); 
00174     int i_cntr = 0;
00175     for(map<int, vector<int> >::iterator i=class_inds.begin();i!=class_inds.end();i++) {
00176       //cout << head_node->size() << " Size" << endl;
00177       for(int j=0;j<num_samples;j++) {
00178         int sample_ind = i->second[rand() % i->second.size()];
00179         //cout << sample_ind << endl;
00180         head_node->at(i_cntr++) = sample_ind;
00181         oobs[sample_ind] = false;
00182       }
00183     }
00184     //cout << "YO 6" << endl;
00185     /* Unbalanced tree
00186     int num_samples = dataset->size();
00187     for(uint32_t i=0;i<num_samples;i++) {
00188       int sample_ind = rand() % dataset->size();
00189       head_node->push_back(sample_ind);
00190       oobs[sample_ind] = false;
00191     }*/
00193 
00195     // Create tree
00196     int last_r_node_ind = -1;
00197     vector<vector<int>* > node_stack;
00198     node_stack.push_back(head_node);
00199     //cout << "YO 7" << endl;
00200     while(ros::ok()) {
00201       if(node_stack.size() == 0)
00202         break;
00203       vector<int>* cur_node = node_stack.back(); 
00204       //cout << node_stack.size() << endl;
00205 
00206       bool homo = true; // is the node homogenous?
00207       for(uint32_t i=0;i<cur_node->size();i++) {
00208         //cout << dataset->at(cur_node->at(0))->label << ", " << dataset->at(cur_node->at(i))->label << endl;
00209         if(dataset->at(cur_node->at(0))->label != dataset->at(cur_node->at(i))->label) {
00210           homo = false;
00211           break;
00212         }
00213       }
00214       if(homo) {
00215         // terminal: everything of same class
00216         // note that negative values indicate sentinels
00217         // the class of that terminal node is the absolute value of the first value
00218 
00219         // add terminal node to tree
00220         //cout << dataset->size() << endl;
00221         //cout << cur_node->size() << endl;
00222         //cout << cur_node->at(0) << endl;
00223         rand_tree->attr_split[tm_i] = dataset->at(cur_node->at(0))->label;
00224         rand_tree->val_split[tm_i] = 0.0;
00225         rand_tree->r_node_inds[tm_i] = -1; // this node is terminal
00226         // the parent of this node's right node is the next element
00227         if(r_node_stack.size() > 0) {
00228           last_r_node_ind = r_node_stack.back();
00229           rand_tree->r_node_inds[r_node_stack.back()] = tm_i + 1; 
00230           r_node_stack.pop_back();
00231           tm_i++;
00232         }
00233 
00234         node_stack.pop_back();
00235         //cout << "YO in 0" << endl;
00236         //delete cur_node;
00237         //cout << "YO ot 0" << endl;
00238         continue;
00239       }
00240       assert(cur_node->size() > 1);
00241 
00242       // get a random selection of attributes to check for a split
00243       int num_try = (int) sqrt(num_attrs);
00244       vector<int> attrs(num_try, 0); 
00245       for(int i=0;i<num_try;i++) {
00246         attrs[i] = rand() % num_attrs;
00247       }
00248 
00249       // find a splitting point on this node
00250       // <attribute to split on, value to split at> if x <= v :-> go left
00251       //cout << cur_node->size() << "Size in" << endl;
00252       pair<int, float> split_pt;
00253       findBestSplit(cur_node, attrs, split_pt);
00254       //cout << cur_node->size() << "Size out" << endl;
00255       if(split_pt.first < 0) {
00256         // DISABLED
00257         assert(false);
00258         // terminal, entropy is small
00259         split_pt.first *= -1;
00260         // left and right split nodes
00261         pair<vector<int>*, vector<int>* > new_nodes;
00262         splitNode(cur_node, split_pt, new_nodes);
00263 
00264         // left
00265         vector<int> class_sums_l(num_classes, 0);
00266         for(uint32_t i=0;i<new_nodes.first->size();i++) 
00267           class_sums_l[dataset->at(new_nodes.first->at(i))->label]++;
00268         int popular_class_l = max_element(class_sums_l.begin(), class_sums_l.end()) - class_sums_l.begin();
00269 
00270         // add terminal left node to tree
00271         rand_tree->attr_split[tm_i] = popular_class_l;
00272         rand_tree->val_split[tm_i] = 0.0;
00273         rand_tree->r_node_inds[tm_i] = -1; // this node is terminal
00274         // the parent of this node's right node is the next element
00275         rand_tree->r_node_inds[r_node_stack.back()] = tm_i + 1; 
00276         r_node_stack.pop_back();
00277         tm_i++;
00278 
00279         // right
00280         vector<int> class_sums_r(num_classes, 0);
00281         for(uint32_t i=0;i<new_nodes.second->size();i++) 
00282           class_sums_r[dataset->at(new_nodes.second->at(i))->label]++;
00283         int popular_class_r = max_element(class_sums_r.begin(), class_sums_r.end()) - class_sums_r.begin();
00284 
00285         // add terminal right node to tree
00286         rand_tree->attr_split[tm_i] = popular_class_r;
00287         rand_tree->val_split[tm_i] = 0.0;
00288         rand_tree->r_node_inds[tm_i] = -1; // this node is terminal
00289         // the parent of this node's right node is the next element
00290         rand_tree->r_node_inds[r_node_stack.back()] = tm_i + 1; 
00291         r_node_stack.pop_back();
00292         tm_i++;
00293       } else {
00294         // recurse upon both left and right nodes
00295         // left and right split nodes
00296         pair<vector<int>*, vector<int>* > new_nodes;
00297         splitNode(cur_node, split_pt, new_nodes);
00298 
00299         node_stack.pop_back();
00300         // normal
00301         node_stack.push_back(new_nodes.second); 
00302         // put the left node at the top of the stack
00303         node_stack.push_back(new_nodes.first); 
00304 
00305         // add split to tree
00306         rand_tree->attr_split[tm_i] = split_pt.first;
00307         rand_tree->val_split[tm_i] = split_pt.second;
00308         r_node_stack.push_back(tm_i);
00309         tm_i++;
00310       }
00311       //cout << "YO in 1" << endl;
00312       //delete cur_node;
00313       //cout << "YO ot 1" << endl;
00314     }
00315     //cout << "YO 8" << endl;
00316     rand_tree->r_node_inds[last_r_node_ind] = tm_i - 1; 
00317     //cout << r_node_stack[0] << ", " << rand_tree->r_node_inds[0] << endl;
00318     //exit(1);
00319 
00320     // we have an pre-order tree stored at d_tree
00321     // resize message back to minimum size
00322     rand_tree->attr_split.resize(tm_i);
00323     rand_tree->val_split.resize(tm_i);
00324     rand_tree->r_node_inds.resize(tm_i);
00325     //cout << tm_i << endl;
00326     for(uint32_t i=0;i<oobs.size();i++)
00327       if(oobs[i])
00328         rand_tree->out_of_bags.push_back(i);
00329     rand_tree->num_classes = num_classes;
00330     // rand_tree is fully built
00331     //cout << "YO TREE OUT" << endl;
00332   }
00333 
00334   void RandomTree::writeTree(string& bag_file, bool is_first) {
00335     // save tree to file
00336     rosbag::Bag bag;
00337     int bagmode;
00338     if(is_first)
00339       bagmode = rosbag::bagmode::Write;
00340     else
00341       bagmode = rosbag::bagmode::Append;
00342     bag.open(bag_file, bagmode);
00343     bag.write("trees", ros::Time::now(), rand_tree);
00344     bag.close();
00345   }
00346 
00347   int RandomTree::classifyInstance(SensorPoint::Ptr inst) {
00348     int ind = 0, attr;
00349 
00350     //cout << rand_tree->attr_split.size() << ", " << rand_tree->r_node_inds.size() << ", " << rand_tree->val_split.size() << endl;
00351     //for(int i=0;i<rand_tree->attr_split.size();i++)
00352     //  cout << i << ", " << rand_tree->attr_split[i] << ", " << rand_tree->val_split[i] << ", " << rand_tree->r_node_inds[i] << ", " << rand_tree->r_node_inds[rand_tree->r_node_inds[i]]  << endl;
00353       
00354 
00355     //exit(1);
00356     //cout << "Class in" << endl;
00357     while(ros::ok()) {
00358       //cout << attr << ", " << ind << ", " << rand_tree->r_node_inds[ind] << ", " << inst->features.size() << ", " << rand_tree->r_node_inds.size() <<", " << rand_tree->val_split.size() << ", " << rand_tree->attr_split.size() << endl;
00359       attr = rand_tree->attr_split[ind];
00360       assert(attr >= 0 && attr < inst->features.size());
00361       //cout << ind << " ";
00362       assert(ind != -999);
00363       if(rand_tree->r_node_inds[ind] < 0) {
00364         // terminal node
00365         //cout << "Class out" << endl;
00366         return attr;
00367       }
00368       float feat = inst->features[attr];
00369       if(is_abs)
00370         feat = abs(feat);
00371       if(feat <= rand_tree->val_split[ind]) {
00372         // go left
00373         ind++;
00374       } else {
00375         // go right
00376         ind = rand_tree->r_node_inds[ind];
00377       }
00378     }
00379     //NODELET_FATAL("BAD RANDOM TREE");
00380     return -1;
00381   }
00382 
00383   void RandomForest::loadDataset() {
00384     dataset = new vector< SensorPoint::Ptr >;
00385     string bag_path;
00386     XmlRpc::XmlRpcValue bag_names, bag_labels;
00387     nh_priv->getParam("bag_path", bag_path);
00388     nh_priv->getParam("bag_names", bag_names);
00389     nh_priv->getParam("bag_labels", bag_labels);
00390     for(int i=0;i<bag_names.size();i++) {
00391       string file_loc = bag_path + (string) bag_names[i];
00392       loadDataBag(file_loc, bag_labels[i]);
00393     }
00394   }
00395 
00396   void RandomForest::loadDataBag(string& data_bag, int label) {
00397     // load dataset
00398     rosbag::Bag bag(data_bag);
00399     rosbag::View view(bag, rosbag::TopicQuery("/collision_data"));
00400     BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00401       SensorPoint::Ptr sp = m.instantiate<SensorPoint>();
00402       if(sp != NULL) {
00403         sp->label = label;
00404         if(is_abs)
00405           for(uint32_t i=0;i<sp->features.size();i++)
00406             sp->features[i] = abs(sp->features[i]);
00407         dataset->push_back(sp);
00408       }
00409     }
00410     assert(dataset->size() != 0);
00411   }
00412 
00413   void RandomForest::setDataset(vector< SensorPoint::Ptr >* datas) {
00414     dataset = datas;
00415   }
00416 
00417   void RandomForest::growForest(vector< SensorPoint::Ptr >* c_dataset,
00418                             vector<int>* inds, int c_num_trees) {
00419     dataset = c_dataset;
00420     num_trees = c_num_trees;
00421     // grow trees
00422     trees = new RandomTree*[num_trees];
00423     int i, NUM_CHUNKS = 10;
00424     RandomForest* this_forest = this;
00425 #pragma omp parallel shared(this_forest, inds, NUM_CHUNKS) private(i)
00426     {
00427 #pragma omp for schedule(dynamic, NUM_CHUNKS)
00428       for(i=0;i<this_forest->num_trees;i++) {
00429         ROS_INFO("Growing tree %d", i+1);
00430         //cout << i << " " << omp_get_thread_num() << " " << omp_get_num_threads() << endl;
00431         while(ros::ok()) {
00432           try {
00433             this_forest->trees[i] = new RandomTree(i);
00434             this_forest->trees[i]->growTree(this_forest->dataset, inds);
00435             break;
00436           }
00437           catch (bad_alloc &) {
00438             ROS_INFO("Memory unavailable...restaring...");
00439             ros::Duration(1.0).sleep();
00440           }
00441         }
00442       }
00443     }
00444   }
00445 
00446   RandomForest::~RandomForest() {
00447     for(int i=0;i<num_trees;i++)
00448       delete trees[i];
00449   }
00450 
00451   void RandomForest::writeForest() {
00452     string forest_bag_name;
00453     nh_priv->getParam("forest_bag_name", forest_bag_name);
00454     writeForest(forest_bag_name);
00455   }
00456 
00457   void RandomForest::writeForest(string name) {
00458     string bag_path;
00459     nh_priv->getParam("bag_path", bag_path);
00460     string file_loc = bag_path + name;
00461     for(int i=0;i<num_trees;i++) {
00462       trees[i]->writeTree(file_loc, i == 0);
00463     }
00464   }
00465 
00466 
00467   void RandomForest::growWriteForest() {
00468     int c_num_trees;
00469     bool is_validation;
00470     nh_priv->getParam("num_trees", c_num_trees);
00471     nh_priv->getParam("is_validation", is_validation);
00472     loadDataset();
00473     vector<int>* inds = new vector<int>(dataset->size());
00474     for(uint32_t i=0;i<dataset->size();i++) 
00475       inds->at(i) = i;
00476     if(is_validation) {
00477       int pos_id;
00478       bool classify_first;
00479       nh_priv->getParam("pos_id", pos_id);
00480       nh_priv->getParam("classify_first", classify_first);
00481       ROS_INFO("Running ten-fold cross validation.");
00482       vector<map<int, int> > votes_total;
00483       RandomForest::runTenFold(dataset, pos_id, c_num_trees, votes_total, classify_first);
00484       // vote filter
00485       bool filter;
00486       nh_priv->getParam("filter", filter);
00487       if(filter) {
00488         double filter_thresh;
00489         string filter_out, bag_path;
00490         ROS_INFO("Filtering bad data.");
00491         nh_priv->getParam("filter_thresh", filter_thresh);
00492         nh_priv->getParam("bag_path", bag_path);
00493         nh_priv->getParam("filter_out", filter_out);
00494         string file_loc = bag_path + filter_out;
00495         rosbag::Bag bag;
00496         int bagmode = rosbag::bagmode::Write;
00497         bag.open(file_loc, bagmode);
00498         for(uint32_t i = 0;i<votes_total.size();i++) {
00499           if(votes_total[i][pos_id] / (double) c_num_trees < filter_thresh &&
00500               dataset->at(i)->label != (uint32_t) pos_id) {
00501             // write
00502             bag.write("/collision_data", ros::Time::now(), dataset->at(i));
00503           }
00504         }
00505         bag.close();
00506       }
00507       ROS_INFO("Done, please exit.");
00508     } else {
00509       ROS_INFO("Growing forest");
00510       growForest(dataset, inds, c_num_trees);
00511       ROS_INFO("Writing forest");
00512       writeForest();
00513       ROS_INFO("Finished writing forest, please exit");
00514     }
00515   }
00516 
00517   void RandomForest::loadForest() {
00518     string forest_bag_name;
00519     string bag_path;
00520     nh_priv->getParam("bag_path", bag_path);
00521     nh_priv->getParam("forest_bag_name", forest_bag_name);
00522     string file_loc = bag_path + forest_bag_name;
00523     rosbag::Bag bag(file_loc);
00524     rosbag::View view(bag, rosbag::TopicQuery("trees"));
00525     int ind = 0;
00526     trees = new RandomTree*[view.size()];
00527     BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00528       RandomTreeMsg::Ptr tree = m.instantiate<RandomTreeMsg>();
00529       if(tree != NULL) {
00530         trees[ind++] = new RandomTree(tree);
00531         trees[ind-1]->is_abs = is_abs;
00532         oobs.push_back(tree->out_of_bags);
00533       }
00534     }
00535     num_trees = ind;
00536     num_classes = trees[0]->num_classes;
00537     std_msgs::Bool loaded;
00538     loaded.data = true;
00539     loaded_pub.publish(loaded);
00540     trees_loaded = true;
00541     ROS_INFO("[random_forest] All trees loaded.");
00542   }
00543 
00544   void RandomForest::loadCovMat() {
00545     string mahal_bag_name;
00546     string bag_path;
00547     nh_priv->getParam("bag_path", bag_path);
00548     nh_priv->getParam("mahal_bag_name", mahal_bag_name);
00549     string file_loc = bag_path + mahal_bag_name;
00550     rosbag::Bag bag(file_loc);
00551     rosbag::View view(bag, rosbag::TopicQuery("matrix"));
00552     BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00553       CovarianceMatrix::Ptr cov_mat_msg = m.instantiate<CovarianceMatrix>();
00554       if(cov_mat_msg != NULL) {
00555         MatrixXd cov_mat(cov_mat_msg->size, cov_mat_msg->size);
00556         int m_ind = 0;
00557         for(uint32_t i=0;i<cov_mat_msg->size;i++) {
00558           for(uint32_t j=0;j<cov_mat_msg->size;j++)
00559             cov_mat(i,j) = cov_mat_msg->cov_mat[m_ind++];
00560           means(i) = cov_mat_msg->means[i];
00561         }
00562         cov_inv = new LDLT<MatrixXd>(cov_mat);
00563       }
00564     }
00565     std_msgs::Bool loaded;
00566     loaded.data = true;
00567     loaded_pub.publish(loaded);
00568     trees_loaded = true;
00569     ROS_INFO("[random_forest] All trees loaded.");
00570   }
00571 
00572   void RandomForest::createCovMat() {
00573     loadDataset();
00574     //dataset->resize(10000);
00575     int num_feats = dataset->at(0)->features.size();
00576 
00577     VectorXd means = VectorXd::Zero(num_feats);
00578     for(uint32_t i=0;i<dataset->size();i++) {
00579       for(int j=0;j<num_feats;j++) {
00580         means(j) += dataset->at(i)->features[j];
00581       }
00582     }
00583     means /= dataset->size();
00584     int j=0, l=0;
00585     uint32_t i=0;
00586     MatrixXd var_mat = MatrixXd::Zero(num_feats, num_feats);
00587 #pragma omp parallel default(shared) private(j, l, i) num_threads(10)
00588     {
00589 #pragma omp for schedule(dynamic, 10)
00590       for(j=0;j<num_feats;j++) {
00591         ROS_INFO("j %d", j);
00592         for(l=0;l<num_feats;l++) {
00593           for(i=0;i<dataset->size();i++) 
00594             var_mat(j,l) += (dataset->at(i)->features[j] - means(j)) * (dataset->at(i)->features[l] - means(l));
00595           var_mat(j,l) /= dataset->size();
00596         }
00597       }
00598     }
00599 
00600     CovarianceMatrix cov_mat_msg;
00601     cov_mat_msg.size = num_feats;
00602     for(j=0;j<num_feats;j++) 
00603       for(l=0;l<num_feats;l++) 
00604         cov_mat_msg.cov_mat.push_back(var_mat(j,l));
00605     // save to file
00606     string mahal_bag_name;
00607     string bag_path;
00608     nh_priv->getParam("bag_path", bag_path);
00609     nh_priv->getParam("mahal_bag_name", mahal_bag_name);
00610     string file_loc = bag_path + mahal_bag_name;
00611     rosbag::Bag bag;
00612     string forest_bag_name;
00613     int bagmode = rosbag::bagmode::Write;
00614     bag.open(file_loc, bagmode);
00615     bag.write("matrix", ros::Time::now(), cov_mat_msg);
00616     bag.close();
00617   }
00618 
00619   void RandomForest::collectVotes(SensorPoint::Ptr inst, map<int, int>& class_votes) {
00620     int i;
00621     //cout << "IN ";
00622     for(i=0;i<num_trees;i++) {
00623       int class_vote = trees[i]->classifyInstance(inst);
00624       if(class_votes.count(class_vote) == 0) 
00625         class_votes[class_vote] = 0;
00626       class_votes[class_vote]++;
00627       //cout<< class_vote << ", ";
00628     }
00629     //cout << " OUT" << endl;
00630   }
00631 
00632   void RandomForest::classifyCallback(const boost::shared_ptr<SensorPoint>& inst) {
00633     if(!trees_loaded) {
00634       ROS_INFO("[random_forest] Classifcation requested but trees not loaded.");
00635       return;
00636     }
00637     ClassVotes::Ptr class_votes(new ClassVotes);
00638     map<int, int> cvs;
00639     for(uint32_t i=0;i<inst->features.size();i++)
00640       if(inst->features[i] != inst->features[i]) {
00641         ROS_INFO("NAN %d", i);
00642         return;
00643       }
00644     //cout << endl;
00645     collectVotes(inst, cvs);
00646     class_votes->votes.resize(num_classes);
00647     class_votes->classes.resize(num_classes);
00648     int ind = 0;
00649     for(map<int, int>::iterator iter=cvs.begin();iter!=cvs.end();iter++) {
00650       //class_votes->votes.push_back(iter->second);
00651       //ROS_INFO("[random_forest] %d %d", iter->first, iter->second);
00652       class_votes->votes[iter->first] = iter->second;
00653       class_votes->classes[ind++] = iter->first;
00654     }
00655     class_votes->classifier_name = classifier_name;
00656     class_votes->classifier_id = classifier_id;
00657     results_pub.publish(class_votes);
00658   }
00659 
00660   // finds the first class which is not the positive class
00661   int RandomForest::findFirstClass(vector<pair<map<int, int>, float > >* votes_list, int pos_id, float thresh) {
00662     int min_class = pos_id;
00663     float min_delay = 10000.0;
00664     for(uint32_t j=0;j<votes_list->size();j++) {
00665       int pred_class;
00666       if(votes_list->at(j).first[pos_id] > thresh) 
00667         pred_class = pos_id;
00668       else {
00669         int max_class = 0, max_votes = -1;
00670         for(map<int, int>::iterator iter=votes_list->at(j).first.begin();iter!=votes_list->at(j).first.end();iter++) {
00671           if(iter->second > max_votes) {
00672             max_class = iter->first;
00673             max_votes = iter->second;
00674           }
00675         }
00676         pred_class = max_class;
00677       }
00678       if(votes_list->at(j).second < min_delay && pred_class != pos_id) {
00679         min_class = pred_class;
00680         min_delay = votes_list->at(j).second;
00681       }
00682     }
00683     return min_class;
00684   }
00685 
00686   int RandomForest::findFrequentClass(vector<pair<map<int, int>, float > >* votes_list, int pos_id, float thresh) {
00687     // all of the vote maps in this trajectory
00688     map<int, int> vote_counts;
00689     for(uint32_t j=0;j<votes_list->size();j++) {
00690       int pred_class;
00691       if(votes_list->at(j).first[pos_id] > thresh) 
00692         pred_class = pos_id;
00693       else {
00694         int max_class = 0, max_votes = -1;
00695         for(map<int, int>::iterator iter=votes_list->at(j).first.begin();iter!=votes_list->at(j).first.end();iter++) {
00696           if(iter->second > max_votes) {
00697             max_class = iter->first;
00698             max_votes = iter->second;
00699           }
00700         }
00701         pred_class = max_class;
00702       }
00703       if(vote_counts.count(pred_class) == 0)
00704         vote_counts[pred_class] = 0;
00705       vote_counts[pred_class]++;
00706     }
00707     int max_class = 0, max_votes = -1;
00708     for(map<int, int>::iterator iter=vote_counts.begin();iter!=vote_counts.end();iter++) {
00709       if(iter->second > max_votes) {
00710         max_class = iter->first;
00711         max_votes = iter->second;
00712       }
00713     }
00714     return max_class;
00715   }
00716 
00717   void RandomForest::runTenFold(vector< SensorPoint::Ptr >* train_test_data, 
00718                                        int pos_id,
00719                                        int c_num_trees,
00720                                        vector<map<int,int> >& votes_total,
00721                                        bool classify_first) {
00722     int num_roc = 21;
00723     int NUM_FOLDS = 10;
00724     vector<int> traj_ids;
00725     map<int, int> label_cntr; 
00726     vector<int> inv_label_cntr;
00727     map<int, vector<int> > traj_id_map;
00728     int label_ind = 0;
00729     for(uint32_t i=0;i<train_test_data->size();i++) {
00730       int traj_id = train_test_data->at(i)->traj_id;
00731       if(traj_id_map[traj_id].size() == 0)
00732         traj_ids.push_back(traj_id);
00733       traj_id_map[traj_id].push_back(i);
00734       if(label_cntr.count(train_test_data->at(i)->label) == 0) {
00735         inv_label_cntr.push_back(train_test_data->at(i)->label);
00736         label_cntr[train_test_data->at(i)->label] = label_ind++;
00737       }
00738     }
00739     int num_classes = label_cntr.size();
00740 
00741     // create folds
00742     FoldData::Ptr fold_save(new FoldData);
00743     random_shuffle(traj_ids.begin(), traj_ids.end());
00744     vector<int>* folds[NUM_FOLDS];
00745     vector<int>* f_tests[NUM_FOLDS];
00746     for(int i=0;i<NUM_FOLDS;i++) {
00747       folds[i] = new vector<int>;
00748       f_tests[i] = new vector<int>;
00749     }
00750     for(uint32_t i=0;i<traj_ids.size();i+=NUM_FOLDS) {
00751       for(int j=0;j<NUM_FOLDS;j++) {
00752         if(i+j == traj_ids.size())
00753           break;
00754         for(int k=0;k<NUM_FOLDS;k++) {
00755           int id = traj_ids[i+j];
00756           if(k != j) {
00757             for(uint32_t l=0;l<traj_id_map[id].size();l++) {
00758               folds[k]->push_back(traj_id_map[id][l]);
00759               fold_save->fold_data.push_back(traj_id_map[id][l]);
00760             }
00761           } else {
00762             for(uint32_t l=0;l<traj_id_map[id].size();l++) {
00763               f_tests[k]->push_back(traj_id_map[id][l]);
00764               fold_save->test_data.push_back(traj_id_map[id][l]);
00765             }
00766           }
00767         }
00768       }
00769     } 
00770     /*
00771     for(int k=0;k<NUM_FOLDS;k++) {
00772       fold_save->fold_sizes.push_back(folds[k]->size());
00773       fold_save->test_sizes.push_back(f_tests[k]->size());
00774     }
00775     rosbag::Bag bag;
00776     int bagmode = rosbag::bagmode::Write;
00777     bag.open("fold_save.bag", bagmode);
00778     bag.write("fold_save", ros::Time::now(), fold_save);
00779     bag.close();
00780     */
00781     
00782     MatrixXi confusion_mats[num_roc];
00783     for(int j=0;j<num_roc;j++) 
00784       confusion_mats[j] = MatrixXi::Zero(num_classes, num_classes);
00785 
00786 
00787     votes_total.resize(train_test_data->size());
00788 
00789     vector<float> roc_list;
00790     for(int i=0;i<NUM_FOLDS;i++) {
00791       RandomForest rf;
00792       printf("Growing Forest %d\n", i+1);
00793       //vector<int>* fold = new vector<int>();
00794       for(uint32_t k=0;k<folds[i]->size();k++)
00795         if((uint32_t) folds[i]->at(k) >= train_test_data->size())
00796           printf("WTF\n");
00797       //  cout << folds[i]->at(k) << " ";
00798       //rf.growForest(train_test_data, fold, c_num_trees);
00799       rf.growForest(train_test_data, folds[i], c_num_trees);
00800       //stringstream tmp; tmp << "forest_X_tmp_" << i << ".bag"; 
00801       //rf.writeForest(tmp.str());
00802       printf("Evaluating...\n");
00803 
00804       // maps traj_ids to a list of labels paired with vote maps
00805       map<int, pair<int, vector<pair<map<int, int>, float > >* > > votes_map;
00806       for(uint32_t j=0;j<f_tests[i]->size();j++) {
00807         int cur_traj_id = train_test_data->at(f_tests[i]->at(j))->traj_id;
00808         //vector<map<int, int> >* votes_list;
00809         if(votes_map.count(cur_traj_id) == 0) {
00810           pair<int, vector<pair<map<int, int>, float > >* > vote_pair;
00811           vote_pair.first = train_test_data->at(f_tests[i]->at(j))->label;
00812           vote_pair.second = new vector<pair<map<int, int>, float > >;
00813           votes_map[cur_traj_id] = vote_pair;
00814         }
00815         map<int, int> cur_votes;
00816         rf.collectVotes(train_test_data->at(f_tests[i]->at(j)), cur_votes);
00817         votes_map[cur_traj_id].second->push_back(
00818                 make_pair(cur_votes, train_test_data->at(f_tests[i]->at(j))->detect_delay));
00819         for(map<int, int>::iterator vt_iter=cur_votes.begin();vt_iter!=cur_votes.end();
00820                      vt_iter++)
00821           votes_total[f_tests[i]->at(j)][vt_iter->first] = vt_iter->second;
00822         //cout << votes_list[j]->size() << " ";
00823       }
00824       float cur_percent = 1.0 / (num_roc + 1);
00825 
00826       for(int k=0;k<num_roc;k++) {
00827         map<int, pair<int, vector<pair<map<int, int>, float > >* > >::iterator votes_map_iter;
00828         for(votes_map_iter=votes_map.begin();
00829             votes_map_iter!=votes_map.end();votes_map_iter++) {
00830           int act_class = votes_map_iter->second.first;
00831           int pred_class;
00832           if(classify_first) 
00833             pred_class = findFirstClass(votes_map_iter->second.second, pos_id, 
00834                                         cur_percent * c_num_trees);
00835           else
00836             pred_class = findFrequentClass(votes_map_iter->second.second, pos_id, 
00837                                         cur_percent * c_num_trees);
00838           confusion_mats[k](label_cntr[pred_class],label_cntr[act_class])++;
00839         }
00840         if(i == 0)
00841           roc_list.push_back(cur_percent);
00842         cur_percent += 1.0 / (num_roc + 1);
00843         cout << confusion_mats[k] << endl;
00844       }
00845     }
00846 
00847     vector<vector<double> > tpr_list(num_roc);
00848     vector<vector<double> > fpr_list(num_roc);
00849     vector<vector<double> > spc_list(num_roc);
00850     for(int j=0;j<num_roc;j++) {
00851       for(int k=0;k<num_classes;k++) {
00852         double tpr = -1.0, fpr = -1.0, spc = -1.0;
00853         double TP = confusion_mats[j](k, k);
00854         double FP = confusion_mats[j].row(k).sum() - confusion_mats[j](k,k);
00855         double TN = confusion_mats[j].trace() - confusion_mats[j](k,k);
00856         double FN = confusion_mats[j].sum() - confusion_mats[j].row(k).sum() - TN;
00857         if((TP + FN) != 0)
00858           tpr = TP / (TP + FN);
00859         if((FP + TN) != 0)
00860           fpr = FP / (FP + TN);
00861         if(fpr != -1)
00862           spc = 1.0 - fpr;
00863         tpr_list[j].push_back(tpr);
00864         fpr_list[j].push_back(fpr);
00865         spc_list[j].push_back(spc);
00866       }
00867     }
00868 
00869     for(int k=0;k<num_classes;k++) {
00870       cout << "Class " << inv_label_cntr[k] << endl;
00871       cout << "ROC vals" << endl;
00872       for(int j=0;j<num_roc;j++)
00873         cout << roc_list[j] << ", ";
00874       cout << endl;
00875       cout << "TPR" << endl;
00876       for(int j=0;j<num_roc;j++)
00877         cout << tpr_list[j][k] << ", ";
00878       cout << endl;
00879       cout << "FPR" << endl;
00880       for(int j=0;j<num_roc;j++)
00881         cout << fpr_list[j][k] << ", ";
00882       cout << endl;
00883       cout << "SPC" << endl;
00884       for(int j=0;j<num_roc;j++)
00885         cout << spc_list[j][k] << ", ";
00886       cout << endl;
00887     }
00888   }
00889 
00890   void RandomForest::variableImportance() {
00891     loadDataset();
00892     loadForest();
00893     uint32_t num_feats = dataset->at(0)->features.size();
00894     int base_right = 0;
00895     int num_oobs = 0;
00896     vector<int> num_right(num_feats, 0);
00897     for(int i=0;i<num_trees;i++) {
00898       printf("Tree %d\n", i);
00899       int oobs_size = oobs[i].size();
00900       for(int j=0;j<oobs_size;j++) {
00901         int class_vote = trees[i]->classifyInstance(dataset->at(oobs[i][j]));
00902         base_right += class_vote == dataset->at(oobs[i][j])->label;
00903         num_oobs++;
00904       }
00905       uint32_t f;
00906       int j;
00907       int class_vote;
00908       for(f=0;f<num_feats;f++) {
00909         printf("Feature %d\n", f);
00910         vector<int> var_permute(oobs_size);
00911         vector<float> f_vals(oobs_size);
00912         for(j=0;j<oobs_size;j++) {
00913           var_permute[j] = oobs[i][j];
00914           f_vals[j] = dataset->at(var_permute[j])->features[f];
00915         }
00916         random_shuffle(var_permute.begin(), var_permute.end());
00917         int n_right = 0;
00918 #pragma omp parallel default(shared) private(j, class_vote) num_threads(10) reduction( + : n_right)
00919     {
00920 #pragma omp for schedule(dynamic, 10)
00921         for(j=0;j<oobs_size;j++) {
00922           dataset->at(oobs[i][j])->features[f] = f_vals[var_permute[j]];
00923           class_vote = trees[i]->classifyInstance(dataset->at(oobs[i][j]));
00924           n_right += class_vote == dataset->at(oobs[i][j])->label;
00925           dataset->at(oobs[i][j])->features[f] = f_vals[j];
00926         }
00927       }
00928         num_right[f] += n_right;
00929       }
00930 
00931     }
00932     for(uint32_t f=0;f<num_feats;f++) {
00933       float raw_score = (base_right - num_right[f]) / num_trees;
00934       printf("%4d score: %f\n", f, raw_score);
00935     }
00936   }
00937 
00938   void RandomForest::randomPermuteData() {
00939     loadDataset();
00940 
00941     rosbag::Bag bag;
00942     int bagmode;
00943     bagmode = rosbag::bagmode::Write;
00944     string bag_path, data_bag_name;
00945     nh_priv->getParam("data_bag_name", data_bag_name);
00946     nh_priv->getParam("bag_path", bag_path);
00947     string file_loc = bag_path + data_bag_name;
00948     bag.open(file_loc, bagmode);
00949     for(uint32_t i=0;i<dataset->size();i++) {
00950       SensorPoint sp;
00951       sp.features.resize(dataset->at(0)->features.size());
00952       for(uint32_t j=0;j<dataset->at(0)->features.size();j++)
00953         sp.features[j] = dataset->at(rand() % dataset->size())->features[j];
00954       bag.write("/collision_data", ros::Time::now(), sp);
00955     }
00956     bag.close();
00957   }
00958 
00959 
00960   /*void RandomForest::runTests(vector< SensorPoint::ConstPtr >* test_data, 
00961                               vector<int>* test_labels, int num_roc = 10) {
00962     MatrixXi class_votes = MatrixXi::Zero(test_data->size(), trees->at(0)->num_classes);
00963     for(uint32_t i=0;i<test_data->size();i++) {
00964       class_votes.row(i) = collectVotes(test_data->at(i));
00965     }
00966     if(num_roc > 0) {
00967       // binary case where we can vary threshold
00968       VectorXf tpr = VectorXf::Zero(num_roc);
00969       VectorXf fpr = VectorXf::Zero(num_roc);
00970       float cur_percent = 1.0 / (num_roc + 1);
00971       for(int32_t i=0;i<num_roc;i++) {
00972         MatrixXi confusion_mat = MatrixXi::Zero(2, 2);
00973         for(int32_t j=0;j<class_votes.rows();j++) {
00974           uint32_t pred_class = 0;
00975           if(class_votes(j,0) / (float) class_votes.row(j).sum() < cur_percent)
00976             pred_class = 1;
00977           uint32_t actual_class = test_labels[j];
00978           confusion_mat(actual_class, pred_class)++;
00979         }
00980         cout << "Percentage on class 0:" << cur_percent << endl;
00981         cout << confusion_mat << endl;
00982         cur_percent += 1.0 / (num_roc + 1);
00983       }
00984     } else {
00985       // multi-class case where we only consider popular vote
00986       MatrixXi confusion_mat = MatrixXi::Zero(trees->at(0)->num_classes, trees->at(0)->num_classes);
00987       for(int32_t i=0;i<class_votes.cols();i++) {
00988         int32_t pred_class = 0, pred_class_votes = 0;
00989         for(int32_t j=0;j<class_votes.rows();j++) {
00990           if(class_votes(j, i) > pred_class_votes) {
00991             pred_class = j;
00992             pred_class_votes = class_votes(j, i);
00993           }
00994         }
00995         uint32_t actual_class = test_data->at(i)->label;
00996         confusion_mat(actual_class, pred_class)++;
00997       }
00998       cout << confusion_mat << endl;
00999     }
01000   }*/
01001   double RandomForest::mahalanobisDist(LDLT<MatrixXd>* cov_inv, VectorXd& means, VectorXd& pt) {
01002     return sqrt( (pt - means).dot(cov_inv->solve(pt - means)) );
01003   }
01004 
01005   void RandomForest::doMahalanobis() {
01006     loadDataset();
01007     vector<vector< SensorPoint::Ptr > > datasets(1000); // split by label
01008     for(uint32_t i=0;i<dataset->size();i++) 
01009       datasets[dataset->at(i)->label].push_back(dataset->at(i));
01010     for(uint32_t i=0;i<datasets.size();i++) 
01011       if(datasets[i].size() == 0) {
01012         datasets.resize(i);
01013         num_classes = i;
01014         break;
01015       }
01016     datasets[0].resize(10000);
01017     datasets[1].resize(10000);
01018     int num_feats = datasets[0][0]->features.size();
01019 
01020     vector<LDLT<MatrixXd>* > cov_inv(num_classes);
01021 
01022     vector<VectorXd> means(num_classes);
01023     for(int k=0;k<num_classes;k++) {
01024       means[k] = VectorXd::Zero(num_feats);
01025       for(uint32_t i=0;i<datasets[k].size();i++) {
01026         for(int j=0;j<num_feats;j++) {
01027           means[k](j) += datasets[k][i]->features[j];
01028         }
01029       }
01030       means[k] /= datasets[k].size();
01031     }
01032     int k=0, j=0, l=0;
01033     uint32_t i=0;
01034     for(k=0;k<num_classes;k++) {
01035       //ROS_INFO("k %d, num_classes %d", j, num_classes);
01036       MatrixXd var_mat = MatrixXd::Zero(num_feats, num_feats);
01037 #pragma omp parallel default(shared) private(j, l, i) num_threads(10)
01038     {
01039 #pragma omp for schedule(dynamic, 10)
01040       for(j=0;j<num_feats;j++) {
01041         ROS_INFO("j %d", j);
01042         for(l=0;l<num_feats;l++) {
01043           for(i=0;i<datasets[k].size();i++) 
01044             var_mat(j,l) += (datasets[k][i]->features[j] - means[k](j)) * (datasets[k][i]->features[l] - means[k](l));
01045           var_mat(j,l) /= datasets[k].size();
01046         }
01047       }
01048     }
01049       //cout<< var_mat << endl;
01050       LDLT<MatrixXd>* qr = new LDLT<MatrixXd>(var_mat);
01051       cov_inv[k] = qr;
01052     }
01053     for(int k=0;k<num_classes;k++) {
01054       for(int l=0;l<num_classes;l++) {
01055         ArrayXd dists(datasets[l].size());
01056         VectorXd feat_vec(num_feats);
01057         int ind = 0, data_size = datasets[l].size();
01058 #pragma omp parallel default(shared) private(ind, j) num_threads(10)
01059     {
01060 #pragma omp for schedule(dynamic, 10)
01061         for(ind=0;ind<data_size;ind++) {
01062           for(j=0;j<num_feats;j++) 
01063             feat_vec(j) = datasets[l][ind]->features[j];
01064           dists[ind] = mahalanobisDist(cov_inv[k], means[k], feat_vec);
01065         }
01066     }
01067         int nans = 0;
01068         if(dists[i] != dists[i]) {
01069           dists[i] = 0;
01070           nans++;
01071         }
01072         //cout << dists << endl;
01073         double mean_dist = dists.sum() / (dists.size() - nans);
01074         VectorXd diff = dists - mean_dist;
01075         double var_dist = diff.dot(diff) / (dists.size() - nans);
01076         double std_dist = sqrt(var_dist);
01077         double min_dist = dists.minCoeff();
01078         double max_dist = dists.maxCoeff();
01079         printf("cov %d, data %d, mean_dist %f, std_dist %f, min %f, max %f, nans %d\n", k, l, mean_dist, std_dist, min_dist, max_dist, nans);
01080       }
01081       printf("cov %d, num_samps %d, rank 0\n", k, (int) datasets[k].size()); //, cov_inv[k]->rank());
01082     }
01083     printf("num_classes %d, num_feats %d\n", num_classes, num_feats);
01084   }
01085 
01086   void RandomForest::onInit() {
01087     nh = new ros::NodeHandle;
01088     nh_priv = new ros::NodeHandle("~");
01089     
01090     std::string results_topic, classify_topic, data_bag, forest_bag, loaded_topic;
01091     bool random_permute, training_mode, means, variable_import;
01092 
01093     nh_priv->param<bool>("is_abs", is_abs, false);
01094 
01095     nh_priv->param<bool>("variable_import", variable_import, false);
01096     if(variable_import) {
01097       variableImportance();
01098       return;
01099     }
01100 
01101     nh_priv->param<bool>("random_permute", random_permute, false);
01102     if(random_permute) {
01103       randomPermuteData();
01104       return;
01105     }
01106 
01107     nh_priv->param<bool>("mahalanobis", means, false);
01108     if(means) {
01109       doMahalanobis();
01110       return;
01111     }
01112 
01113     nh_priv->param<bool>("training_mode", training_mode, false);
01114     if(training_mode) {
01115       growWriteForest();
01116       return;
01117     }
01118 
01119     nh_priv->getParam("classify_topic", classify_topic);
01120     nh_priv->getParam("results_topic", results_topic);
01121     nh_priv->getParam("loaded_topic", loaded_topic);
01122     nh_priv->getParam("classifier_id", classifier_id);
01123     nh_priv->getParam("classifier_name", classifier_name);
01124 
01125     classify_sub = nh->subscribe(classify_topic.c_str(), 2, 
01126                         &RandomForest::classifyCallback, this);
01127     ROS_INFO("[random_forest] Subscribed to %s", classify_topic.c_str());
01128     results_pub = nh->advertise<ClassVotes>(results_topic, 1);
01129     ROS_INFO("[random_forest] Publishing on %s", results_topic.c_str());
01130     loaded_pub = nh->advertise<std_msgs::Bool>(loaded_topic, 1);
01131     ROS_INFO("[random_forest] Publishing on %s", loaded_topic.c_str());
01132 
01133     trees_loaded = false;
01134     loadForest();
01135   }
01136 
01137 }
01138 
01139 using namespace collision_detection;
01140 
01141 void INTHandler(int sig);
01142 
01143 void INTHandler(int sig) {
01144   char c;
01145   signal(sig, SIG_IGN);
01146   printf("Want to exit?\n");
01147   c = getchar();
01148   if(c == 'y' || c == 'Y')
01149     exit(0);
01150   else
01151     signal(SIGINT, INTHandler);
01152 }
01153 
01154 int main(int argc, char** argv) {
01155   ros::init(argc, argv, "random_forest", ros::init_options::AnonymousName);
01156   //signal(SIGINT, INTHandler);
01157 
01158   RandomForest rf;
01159   rf.onInit();
01160   ros::spin();
01161   printf("Exiting\n");
01162 
01163   /*
01164   int num_attrs = 10;
01165   vector<SensorPoint::ConstPtr>* train = new vector<SensorPoint::ConstPtr>;
01166   for(int i=0;i<100;i++) {
01167     SensorPoint::Ptr add_data(new SensorPoint());
01168     add_data->features.resize(num_attrs);
01169     for(int j=0;j<num_attrs;j++) // add positives
01170       add_data->features[j] = 5.0 * rand() / double(RAND_MAX);
01171     add_data->traj_id = i / 10;
01172     add_data->label = 234;
01173     train->push_back(add_data);
01174   }
01175   for(int i=0;i<100;i++) {
01176     SensorPoint::Ptr add_data(new SensorPoint());
01177     add_data->features.resize(num_attrs);
01178     for(int j=0;j<num_attrs;j++) // add positives
01179       add_data->features[j] = 5.0 * rand() / double(RAND_MAX) + 5.1;
01180     add_data->traj_id = i / 10 + 10;
01181     add_data->label = 234;
01182     train->push_back(add_data);
01183   }
01184   for(int i=0;i<100;i++) {
01185     SensorPoint::Ptr add_data(new SensorPoint());
01186     add_data->features.resize(num_attrs);
01187     for(int j=0;j<num_attrs;j++) // add positives
01188       add_data->features[j] = 5.0 * rand() / double(RAND_MAX) + 70.0;
01189     add_data->traj_id = i / 10 + 20;
01190     add_data->label = 234;
01191     train->push_back(add_data);
01192   }
01193   for(int i=0;i<3000;i++) {
01194     SensorPoint::Ptr add_data(new SensorPoint());
01195     add_data->features.resize(num_attrs);
01196     for(int j=0;j<num_attrs;j++) // add negatives
01197       add_data->features[j] = 5.1 * (rand() / double(RAND_MAX));
01198     add_data->traj_id = i / 10 + 30;
01199     add_data->label = 51;
01200     train->push_back(add_data);
01201   }
01202 
01203   vector<SensorPoint::ConstPtr>* test = new vector<SensorPoint::ConstPtr>;
01204   for(int i=0;i<1000;i++) {
01205     SensorPoint::Ptr add_data(new SensorPoint());
01206     add_data->features.resize(num_attrs);
01207     for(int j=0;j<num_attrs;j++) // add positives
01208       add_data->features[j] = 5.0 * rand() / double(RAND_MAX);
01209     add_data->label = 1;
01210     test->push_back(add_data);
01211   }
01212   for(int i=0;i<1000;i++) {
01213     SensorPoint::Ptr add_data(new SensorPoint());
01214     add_data->features.resize(num_attrs);
01215     for(int j=0;j<num_attrs;j++) // add positives
01216       add_data->features[j] = 5.0 * rand() / double(RAND_MAX) + 20.0;
01217     add_data->label = 1;
01218     test->push_back(add_data);
01219   }
01220   for(int i=0;i<1000;i++) {
01221     SensorPoint::Ptr add_data(new SensorPoint());
01222     add_data->features.resize(num_attrs);
01223     for(int j=0;j<num_attrs;j++) // add positives
01224       add_data->features[j] = 5.0 * rand() / double(RAND_MAX) + 70.0;
01225     add_data->label = 1;
01226     test->push_back(add_data);
01227   }
01228   for(int i=0;i<3000;i++) {
01229     SensorPoint::Ptr add_data(new SensorPoint());
01230     add_data->features.resize(num_attrs);
01231     for(int j=0;j<num_attrs;j++) // add negatives
01232       add_data->features[j] = 10.0 * (rand() / double(RAND_MAX));
01233     add_data->label = 0;
01234     test->push_back(add_data);
01235   }
01236 
01237   if(0) {
01238     vector<int>* inds = new vector<int>(train->size());
01239     for(uint32_t i=0;i<train->size();i++) 
01240       inds->at(i) = i;
01241     RandomTree rt(0);
01242     rt.growTree(train, inds);
01243     for(int i=0;i<6000;i++)
01244       cout << rt.classifyInstance(test->at(i)) << ", " << train->at(i)->label << endl;
01245   }
01246   
01247   if(1) {
01248     vector<int>* inds = new vector<int>(train->size());
01249     for(uint32_t i=0;i<train->size();i++) 
01250       inds->at(i) = i;
01251     RandomForest rf;
01252     //rf.growForest(train, inds, 100);
01253     vector<int>* inds_test = new vector<int>(test->size());
01254     for(uint32_t i=0;i<test->size();i++) 
01255       inds_test->at(i) = i;
01256     RandomForest::runTenFold(train, 234, 20, 11);
01257     //rf.runTests(test, inds_test, 10);
01258   }*/
01259 
01260 
01261   return 0;
01262 }


kelsey_sandbox
Author(s): kelsey
autogenerated on Wed Nov 27 2013 11:52:04