sliding_window_object_detector_trainer_node.cpp
Go to the documentation of this file.
00001 #include <jsk_topic_tools/log_utils.h>
00002 #include <jsk_perception/sliding_window_object_detector_trainer.h>
00003 
00004 #include <iostream>
00005 namespace jsk_perception
00006 {
00007    SlidingWindowObjectDetectorTrainer::SlidingWindowObjectDetectorTrainer()
00008 #if CV_MAJOR_VERSION < 3
00009      : supportVectorMachine_(new cv::SVM)
00010 #endif
00011    {
00012 #if CV_MAJOR_VERSION >= 3
00013       this->supportVectorMachine_ = cv::ml::SVM::create();
00014 #endif
00015       nh_.getParam("dataset_path", this->dataset_path_);
00016       nh_.getParam("object_dataset_filename", this->object_dataset_filename_);
00017       nh_.getParam("nonobject_dataset_filename", this->nonobject_dataset_filename_);
00018       nh_.getParam("classifier_name", this->trained_classifier_name_);
00019       nh_.getParam("swindow_x", this->swindow_x_);
00020       nh_.getParam("swindow_y", this->swindow_y_);
00021 
00022       ROS_INFO("--Training Classifier");
00023       std::string pfilename = dataset_path_ + this->object_dataset_filename_;
00024       std::string nfilename = dataset_path_ + this->nonobject_dataset_filename_;
00025       trainObjectClassifier(pfilename, nfilename);
00026       ROS_INFO("--Trained Successfully..");
00027 
00028       /*write the training manifest*/
00029       std::string mainfest_filename = "sliding_window_trainer_manifest.xml";
00030       cv::FileStorage fs = cv::FileStorage(
00031          mainfest_filename, cv::FileStorage::WRITE);
00032       this->writeTrainingManifestToDirectory(fs);
00033       fs.release();
00034 
00035       cv::destroyAllWindows();
00036       nh_.shutdown();
00037    }
00038    
00039    void SlidingWindowObjectDetectorTrainer::trainObjectClassifier(
00040       std::string pfilename, std::string nfilename)
00041    {
00042       cv::Mat featureMD;
00043       cv::Mat labelMD;
00044       std::string topic_name = "/dataset/roi";
00045       this->readDataset(pfilename, topic_name, featureMD, labelMD, true, 1);
00046       ROS_INFO("Info: Total Object Sample: %d", featureMD.rows);
00047       
00048       topic_name = "/dataset/background/roi";
00049       this->readDataset(nfilename, topic_name, featureMD, labelMD, true, -1);
00050       ROS_INFO("Info: Total Training Features: %d", featureMD.rows);
00051     
00052       try {
00053          this->trainBinaryClassSVM(featureMD, labelMD);
00054          this->supportVectorMachine_->save(
00055             this->trained_classifier_name_.c_str());
00056       } catch(std::exception &e) {
00057          ROS_ERROR("--ERROR: PLEASE CHECK YOUR DATA \n%s", e.what());
00058          std::_Exit(EXIT_FAILURE);
00059       }
00060    }
00061 
00062    void SlidingWindowObjectDetectorTrainer::readDataset(
00063       std::string filename, std::string topic_name, cv::Mat &featureMD,
00064       cv::Mat &labelMD, bool is_usr_label, const int usr_label) {
00065       ROS_INFO("--READING DATASET IMAGE");
00066       try {
00067          rosbag_ = boost::shared_ptr<rosbag::Bag>(new rosbag::Bag);
00068          this->rosbag_->open(filename, rosbag::bagmode::Read);
00069          ROS_INFO("Bag Found and Opened Successfully...");
00070          std::vector<std::string> topics;
00071          topics.push_back(std::string(topic_name));
00072          rosbag::View view(*rosbag_, rosbag::TopicQuery(topics));
00073          BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00074             sensor_msgs::Image::ConstPtr img_msg = m.instantiate<
00075                sensor_msgs::Image>();
00076             cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(
00077                img_msg, sensor_msgs::image_encodings::BGR8);
00078             if (cv_ptr->image.data) {
00079                cv::Mat image = cv_ptr->image.clone();
00080                float label = static_cast<float>(usr_label);
00081                labelMD.push_back(label);
00082                this->extractFeatures(image, featureMD);
00083                cv::imshow("image", image);
00084                cv::waitKey(3);
00085             } else {
00086                ROS_WARN("-> NO IMAGE");
00087             }
00088          }
00089          this->rosbag_->close();
00090       } catch (ros::Exception &e) {
00091          ROS_ERROR("ERROR: Bag File:%s not found..\n%s",
00092                    filename.c_str(), e.what());
00093          std::_Exit(EXIT_FAILURE);
00094       }
00095    }
00096 
00100    void SlidingWindowObjectDetectorTrainer::extractFeatures(
00101       cv::Mat &img, cv::Mat &featureMD) {
00102       ROS_INFO("--EXTRACTING IMAGE FEATURES.");
00103       if (img.data) {
00104          cv::resize(img, img, cv::Size(this->swindow_x_, this->swindow_y_));
00105          cv::Mat hog_feature = this->computeHOG(img);
00106          cv::Mat hsv_feature;
00107          this->computeHSHistogram(img, hsv_feature, 16, 16, true);
00108          hsv_feature = hsv_feature.reshape(1, 1);
00109          cv::Mat _feature;
00110          this->concatenateCVMat(hog_feature, hsv_feature, _feature, true);
00111          featureMD.push_back(_feature);
00112       }
00113       cv::imshow("image", img);
00114       cv::waitKey(3);
00115    }
00116 
00117    void SlidingWindowObjectDetectorTrainer::trainBinaryClassSVM(
00118       const cv::Mat &featureMD, const cv::Mat &labelMD)
00119    {
00120       ROS_INFO("--TRAINING CLASSIFIER");
00121 #if CV_MAJOR_VERSION >= 3
00122       this->supportVectorMachine_->setType(cv::ml::SVM::NU_SVC);
00123       //this->supportVectorMachine_->setKernelType(cv::ml::SVM::RBF);
00124       this->supportVectorMachine_->setDegree(0.0);
00125       this->supportVectorMachine_->setGamma(0.90);
00126       this->supportVectorMachine_->setCoef0(0.50);
00127       this->supportVectorMachine_->setC(1);
00128       this->supportVectorMachine_->setNu(0.70);
00129       this->supportVectorMachine_->setP(1.0);
00130       //this->supportVectorMachine_->setClassWeights(NULL);
00131       cv::TermCriteria term_crit;
00132       term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
00133       term_crit.maxCount = 1e6;
00134       term_crit.epsilon = 1e-6;
00135       this->supportVectorMachine_->setTermCriteria(term_crit);
00136       cv::ml::ParamGrid paramGrid = cv::ml::ParamGrid();
00137       paramGrid.minVal = 0;
00138       paramGrid.maxVal = 0;
00139       paramGrid.logStep = 1;
00140 
00141       cv::Ptr<cv::ml::TrainData> train = cv::ml::TrainData::create(featureMD, cv::ml::ROW_SAMPLE, labelMD, cv::Mat(), cv::Mat()); // ROW_SAMPLE ? COL_SAMPLE ?
00142       this->supportVectorMachine_->trainAuto
00143         (train, 10,
00144          paramGrid, cv::ml::SVM::getDefaultGrid(cv::ml::SVM::GAMMA),
00145          cv::ml::SVM::getDefaultGrid(cv::ml::SVM::P),
00146          cv::ml::SVM::getDefaultGrid(cv::ml::SVM::NU),
00147          cv::ml::SVM::getDefaultGrid(cv::ml::SVM::COEF),
00148          cv::ml::SVM::getDefaultGrid(cv::ml::SVM::DEGREE),
00149           true);
00150 #else
00151       cv::SVMParams svm_param = cv::SVMParams();
00152       svm_param.svm_type = cv::SVM::NU_SVC;
00153       svm_param.kernel_type = cv::SVM::RBF;
00154       svm_param.degree = 0.0;
00155       svm_param.gamma = 0.90;
00156       svm_param.coef0 = 0.50;
00157       svm_param.C = 1;
00158       svm_param.nu = 0.70;
00159       svm_param.p = 1.0;
00160       svm_param.class_weights = NULL;
00161       svm_param.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
00162       svm_param.term_crit.max_iter = 1e6;
00163       svm_param.term_crit.epsilon = 1e-6;
00164       cv::ParamGrid paramGrid = cv::ParamGrid();
00165       paramGrid.min_val = 0;
00166       paramGrid.max_val = 0;
00167       paramGrid.step = 1;
00168 
00169       /*this->supportVectorMachine_->train(
00170         featureMD, labelMD, cv::Mat(), cv::Mat(), svm_param);*/
00171       this->supportVectorMachine_->train_auto
00172          (featureMD, labelMD, cv::Mat(), cv::Mat(), svm_param, 10,
00173           paramGrid, cv::SVM::get_default_grid(cv::SVM::GAMMA),
00174           cv::SVM::get_default_grid(cv::SVM::P),
00175           cv::SVM::get_default_grid(cv::SVM::NU),
00176           cv::SVM::get_default_grid(cv::SVM::COEF),
00177           cv::SVM::get_default_grid(cv::SVM::DEGREE),
00178           true);
00179 #endif
00180    }
00181 
00182    void SlidingWindowObjectDetectorTrainer::concatenateCVMat(
00183       const cv::Mat &mat_1, const cv::Mat &mat_2,
00184       cv::Mat &featureMD, bool iscolwise)
00185    {
00186       if (iscolwise) {
00187          featureMD = cv::Mat(mat_1.rows, (mat_1.cols + mat_2.cols), CV_32F);
00188          for (int i = 0; i < featureMD.rows; i++) {
00189             for (int j = 0; j < mat_1.cols; j++) {
00190                featureMD.at<float>(i, j) = mat_1.at<float>(i, j);
00191             }
00192             for (int j = mat_1.cols; j < featureMD.cols; j++) {
00193                featureMD.at<float>(i, j) = mat_2.at<float>(i, j - mat_1.cols);
00194             }
00195          }
00196       } else {
00197          featureMD = cv::Mat((mat_1.rows + mat_2.rows), mat_1.cols, CV_32F);
00198          for (int i = 0; i < featureMD.cols; i++) {
00199             for (int j = 0; j < mat_1.rows; j++) {
00200                featureMD.at<float>(j, i) = mat_1.at<float>(j, i);
00201             }
00202             for (int j = mat_1.rows; j < featureMD.rows; j++) {
00203                featureMD.at<float>(j, i) = mat_2.at<float>(j - mat_1.rows, i);
00204             }
00205          }
00206       }
00207    }
00208    
00209    void SlidingWindowObjectDetectorTrainer::writeTrainingManifestToDirectory(
00210       cv::FileStorage &fs)
00211    {
00212       fs <<  "TrainerInfo" << "{";
00213       fs << "trainer_type" << "cv::SVM";
00214       fs << "trainer_path" << this->trained_classifier_name_;
00215       fs << "}";
00216 
00217       fs <<  "FeatureInfo" << "{";
00218       fs << "HOG" << 1;
00219       // fs << "LBP" << 0;
00220       // fs << "SIFT" << 0;
00221       // fs << "SURF" << 0;
00222       fs << "COLOR_HISTOGRAM" << 1;
00223       fs << "}";
00224     
00225       fs <<  "SlidingWindowInfo" << "{";
00226       fs << "swindow_x" << this->swindow_x_;
00227       fs << "swindow_y" << this->swindow_y_;
00228       fs << "}";
00229     
00230       fs << "TrainingDatasetDirectoryInfo" << "{";
00231       fs << "object_dataset_filename" << this->object_dataset_filename_;
00232       fs << "nonobject_dataset_filename" << this->nonobject_dataset_filename_;
00233       fs << "dataset_path" << this->dataset_path_;  // only path to neg
00234       fs << "}";
00235    }
00236    
00237    void SlidingWindowObjectDetectorTrainer::computeHSHistogram(
00238       cv::Mat &src, cv::Mat &hist, const int hBin, const int sBin, bool is_norm)
00239    {
00240       if (src.empty()) {
00241          return;
00242       }
00243       cv::Mat hsv;
00244       cv::cvtColor(src, hsv, CV_BGR2HSV);
00245       int histSize[] = {hBin, sBin};
00246       float h_ranges[] = {0, 180};
00247       float s_ranges[] = {0, 256};
00248       const float* ranges[] = {h_ranges, s_ranges};
00249       int channels[] = {0, 1};
00250       cv::calcHist(
00251          &hsv, 1, channels, cv::Mat(), hist, 2, histSize, ranges, true, false);
00252       if (is_norm) {
00253          cv::normalize(hist, hist, 0, 1, cv::NORM_MINMAX, -1, cv::Mat());
00254       }
00255    }
00256 }  // namespace jsk_perception
00257 
00258 int main(int argc, char *argv[]) {
00259     ros::init(argc, argv, "sliding_window_object_detector_trainer_node");
00260     ROS_INFO("RUNNING NODELET %s", "sliding_window_object_detector_trainer");
00261     jsk_perception::SlidingWindowObjectDetectorTrainer run_trainer;
00262     ros::spin();
00263     return 0;    
00264 }


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07