Go to the documentation of this file.00001 #ifndef PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H
00002 #define PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H
00003 #include <ros/ros.h>
00004 #include "sensor_msgs/JointState.h"
00005 #include "std_msgs/UInt16MultiArray.h"
00006 #include "std_msgs/UInt32MultiArray.h"
00007 #include "std_msgs/Float32MultiArray.h"
00008 #include "std_msgs/Float32.h"
00009 #include "std_msgs/MultiArrayLayout.h"
00010 #include "std_msgs/MultiArrayDimension.h"
00011 #include "std_msgs/UInt16.h"
00012 #include "std_msgs/Bool.h"
00013 #include "pr2_msgs/AccelerometerState.h"
00014 #include "pr2_msgs/PressureState.h"
00015 #include "pr2_controllers_msgs/JointTrajectoryControllerState.h"
00016 #include "boost/thread/thread.hpp"
00017 #include "boost/format.hpp"
00018 #include "boost/foreach.hpp"
00019 #include "rosbag/bag.h"
00020 #include "rosbag/view.h"
00021
00022
00023 #include "pr2_overhead_grasping/SensorPoint.h"
00024 #include "pr2_overhead_grasping/ClassVotes.h"
00025 #include "pr2_overhead_grasping/CollisionDescription.h"
00026 #include "pr2_overhead_grasping/RandomTreeMsg.h"
00027 #include "pr2_overhead_grasping/CovarianceMatrix.h"
00028 #include "roslib/Header.h"
00029 #include <ros/package.h>
00030 #include <std_srvs/Empty.h>
00031 #include <math.h>
00032 #include <nodelet/nodelet.h>
00033 #include <algorithm>
00034 #include <vector>
00035 #include <Eigen/Eigen>
00036 #include <Eigen/Dense>
00037 #include <Eigen/QR>
00038 #include <Eigen/Cholesky>
00039
00040 using namespace Eigen;
00041
00042 using namespace std;
00043 using namespace pr2_overhead_grasping;
00044
00045 namespace collision_detection {
00046
00047 struct DistFinder {
00048 MatrixXd V;
00049 VectorXd D;
00050 VectorXd means;
00051 void makeInv(MatrixXd& var_mat, VectorXd& c_means, double min_eig_val=0.01) {
00052 SelfAdjointEigenSolver<MatrixXd> solver(var_mat);
00053 D = solver.eigenvalues();
00054 double thresh = D.maxCoeff() * 0.65;
00055 double max_thresh = D.maxCoeff() * 0.18;
00056 int rank = D.size();
00057 cout << "Eigenvalues";
00058 for(int i=0;i<D.size();i++)
00059 printf("%e ", D(i));
00060 cout << endl;
00061 for(uint32_t i=0;i<D.size();i++) {
00062 if((D(i) > thresh || D(i) < max_thresh) && false) {
00063
00064 D(i) = 0;
00065 rank--;
00066 }
00067 else
00068 D(i) = 1.0 / D(i);
00069 }
00070 ROS_INFO("Rank: %d", rank);
00071 V = solver.eigenvectors();
00072 means = c_means;
00073 }
00074 double dist(VectorXd& pt) {
00075
00076 static VectorXd left_prod;
00077 left_prod = (pt - means) * V;
00078 return sqrt((left_prod * D.asDiagonal()).dot(left_prod.transpose()));
00079 }
00080 };
00081
00082 class MahalanobisDist {
00083 public:
00084 MahalanobisDist() {}
00085 ~MahalanobisDist();
00086 void loadDataset();
00087 void loadDataBag(string& data_bag, int label);
00088 void classifyCallback(const boost::shared_ptr<SensorPoint>& inst);
00089 static void runTenFold(vector< SensorPoint::ConstPtr >* train_test_data,
00090 int roc_id,
00091 int num_trees,
00092 vector<map<int, int> >& votes_total,
00093 bool classify_first=true);
00094 static int findFirstClass(vector<pair<map<int, int>, float > >* votes_list,
00095 int pos_id, float thresh);
00096 static int findFrequentClass(vector<pair<map<int, int>, float > >* votes_list,
00097 int pos_id, float thresh);
00098 void onInit();
00099
00100 void loadCovMat();
00101 void doMahalanobis();
00102 double mahalanobisDist(MatrixXd& cov_inv, VectorXd& means, VectorXd& pt);
00103 void createCovMat();
00104 void makeInv(MatrixXd& A, MatrixXd& A_inv, double min_eig_val=0.0001);
00105 void summarizeData();
00106
00107
00108 protected:
00109
00110 ros::NodeHandle* nh;
00111 ros::NodeHandle* nh_priv;
00112
00113 int num_classes;
00114 string classifier_name;
00115 int classifier_id;
00116
00117 void saveCovMat(MatrixXd& var_mat, VectorXd& means);
00118 DistFinder cov_inv;
00119 VectorXd means;
00120 vector< SensorPoint::ConstPtr >* dataset;
00121 vector<int> labels;
00122 ros::Subscriber classify_sub;
00123 ros::Publisher results_pub;
00124 ros::Publisher loaded_pub;
00125 boost::thread setup_thread;
00126 bool classifier_loaded;
00127 };
00128 }
00129
00130
00131 #endif // PR2_OVERHEAD_GRASPING_RANDOM_FOREST_H