svm_one_class.cpp
Go to the documentation of this file.
00001 #include "pr2_overhead_grasping/svm_one_class.h"
00002 //#include <pluginlib/class_list_macros.h>
00003 #include <omp.h>
00004 #include <stdio.h>
00005 #include <signal.h>
00006 //PLUGINLIB_DECLARE_CLASS(collision_detection, mahalanobis_dist, collision_detection::SVMOneClass, nodelet::Nodelet)
00007 
00008 using namespace std;
00009 using namespace pr2_overhead_grasping;
00010 using namespace std_msgs;
00011 
00012 
00013 namespace collision_detection {
00014 
00015   void SVMOneClass::loadDataset() {
00016     dataset = new vector< SensorPoint::ConstPtr >;
00017     string bag_path;
00018     XmlRpc::XmlRpcValue bag_names, bag_labels;
00019     nh_priv->getParam("bag_path", bag_path);
00020     nh_priv->getParam("bag_names", bag_names);
00021     nh_priv->getParam("bag_labels", bag_labels);
00022     map<int, int> labels_dict;
00023     for(int i=0;i<bag_names.size();i++) {
00024       string file_loc = bag_path + (string) bag_names[i];
00025       loadDataBag(file_loc, bag_labels[i]);
00026       labels_dict[bag_labels[i]] = 0;
00027     }
00028     num_classes = labels_dict.size();
00029     printf("%d\n", num_classes);
00030   }
00031 
00032   void SVMOneClass::loadDataBag(string& data_bag, int label) {
00033     // load dataset
00034     rosbag::Bag bag(data_bag);
00035     rosbag::View view(bag, rosbag::TopicQuery("/collision_data"));
00036     BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00037       SensorPoint::Ptr sp = m.instantiate<SensorPoint>();
00038       if(sp != NULL) {
00039         sp->label = label;
00040 
00041         // remove fingertip data
00042         int ind = 340;
00043         for(int i =420;i<920;i++)
00044           sp->features[ind++] = (sp->features[i]);
00045         for(int i =1000;i<(int) sp->features.size();i++)
00046           sp->features[ind++] = (sp->features[i]);
00047         sp->features.resize(sp->features.size() - 160);
00048 
00049         dataset->push_back(sp);
00050       }
00051     }
00052     assert(dataset->size() != 0);
00053   }
00054 
00055   SVMOneClass::~SVMOneClass() {
00056   }
00057 
00058   void SVMOneClass::generateModel() {
00059     loadDataset();
00060     //dataset->resize(100);
00061     uint32_t num_one_class = 0;
00062     for(uint32_t i=0;i<dataset->size();i++) 
00063       if(dataset->at(i)->label == 0)
00064         num_one_class++;
00065     num_one_class = 4000;
00066     svm_problem* prob = new svm_problem;
00067     prob->l = num_one_class;
00068     prob->x = new svm_node*[num_one_class];
00069     prob->y = new double[num_one_class];
00070     uint32_t num_feats = dataset->at(0)->features.size();
00071     //num_feats = 30;
00072     int ind = 0;
00073     for(uint32_t i=0;i<dataset->size();i++) {
00074       if(dataset->at(i)->label != 0)
00075         continue;
00076       if(ind == num_one_class)
00077         break;
00078       prob->y[ind] = 1.0;
00079       prob->x[ind] = new svm_node[num_feats+1];
00080       for(uint32_t j=0;j<num_feats;j++) {
00081         prob->x[ind][j].index = j;
00082         prob->x[ind][j].value = dataset->at(i)->features[j];
00083       }
00084       prob->x[ind++][num_feats].index = -1;
00085     }
00086     printf("%d %d\n", num_one_class, ind);
00087     //prob->y[0] = -1.0;
00088     //prob->y[5] = -1.0;
00089     //prob->y[3] = -1.0;
00090 
00091     svm_parameter param;
00092     param.svm_type = ONE_CLASS;
00093     param.kernel_type = RBF;
00094     param.degree = 3;
00095     param.gamma = 0.001 / num_feats;  // 1/num_features
00096     param.coef0 = 0;
00097     param.nu = 0.98;
00098     param.cache_size = 8000;
00099     param.C = 1;
00100     param.eps = 1e-3;
00101     param.p = 0.1;
00102     param.shrinking = 1;
00103     param.probability = 0;
00104     param.nr_weight = 0;
00105     param.weight_label = new int[1];
00106     param.weight_label[0] = 0;
00107     param.weight = new double[1];
00108     param.weight[0] = 1;
00109 
00110     ROS_INFO("Train in");
00111     svm_model* model = svm_train(prob, &param);
00112     ROS_INFO("Train out");
00113     int cm[2][2];
00114     cm[0][0] = 0; cm[1][0] = 0; cm[0][1] = 0; cm[1][1] = 0; 
00115     for(uint32_t i=0;i<dataset->size();i++) {
00116       svm_node* inst = new svm_node[num_feats+1];
00117       for(uint32_t j=0;j<num_feats;j++) {
00118         inst[j].index = j;
00119         inst[j].value = dataset->at(i)->features[j];
00120       }
00121       inst[num_feats].index = -1;
00122       int pred = (int) svm_predict(model, inst);
00123       cm[dataset->at(i)->label != 0][pred > 0]++;
00124       delete[] inst;
00125     }
00126     printf("%4d %4d\n", cm[1][1], cm[0][1]);
00127     printf("%4d %4d\n", cm[1][0], cm[0][0]);
00128     /*double* target = new double[dataset->size()];
00129     svm_cross_validation(prob, &param, 2, target);
00130     for(uint32_t i=0;i<dataset->size();i++) 
00131       printf("%f, ", target[i]); */
00132     ROS_INFO("Test out");
00133   }
00134 
00135   void SVMOneClass::onInit() {
00136     nh = new ros::NodeHandle;
00137     nh_priv = new ros::NodeHandle("~");
00138     
00139     std::string results_topic, classify_topic, data_bag, forest_bag, loaded_topic;
00140     bool training_mode, is_validation, is_data_summary;
00141 
00142     nh_priv->getParam("is_validation", is_validation);
00143     if(is_validation) {
00144       return;
00145     }
00146 
00147     nh_priv->getParam("training_mode", training_mode);
00148     if(training_mode) {
00149       generateModel();
00150       return;
00151     }
00152 
00153     /*
00154     nh_priv->getParam("classify_topic", classify_topic);
00155     nh_priv->getParam("results_topic", results_topic);
00156     nh_priv->getParam("loaded_topic", loaded_topic);
00157     nh_priv->getParam("classifier_id", classifier_id);
00158     nh_priv->getParam("classifier_name", classifier_name);
00159 
00160     classify_sub = nh->subscribe(classify_topic.c_str(), 2, 
00161                         &SVMOneClass::classifyCallback, this);
00162     ROS_INFO("[mahalanobis_dist] Subscribed to %s", classify_topic.c_str());
00163     results_pub = nh->advertise<Float32>(results_topic, 1);
00164     ROS_INFO("[mahalanobis_dist] Publishing on %s", results_topic.c_str());
00165     loaded_pub = nh->advertise<Bool>(loaded_topic, 1);
00166     ROS_INFO("[mahalanobis_dist] Publishing on %s", loaded_topic.c_str());
00167 
00168     classifier_loaded = false;
00169     loadCovMat();
00170     */
00171   }
00172 
00173 }
00174 
00175 using namespace collision_detection;
00176 
00177 void INTHandler(int sig);
00178 
00179 void INTHandler(int sig) {
00180   char c;
00181   signal(sig, SIG_IGN);
00182   printf("Want to exit?\n");
00183   c = getchar();
00184   if(c == 'y' || c == 'Y')
00185     exit(0);
00186   else
00187     signal(SIGINT, INTHandler);
00188 }
00189 
00190 int main(int argc, char** argv) {
00191   ros::init(argc, argv, "mahalanobis_dist", ros::init_options::AnonymousName);
00192   //signal(SIGINT, INTHandler);
00193 
00194   SVMOneClass md;
00195   md.onInit();
00196   ros::spin();
00197   printf("Exiting\n");
00198   return 0;
00199 }


kelsey_sandbox
Author(s): kelsey
autogenerated on Wed Nov 27 2013 11:52:04