Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040 #include <limits>
00041 #include <jsk_pcl_ros/color_histogram_classifier.h>
00042 #include <jsk_recognition_msgs/ClassificationResult.h>
00043
00044 namespace jsk_pcl_ros
00045 {
00046 void ColorHistogramClassifier::onInit()
00047 {
00048 DiagnosticNodelet::onInit();
00049 classifier_name_ = "color_histogram";
00050
00051 if (!loadReference()) return;
00052
00053 srv_ = boost::make_shared<dynamic_reconfigure::Server<Config> >(*pnh_);
00054 dynamic_reconfigure::Server<Config>::CallbackType f =
00055 boost::bind(&ColorHistogramClassifier::configCallback, this, _1, _2);
00056 srv_->setCallback(f);
00057 pub_class_ = advertise<jsk_recognition_msgs::ClassificationResult>(*pnh_, "output", 1);
00058 onInitPostProcess();
00059 }
00060
00061 bool ColorHistogramClassifier::loadReference()
00062 {
00063
00064
00065 std::vector<std::string> labels;
00066 pnh_->param("label_names", labels, std::vector<std::string>());
00067 if (labels.empty()) {
00068 NODELET_FATAL_STREAM("param ~label_names must not be empty");
00069 return false;
00070 }
00071
00072
00073 for (size_t i = 0; i < labels.size(); ++i) {
00074 std::string name = "histograms/" + labels[i];
00075 NODELET_INFO_STREAM("Loading " << name);
00076 std::vector<float> hist;
00077 pnh_->param(name, hist, std::vector<float>());
00078 if (hist.empty()) {
00079 NODELET_ERROR_STREAM("Failed to load " << name);
00080 } else {
00081 label_names_.push_back(labels[i]);
00082 reference_histograms_.push_back(hist);
00083 }
00084 }
00085
00086
00087 bin_size_ = reference_histograms_[0].size();
00088 for (size_t i = 0; i < label_names_.size(); ++i) {
00089 if (reference_histograms_[i].size() != bin_size_) {
00090 NODELET_FATAL_STREAM("size of histogram " << label_names_[i] << " is different from " << label_names_[0]);
00091 return false;
00092 }
00093 }
00094
00095 NODELET_INFO_STREAM("Loaded " << label_names_.size() << " references");
00096 return true;
00097 }
00098
00099 void ColorHistogramClassifier::configCallback(Config &config, uint32_t level)
00100 {
00101 boost::mutex::scoped_lock lock(mutex_);
00102 compare_policy_ = jsk_recognition_utils::ComparePolicy(config.compare_policy);
00103 detection_threshold_ = config.detection_threshold;
00104
00105 if (queue_size_ != config.queue_size) {
00106 queue_size_ = config.queue_size;
00107 if (isSubscribed()) {
00108 unsubscribe();
00109 subscribe();
00110 }
00111 }
00112 }
00113
00114 void ColorHistogramClassifier::subscribe()
00115 {
00116 sub_hist_ = pnh_->subscribe("input", 1, &ColorHistogramClassifier::feature, this);
00117 sub_hists_ = pnh_->subscribe("input/array", 1, &ColorHistogramClassifier::features, this);
00118 }
00119
00120 void ColorHistogramClassifier::unsubscribe()
00121 {
00122 sub_hist_.shutdown();
00123 sub_hists_.shutdown();
00124 }
00125
00126 void ColorHistogramClassifier::computeDistance(const std::vector<float>& histogram,
00127 std::vector<double>& distances) {
00128 distances.resize(reference_histograms_.size());
00129 for (size_t i = 0; i < reference_histograms_.size(); ++i) {
00130 jsk_recognition_utils::compareHistogram(
00131 histogram, reference_histograms_[i],
00132 compare_policy_, distances[i]);
00133 }
00134 }
00135
00136 void ColorHistogramClassifier::feature(const jsk_recognition_msgs::ColorHistogram::ConstPtr& histogram)
00137 {
00138 boost::mutex::scoped_lock lock(mutex_);
00139
00140 jsk_recognition_msgs::ClassificationResult result;
00141 result.header = histogram->header;
00142 result.classifier = classifier_name_;
00143 result.target_names = label_names_;
00144
00145 std::vector<double> distances;
00146 computeDistance(histogram->histogram, distances);
00147
00148 double max_prob = 0.0;
00149 int label;
00150 for (size_t i = 0; i < distances.size(); ++i) {
00151 double prob = distances[i];
00152 result.probabilities.push_back(prob);
00153 if (prob > max_prob) {
00154 max_prob = prob;
00155 label = i;
00156 }
00157 }
00158
00159 if (max_prob >= detection_threshold_) {
00160 result.labels.push_back(label);
00161 result.label_names.push_back(label_names_[label]);
00162 result.label_proba.push_back(max_prob);
00163 } else {
00164 result.labels.push_back(-1);
00165 result.label_names.push_back(std::string());
00166 result.label_proba.push_back(0.0);
00167 }
00168
00169 pub_class_.publish(result);
00170 }
00171
00172 void ColorHistogramClassifier::features(const jsk_recognition_msgs::ColorHistogramArray::ConstPtr& histograms)
00173 {
00174 boost::mutex::scoped_lock lock(mutex_);
00175
00176 jsk_recognition_msgs::ClassificationResult result;
00177 result.header = histograms->header;
00178 result.classifier = classifier_name_;
00179 result.target_names = label_names_;
00180
00181 for (size_t i = 0; i < histograms->histograms.size(); ++i) {
00182 std::vector<double> distances;
00183 computeDistance(histograms->histograms[i].histogram, distances);
00184
00185 double max_prob = 0.0;
00186 int label;
00187 for (size_t i = 0; i < distances.size(); ++i) {
00188 double prob = distances[i];
00189 result.probabilities.push_back(prob);
00190 if (prob > max_prob) {
00191 max_prob = prob;
00192 label = i;
00193 }
00194 }
00195
00196 if (max_prob >= detection_threshold_) {
00197 result.labels.push_back(label);
00198 result.label_names.push_back(label_names_[label]);
00199 result.label_proba.push_back(max_prob);
00200 } else {
00201 result.labels.push_back(-1);
00202 result.label_names.push_back(std::string());
00203 result.label_proba.push_back(0.0);
00204 }
00205 }
00206
00207 pub_class_.publish(result);
00208 }
00209 }
00210
00211 #include <pluginlib/class_list_macros.h>
00212 PLUGINLIB_EXPORT_CLASS(jsk_pcl_ros::ColorHistogramClassifier, nodelet::Nodelet);