Go to the documentation of this file.00001 #include "pr2_overhead_grasping/svm_one_class.h"
00002
00003 #include <omp.h>
00004 #include <stdio.h>
00005 #include <signal.h>
00006
00007
00008 using namespace std;
00009 using namespace pr2_overhead_grasping;
00010 using namespace std_msgs;
00011
00012
00013 namespace collision_detection {
00014
00015 void SVMOneClass::loadDataset() {
00016 dataset = new vector< SensorPoint::ConstPtr >;
00017 string bag_path;
00018 XmlRpc::XmlRpcValue bag_names, bag_labels;
00019 nh_priv->getParam("bag_path", bag_path);
00020 nh_priv->getParam("bag_names", bag_names);
00021 nh_priv->getParam("bag_labels", bag_labels);
00022 map<int, int> labels_dict;
00023 for(int i=0;i<bag_names.size();i++) {
00024 string file_loc = bag_path + (string) bag_names[i];
00025 loadDataBag(file_loc, bag_labels[i]);
00026 labels_dict[bag_labels[i]] = 0;
00027 }
00028 num_classes = labels_dict.size();
00029 printf("%d\n", num_classes);
00030 }
00031
00032 void SVMOneClass::loadDataBag(string& data_bag, int label) {
00033
00034 rosbag::Bag bag(data_bag);
00035 rosbag::View view(bag, rosbag::TopicQuery("/collision_data"));
00036 BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00037 SensorPoint::Ptr sp = m.instantiate<SensorPoint>();
00038 if(sp != NULL) {
00039 sp->label = label;
00040
00041
00042 int ind = 340;
00043 for(int i =420;i<920;i++)
00044 sp->features[ind++] = (sp->features[i]);
00045 for(int i =1000;i<(int) sp->features.size();i++)
00046 sp->features[ind++] = (sp->features[i]);
00047 sp->features.resize(sp->features.size() - 160);
00048
00049 dataset->push_back(sp);
00050 }
00051 }
00052 assert(dataset->size() != 0);
00053 }
00054
00055 SVMOneClass::~SVMOneClass() {
00056 }
00057
00058 void SVMOneClass::generateModel() {
00059 loadDataset();
00060
00061 uint32_t num_one_class = 0;
00062 for(uint32_t i=0;i<dataset->size();i++)
00063 if(dataset->at(i)->label == 0)
00064 num_one_class++;
00065 num_one_class = 4000;
00066 svm_problem* prob = new svm_problem;
00067 prob->l = num_one_class;
00068 prob->x = new svm_node*[num_one_class];
00069 prob->y = new double[num_one_class];
00070 uint32_t num_feats = dataset->at(0)->features.size();
00071
00072 int ind = 0;
00073 for(uint32_t i=0;i<dataset->size();i++) {
00074 if(dataset->at(i)->label != 0)
00075 continue;
00076 if(ind == num_one_class)
00077 break;
00078 prob->y[ind] = 1.0;
00079 prob->x[ind] = new svm_node[num_feats+1];
00080 for(uint32_t j=0;j<num_feats;j++) {
00081 prob->x[ind][j].index = j;
00082 prob->x[ind][j].value = dataset->at(i)->features[j];
00083 }
00084 prob->x[ind++][num_feats].index = -1;
00085 }
00086 printf("%d %d\n", num_one_class, ind);
00087
00088
00089
00090
00091 svm_parameter param;
00092 param.svm_type = ONE_CLASS;
00093 param.kernel_type = RBF;
00094 param.degree = 3;
00095 param.gamma = 0.001 / num_feats;
00096 param.coef0 = 0;
00097 param.nu = 0.98;
00098 param.cache_size = 8000;
00099 param.C = 1;
00100 param.eps = 1e-3;
00101 param.p = 0.1;
00102 param.shrinking = 1;
00103 param.probability = 0;
00104 param.nr_weight = 0;
00105 param.weight_label = new int[1];
00106 param.weight_label[0] = 0;
00107 param.weight = new double[1];
00108 param.weight[0] = 1;
00109
00110 ROS_INFO("Train in");
00111 svm_model* model = svm_train(prob, ¶m);
00112 ROS_INFO("Train out");
00113 int cm[2][2];
00114 cm[0][0] = 0; cm[1][0] = 0; cm[0][1] = 0; cm[1][1] = 0;
00115 for(uint32_t i=0;i<dataset->size();i++) {
00116 svm_node* inst = new svm_node[num_feats+1];
00117 for(uint32_t j=0;j<num_feats;j++) {
00118 inst[j].index = j;
00119 inst[j].value = dataset->at(i)->features[j];
00120 }
00121 inst[num_feats].index = -1;
00122 int pred = (int) svm_predict(model, inst);
00123 cm[dataset->at(i)->label != 0][pred > 0]++;
00124 delete[] inst;
00125 }
00126 printf("%4d %4d\n", cm[1][1], cm[0][1]);
00127 printf("%4d %4d\n", cm[1][0], cm[0][0]);
00128
00129
00130
00131
00132 ROS_INFO("Test out");
00133 }
00134
00135 void SVMOneClass::onInit() {
00136 nh = new ros::NodeHandle;
00137 nh_priv = new ros::NodeHandle("~");
00138
00139 std::string results_topic, classify_topic, data_bag, forest_bag, loaded_topic;
00140 bool training_mode, is_validation, is_data_summary;
00141
00142 nh_priv->getParam("is_validation", is_validation);
00143 if(is_validation) {
00144 return;
00145 }
00146
00147 nh_priv->getParam("training_mode", training_mode);
00148 if(training_mode) {
00149 generateModel();
00150 return;
00151 }
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171 }
00172
00173 }
00174
00175 using namespace collision_detection;
00176
00177 void INTHandler(int sig);
00178
00179 void INTHandler(int sig) {
00180 char c;
00181 signal(sig, SIG_IGN);
00182 printf("Want to exit?\n");
00183 c = getchar();
00184 if(c == 'y' || c == 'Y')
00185 exit(0);
00186 else
00187 signal(SIGINT, INTHandler);
00188 }
00189
00190 int main(int argc, char** argv) {
00191 ros::init(argc, argv, "mahalanobis_dist", ros::init_options::AnonymousName);
00192
00193
00194 SVMOneClass md;
00195 md.onInit();
00196 ros::spin();
00197 printf("Exiting\n");
00198 return 0;
00199 }