Go to the documentation of this file.00001 #ifndef PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H
00002 #define PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H
00003 #include <ros/ros.h>
00004 #include "sensor_msgs/JointState.h"
00005 #include "std_msgs/UInt16MultiArray.h"
00006 #include "std_msgs/UInt32MultiArray.h"
00007 #include "std_msgs/Float32MultiArray.h"
00008 #include "std_msgs/MultiArrayLayout.h"
00009 #include "std_msgs/MultiArrayDimension.h"
00010 #include "std_msgs/UInt16.h"
00011 #include "std_msgs/Bool.h"
00012 #include "pr2_msgs/AccelerometerState.h"
00013 #include "pr2_msgs/PressureState.h"
00014 #include "pr2_controllers_msgs/JointTrajectoryControllerState.h"
00015 #include "boost/thread/thread.hpp"
00016 #include "boost/format.hpp"
00017 #include "boost/foreach.hpp"
00018 #include "rosbag/bag.h"
00019 #include "rosbag/view.h"
00020
00021
00022 #include "pr2_overhead_grasping/SensorPoint.h"
00023 #include "pr2_overhead_grasping/ClassVotes.h"
00024 #include "pr2_overhead_grasping/CollisionDescription.h"
00025 #include "pr2_overhead_grasping/FoldData.h"
00026 #include "pr2_overhead_grasping/RandomTreeMsg.h"
00027 #include "pr2_overhead_grasping/CovarianceMatrix.h"
00028 #include "roslib/Header.h"
00029 #include <ros/package.h>
00030 #include <std_srvs/Empty.h>
00031 #include <math.h>
00032 #include <nodelet/nodelet.h>
00033 #include <algorithm>
00034 #include <vector>
00035 #include <Eigen/Eigen>
00036 #include <Eigen/Dense>
00037 #include <Eigen/QR>
00038 #include <Eigen/Cholesky>
00039
00040 using namespace Eigen;
00041
00042 using namespace std;
00043 using namespace pr2_overhead_grasping;
00044
00045 namespace collision_detection {
00046
00047 typedef VectorXf Dynamic1D;
00048
00049 const int NUM_ATTRS = 1120;
00050
00051 class RandomTree {
00052 public:
00053 int d_tree_num;
00054 int num_classes;
00055 vector< SensorPoint::Ptr >* dataset;
00056
00057 RandomTree(int c_d_tree_num);
00058 RandomTree(RandomTreeMsg::Ptr);
00059 void growTree(vector< SensorPoint::Ptr >* c_dataset,
00060 vector<int>* inds);
00061 int classifyInstance(SensorPoint::Ptr inst);
00062 void writeTree(string& bag_file, bool is_first);
00063
00064 bool is_abs;
00065
00066 protected:
00067
00068 RandomTreeMsg::Ptr rand_tree;
00069
00070 bool attrCompare(int inst_i, int inst_j, int attr);
00071 void findBestSplit(vector<int>* insts, vector<int>& attrs, pair<int, float>& ret);
00072 void splitNode(vector<int>* node_inds,
00073 pair<int, float>& split_pt,
00074 pair<vector<int>*, vector<int>* >& split_nodes);
00075 };
00076
00077 class RandomForest {
00078 public:
00079 RandomForest() {}
00080 ~RandomForest();
00081 void loadDataset();
00082 void loadDataBag(string& data_bag, int label);
00083 void growForest(vector< SensorPoint::Ptr >* c_dataset,
00084 vector<int>* inds, int c_num_trees=100);
00085 void growWriteForest();
00086 void loadForest();
00087 void collectVotes(SensorPoint::Ptr inst, map<int, int>& class_votes);
00088 void writeForest();
00089 void writeForest(string file);
00090 void classifyCallback(const boost::shared_ptr<SensorPoint>& inst);
00091
00092 void setDataset(vector< SensorPoint::Ptr >* datas);
00093 static void runTenFold(vector< SensorPoint::Ptr >* train_test_data,
00094 int roc_id,
00095 int num_trees,
00096 vector<map<int, int> >& votes_total,
00097 bool classify_first=true);
00098 static int findFirstClass(vector<pair<map<int, int>, float > >* votes_list,
00099 int pos_id, float thresh);
00100 static int findFrequentClass(vector<pair<map<int, int>, float > >* votes_list,
00101 int pos_id, float thresh);
00102 void variableImportance();
00103 void randomPermuteData();
00104 void onInit();
00105
00106 void loadCovMat();
00107 void doMahalanobis();
00108 double mahalanobisDist(LDLT<MatrixXd>* cov_inv, VectorXd& means, VectorXd& pt);
00109 void createCovMat();
00110
00111
00112 protected:
00113
00114 ros::NodeHandle* nh;
00115 ros::NodeHandle* nh_priv;
00116
00117 int num_trees, num_classes;
00118 string classifier_name;
00119 int classifier_id;
00120
00121 RandomTree** trees;
00122 vector<vector<uint32_t> > oobs;
00123 LDLT<MatrixXd>* cov_inv;
00124 VectorXd means;
00125 vector< SensorPoint::Ptr >* dataset;
00126 vector<int> labels;
00127 ros::Subscriber classify_sub;
00128 ros::Publisher results_pub;
00129 ros::Publisher loaded_pub;
00130 boost::thread setup_thread;
00131 bool trees_loaded;
00132 bool is_abs;
00133 };
00134 }
00135
00136
00137 #endif // PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H