random_forest.h
Go to the documentation of this file.
00001 #ifndef PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H
00002 #define PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H
00003 #include <ros/ros.h>
00004 #include "sensor_msgs/JointState.h"
00005 #include "std_msgs/UInt16MultiArray.h"
00006 #include "std_msgs/UInt32MultiArray.h"
00007 #include "std_msgs/Float32MultiArray.h"
00008 #include "std_msgs/MultiArrayLayout.h"
00009 #include "std_msgs/MultiArrayDimension.h"
00010 #include "std_msgs/UInt16.h"
00011 #include "std_msgs/Bool.h"
00012 #include "pr2_msgs/AccelerometerState.h"
00013 #include "pr2_msgs/PressureState.h"
00014 #include "pr2_controllers_msgs/JointTrajectoryControllerState.h"
00015 #include "boost/thread/thread.hpp"
00016 #include "boost/format.hpp"
00017 #include "boost/foreach.hpp"
00018 #include "rosbag/bag.h"
00019 #include "rosbag/view.h"
00020 //#include "sensor_msgs/Float32MultiArray.h"
00021 //#include "sensor_msgs/MultiArrayDimension.h"
00022 #include "pr2_overhead_grasping/SensorPoint.h"
00023 #include "pr2_overhead_grasping/ClassVotes.h"
00024 #include "pr2_overhead_grasping/CollisionDescription.h"
00025 #include "pr2_overhead_grasping/FoldData.h"
00026 #include "pr2_overhead_grasping/RandomTreeMsg.h"
00027 #include "pr2_overhead_grasping/CovarianceMatrix.h"
00028 #include "roslib/Header.h"
00029 #include <ros/package.h>
00030 #include <std_srvs/Empty.h>
00031 #include <math.h>
00032 #include <nodelet/nodelet.h>
00033 #include <algorithm>
00034 #include <vector>
00035 #include <Eigen/Eigen>
00036 #include <Eigen/Dense>
00037 #include <Eigen/QR>
00038 #include <Eigen/Cholesky>
00039 //USING_PART_OF_NAMESPACE_EIGEN
00040 using namespace Eigen;
00041 
00042 using namespace std;
00043 using namespace pr2_overhead_grasping;
00044 
00045 namespace collision_detection {
00046 
00047   typedef VectorXf Dynamic1D;
00048 
00049   const int NUM_ATTRS = 1120;
00050 
00051   class RandomTree { 
00052     public:
00053       int d_tree_num;
00054       int num_classes;
00055       vector< SensorPoint::Ptr >* dataset;
00056 
00057       RandomTree(int c_d_tree_num);
00058       RandomTree(RandomTreeMsg::Ptr);
00059       void growTree(vector< SensorPoint::Ptr >* c_dataset,
00060                             vector<int>* inds);
00061       int classifyInstance(SensorPoint::Ptr inst);
00062       void writeTree(string& bag_file, bool is_first);
00063 
00064       bool is_abs;
00065 
00066     protected:
00067       
00068       RandomTreeMsg::Ptr rand_tree;
00069 
00070       bool attrCompare(int inst_i, int inst_j, int attr);
00071       void findBestSplit(vector<int>* insts, vector<int>& attrs, pair<int, float>& ret);
00072       void splitNode(vector<int>* node_inds, 
00073                      pair<int, float>& split_pt,
00074                      pair<vector<int>*, vector<int>* >& split_nodes);
00075   };
00076 
00077   class RandomForest  {
00078     public:
00079       RandomForest() {}
00080       ~RandomForest();
00081       void loadDataset();
00082       void loadDataBag(string& data_bag, int label);
00083       void growForest(vector< SensorPoint::Ptr >* c_dataset,
00084                             vector<int>* inds, int c_num_trees=100);
00085       void growWriteForest();
00086       void loadForest();
00087       void collectVotes(SensorPoint::Ptr inst, map<int, int>& class_votes);
00088       void writeForest();
00089       void writeForest(string file);
00090       void classifyCallback(const boost::shared_ptr<SensorPoint>& inst);
00091       //void runTests(vector< SensorPoint::ConstPtr >* test_data, vector<int>* test_labels, int num_roc);
00092       void setDataset(vector< SensorPoint::Ptr >* datas);
00093       static void runTenFold(vector< SensorPoint::Ptr >* train_test_data, 
00094                                          int roc_id,
00095                                          int num_trees,
00096                                          vector<map<int, int> >& votes_total,
00097                                          bool classify_first=true);
00098       static int findFirstClass(vector<pair<map<int, int>, float > >* votes_list, 
00099                                 int pos_id, float thresh);
00100       static int findFrequentClass(vector<pair<map<int, int>, float > >* votes_list, 
00101                                 int pos_id, float thresh);
00102       void variableImportance();
00103       void randomPermuteData();
00104       void onInit();
00105 
00106       void loadCovMat();
00107       void doMahalanobis();
00108       double mahalanobisDist(LDLT<MatrixXd>* cov_inv, VectorXd& means, VectorXd& pt);
00109       void createCovMat();
00110 
00111 
00112     protected:
00113 
00114       ros::NodeHandle* nh;
00115       ros::NodeHandle* nh_priv;
00116 
00117       int num_trees, num_classes;
00118       string classifier_name;
00119       int classifier_id;
00120       //vector<RandomTree* >* trees;
00121       RandomTree** trees;
00122       vector<vector<uint32_t> > oobs;
00123       LDLT<MatrixXd>* cov_inv;
00124       VectorXd means;
00125       vector< SensorPoint::Ptr >* dataset;
00126       vector<int> labels;
00127       ros::Subscriber classify_sub;
00128       ros::Publisher results_pub;
00129       ros::Publisher loaded_pub;
00130       boost::thread setup_thread;
00131       bool trees_loaded;
00132       bool is_abs;
00133   };
00134 }
00135 
00136 
00137 #endif // PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H


kelsey_sandbox
Author(s): kelsey
autogenerated on Wed Nov 27 2013 11:52:04