00001 #include <jsk_topic_tools/log_utils.h>
00002 #include <jsk_perception/sliding_window_object_detector_trainer.h>
00003
00004 #include <iostream>
00005 namespace jsk_perception
00006 {
00007 SlidingWindowObjectDetectorTrainer::SlidingWindowObjectDetectorTrainer()
00008 #if CV_MAJOR_VERSION < 3
00009 : supportVectorMachine_(new cv::SVM)
00010 #endif
00011 {
00012 #if CV_MAJOR_VERSION >= 3
00013 this->supportVectorMachine_ = cv::ml::SVM::create();
00014 #endif
00015 nh_.getParam("dataset_path", this->dataset_path_);
00016 nh_.getParam("object_dataset_filename", this->object_dataset_filename_);
00017 nh_.getParam("nonobject_dataset_filename", this->nonobject_dataset_filename_);
00018 nh_.getParam("classifier_name", this->trained_classifier_name_);
00019 nh_.getParam("swindow_x", this->swindow_x_);
00020 nh_.getParam("swindow_y", this->swindow_y_);
00021
00022 ROS_INFO("--Training Classifier");
00023 std::string pfilename = dataset_path_ + this->object_dataset_filename_;
00024 std::string nfilename = dataset_path_ + this->nonobject_dataset_filename_;
00025 trainObjectClassifier(pfilename, nfilename);
00026 ROS_INFO("--Trained Successfully..");
00027
00028
00029 std::string mainfest_filename = "sliding_window_trainer_manifest.xml";
00030 cv::FileStorage fs = cv::FileStorage(
00031 mainfest_filename, cv::FileStorage::WRITE);
00032 this->writeTrainingManifestToDirectory(fs);
00033 fs.release();
00034
00035 cv::destroyAllWindows();
00036 nh_.shutdown();
00037 }
00038
00039 void SlidingWindowObjectDetectorTrainer::trainObjectClassifier(
00040 std::string pfilename, std::string nfilename)
00041 {
00042 cv::Mat featureMD;
00043 cv::Mat labelMD;
00044 std::string topic_name = "/dataset/roi";
00045 this->readDataset(pfilename, topic_name, featureMD, labelMD, true, 1);
00046 ROS_INFO("Info: Total Object Sample: %d", featureMD.rows);
00047
00048 topic_name = "/dataset/background/roi";
00049 this->readDataset(nfilename, topic_name, featureMD, labelMD, true, -1);
00050 ROS_INFO("Info: Total Training Features: %d", featureMD.rows);
00051
00052 try {
00053 this->trainBinaryClassSVM(featureMD, labelMD);
00054 this->supportVectorMachine_->save(
00055 this->trained_classifier_name_.c_str());
00056 } catch(std::exception &e) {
00057 ROS_ERROR("--ERROR: PLEASE CHECK YOUR DATA \n%s", e.what());
00058 std::_Exit(EXIT_FAILURE);
00059 }
00060 }
00061
00062 void SlidingWindowObjectDetectorTrainer::readDataset(
00063 std::string filename, std::string topic_name, cv::Mat &featureMD,
00064 cv::Mat &labelMD, bool is_usr_label, const int usr_label) {
00065 ROS_INFO("--READING DATASET IMAGE");
00066 try {
00067 rosbag_ = boost::shared_ptr<rosbag::Bag>(new rosbag::Bag);
00068 this->rosbag_->open(filename, rosbag::bagmode::Read);
00069 ROS_INFO("Bag Found and Opened Successfully...");
00070 std::vector<std::string> topics;
00071 topics.push_back(std::string(topic_name));
00072 rosbag::View view(*rosbag_, rosbag::TopicQuery(topics));
00073 BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00074 sensor_msgs::Image::ConstPtr img_msg = m.instantiate<
00075 sensor_msgs::Image>();
00076 cv_bridge::CvImagePtr cv_ptr = cv_bridge::toCvCopy(
00077 img_msg, sensor_msgs::image_encodings::BGR8);
00078 if (cv_ptr->image.data) {
00079 cv::Mat image = cv_ptr->image.clone();
00080 float label = static_cast<float>(usr_label);
00081 labelMD.push_back(label);
00082 this->extractFeatures(image, featureMD);
00083 cv::imshow("image", image);
00084 cv::waitKey(3);
00085 } else {
00086 ROS_WARN("-> NO IMAGE");
00087 }
00088 }
00089 this->rosbag_->close();
00090 } catch (ros::Exception &e) {
00091 ROS_ERROR("ERROR: Bag File:%s not found..\n%s",
00092 filename.c_str(), e.what());
00093 std::_Exit(EXIT_FAILURE);
00094 }
00095 }
00096
00100 void SlidingWindowObjectDetectorTrainer::extractFeatures(
00101 cv::Mat &img, cv::Mat &featureMD) {
00102 ROS_INFO("--EXTRACTING IMAGE FEATURES.");
00103 if (img.data) {
00104 cv::resize(img, img, cv::Size(this->swindow_x_, this->swindow_y_));
00105 cv::Mat hog_feature = this->computeHOG(img);
00106 cv::Mat hsv_feature;
00107 this->computeHSHistogram(img, hsv_feature, 16, 16, true);
00108 hsv_feature = hsv_feature.reshape(1, 1);
00109 cv::Mat _feature;
00110 this->concatenateCVMat(hog_feature, hsv_feature, _feature, true);
00111 featureMD.push_back(_feature);
00112 }
00113 cv::imshow("image", img);
00114 cv::waitKey(3);
00115 }
00116
00117 void SlidingWindowObjectDetectorTrainer::trainBinaryClassSVM(
00118 const cv::Mat &featureMD, const cv::Mat &labelMD)
00119 {
00120 ROS_INFO("--TRAINING CLASSIFIER");
00121 #if CV_MAJOR_VERSION >= 3
00122 this->supportVectorMachine_->setType(cv::ml::SVM::NU_SVC);
00123
00124 this->supportVectorMachine_->setDegree(0.0);
00125 this->supportVectorMachine_->setGamma(0.90);
00126 this->supportVectorMachine_->setCoef0(0.50);
00127 this->supportVectorMachine_->setC(1);
00128 this->supportVectorMachine_->setNu(0.70);
00129 this->supportVectorMachine_->setP(1.0);
00130
00131 cv::TermCriteria term_crit;
00132 term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
00133 term_crit.maxCount = 1e6;
00134 term_crit.epsilon = 1e-6;
00135 this->supportVectorMachine_->setTermCriteria(term_crit);
00136 cv::ml::ParamGrid paramGrid = cv::ml::ParamGrid();
00137 paramGrid.minVal = 0;
00138 paramGrid.maxVal = 0;
00139 paramGrid.logStep = 1;
00140
00141 cv::Ptr<cv::ml::TrainData> train = cv::ml::TrainData::create(featureMD, cv::ml::ROW_SAMPLE, labelMD, cv::Mat(), cv::Mat());
00142 this->supportVectorMachine_->trainAuto
00143 (train, 10,
00144 paramGrid, cv::ml::SVM::getDefaultGrid(cv::ml::SVM::GAMMA),
00145 cv::ml::SVM::getDefaultGrid(cv::ml::SVM::P),
00146 cv::ml::SVM::getDefaultGrid(cv::ml::SVM::NU),
00147 cv::ml::SVM::getDefaultGrid(cv::ml::SVM::COEF),
00148 cv::ml::SVM::getDefaultGrid(cv::ml::SVM::DEGREE),
00149 true);
00150 #else
00151 cv::SVMParams svm_param = cv::SVMParams();
00152 svm_param.svm_type = cv::SVM::NU_SVC;
00153 svm_param.kernel_type = cv::SVM::RBF;
00154 svm_param.degree = 0.0;
00155 svm_param.gamma = 0.90;
00156 svm_param.coef0 = 0.50;
00157 svm_param.C = 1;
00158 svm_param.nu = 0.70;
00159 svm_param.p = 1.0;
00160 svm_param.class_weights = NULL;
00161 svm_param.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
00162 svm_param.term_crit.max_iter = 1e6;
00163 svm_param.term_crit.epsilon = 1e-6;
00164 cv::ParamGrid paramGrid = cv::ParamGrid();
00165 paramGrid.min_val = 0;
00166 paramGrid.max_val = 0;
00167 paramGrid.step = 1;
00168
00169
00170
00171 this->supportVectorMachine_->train_auto
00172 (featureMD, labelMD, cv::Mat(), cv::Mat(), svm_param, 10,
00173 paramGrid, cv::SVM::get_default_grid(cv::SVM::GAMMA),
00174 cv::SVM::get_default_grid(cv::SVM::P),
00175 cv::SVM::get_default_grid(cv::SVM::NU),
00176 cv::SVM::get_default_grid(cv::SVM::COEF),
00177 cv::SVM::get_default_grid(cv::SVM::DEGREE),
00178 true);
00179 #endif
00180 }
00181
00182 void SlidingWindowObjectDetectorTrainer::concatenateCVMat(
00183 const cv::Mat &mat_1, const cv::Mat &mat_2,
00184 cv::Mat &featureMD, bool iscolwise)
00185 {
00186 if (iscolwise) {
00187 featureMD = cv::Mat(mat_1.rows, (mat_1.cols + mat_2.cols), CV_32F);
00188 for (int i = 0; i < featureMD.rows; i++) {
00189 for (int j = 0; j < mat_1.cols; j++) {
00190 featureMD.at<float>(i, j) = mat_1.at<float>(i, j);
00191 }
00192 for (int j = mat_1.cols; j < featureMD.cols; j++) {
00193 featureMD.at<float>(i, j) = mat_2.at<float>(i, j - mat_1.cols);
00194 }
00195 }
00196 } else {
00197 featureMD = cv::Mat((mat_1.rows + mat_2.rows), mat_1.cols, CV_32F);
00198 for (int i = 0; i < featureMD.cols; i++) {
00199 for (int j = 0; j < mat_1.rows; j++) {
00200 featureMD.at<float>(j, i) = mat_1.at<float>(j, i);
00201 }
00202 for (int j = mat_1.rows; j < featureMD.rows; j++) {
00203 featureMD.at<float>(j, i) = mat_2.at<float>(j - mat_1.rows, i);
00204 }
00205 }
00206 }
00207 }
00208
00209 void SlidingWindowObjectDetectorTrainer::writeTrainingManifestToDirectory(
00210 cv::FileStorage &fs)
00211 {
00212 fs << "TrainerInfo" << "{";
00213 fs << "trainer_type" << "cv::SVM";
00214 fs << "trainer_path" << this->trained_classifier_name_;
00215 fs << "}";
00216
00217 fs << "FeatureInfo" << "{";
00218 fs << "HOG" << 1;
00219
00220
00221
00222 fs << "COLOR_HISTOGRAM" << 1;
00223 fs << "}";
00224
00225 fs << "SlidingWindowInfo" << "{";
00226 fs << "swindow_x" << this->swindow_x_;
00227 fs << "swindow_y" << this->swindow_y_;
00228 fs << "}";
00229
00230 fs << "TrainingDatasetDirectoryInfo" << "{";
00231 fs << "object_dataset_filename" << this->object_dataset_filename_;
00232 fs << "nonobject_dataset_filename" << this->nonobject_dataset_filename_;
00233 fs << "dataset_path" << this->dataset_path_;
00234 fs << "}";
00235 }
00236
00237 void SlidingWindowObjectDetectorTrainer::computeHSHistogram(
00238 cv::Mat &src, cv::Mat &hist, const int hBin, const int sBin, bool is_norm)
00239 {
00240 if (src.empty()) {
00241 return;
00242 }
00243 cv::Mat hsv;
00244 cv::cvtColor(src, hsv, CV_BGR2HSV);
00245 int histSize[] = {hBin, sBin};
00246 float h_ranges[] = {0, 180};
00247 float s_ranges[] = {0, 256};
00248 const float* ranges[] = {h_ranges, s_ranges};
00249 int channels[] = {0, 1};
00250 cv::calcHist(
00251 &hsv, 1, channels, cv::Mat(), hist, 2, histSize, ranges, true, false);
00252 if (is_norm) {
00253 cv::normalize(hist, hist, 0, 1, cv::NORM_MINMAX, -1, cv::Mat());
00254 }
00255 }
00256 }
00257
00258 int main(int argc, char *argv[]) {
00259 ros::init(argc, argv, "sliding_window_object_detector_trainer_node");
00260 ROS_INFO("RUNNING NODELET %s", "sliding_window_object_detector_trainer");
00261 jsk_perception::SlidingWindowObjectDetectorTrainer run_trainer;
00262 ros::spin();
00263 return 0;
00264 }