Go to the documentation of this file.00001
00002 #include <pcl_cloud_algos/cloud_algos.h>
00003 #include <pcl_cloud_algos/svm_classification.h>
00004
00005 using namespace cloud_algos;
00006
00007 void SVMClassification::init (ros::NodeHandle& nh)
00008 {
00009 nh_ = nh;
00010 }
00011
00012 void SVMClassification::pre ()
00013 {
00014
00015
00016 nh_.param("model_file_name", model_file_name_, model_file_name_);
00017 nh_.param("scale_file_name", scale_file_name_, scale_file_name_);
00018 nh_.param("scale_self", scale_self_, scale_self_);
00019 nh_.param("scale_file", scale_file_, scale_file_);
00020 }
00021
00022 void SVMClassification::post ()
00023 {
00024
00025 }
00026
00027 std::vector<std::string> SVMClassification::requires ()
00028 {
00029 std::vector<std::string> requires;
00030
00031 requires.push_back("f1");
00032 return requires;
00033 }
00034
00035 std::vector<std::string> SVMClassification::provides ()
00036 {
00037 std::vector<std::string> provides;
00038 provides.push_back("point_class");
00039 return provides;
00040 }
00041
00042 std::string SVMClassification::process (const boost::shared_ptr<const SVMClassification::InputType>& cloud)
00043 {
00044
00045 int fIdx = -1;
00046 for (unsigned int d = 0; d < cloud->channels.size (); d++)
00047 if (cloud->channels[d].name == "f1")
00048 {
00049 fIdx = d;
00050 break;
00051 }
00052 if (fIdx == -1)
00053 {
00054 if (verbosity_level_ > -2) ROS_ERROR ("[SVMClassification] Provided point cloud does not have features computed. Use PFH or similar first!");
00055 output_valid_ = false;
00056 return std::string("missing features");
00057 }
00058 int nr_values = 1;
00059 for (unsigned int d = fIdx+1; d < cloud->channels.size (); d++)
00060 {
00061 char dim_name[16];
00062 sprintf (dim_name, "f%d", nr_values+1);
00063 if (cloud->channels[d].name == dim_name)
00064 nr_values++;
00065 }
00066 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] Found %d feature values in input PCD.", nr_values);
00067
00068
00069 int plIdx = -1;
00070 for (unsigned int d = 0; d < cloud->channels.size (); d++)
00071 if (cloud->channels[d].name == "point_label")
00072 {
00073 plIdx = d;
00074 break;
00075 }
00076 if (plIdx == -1)
00077 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] NOTE: Points are not labeled with the expected classification results. If you want to evaluate the results please add point_label channel.");
00078
00080 struct svm_node* node;
00081 struct svm_model* model;
00082 if ((model = svm_load_model (model_file_name_.c_str ())) == 0)
00083 {
00084 if (verbosity_level_ > -2) ROS_ERROR ("[SVMClassification] Couldn't load SVM model from %s", model_file_name_.c_str ());
00085 output_valid_ = false;
00086 return std::string("incorrect model file");
00087 }
00088 node = (struct svm_node*) malloc ((nr_values+1) * sizeof (struct svm_node));
00089 ROS_INFO ("[SVMClassification] SVM model type: %d with %d output classes (read from %s).", svm_get_svm_type (model), svm_get_nr_class (model), model_file_name_.c_str ());
00090
00091
00092 double lower, upper;
00093 double** value_ranges = NULL;
00094 if (scale_self_)
00095 {
00096 lower = -1;
00097 upper = +1;
00098 value_ranges = computeScaleParameters (cloud, fIdx, nr_values);
00099 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] Scaling data to the interval (%g,%g) enabled.", lower, upper);
00100 }
00101 else if (scale_file_)
00102 {
00103 value_ranges = parseScaleParameterFile (scale_file_name_.c_str (), lower, upper, nr_values);
00104 if (value_ranges == NULL)
00105 {
00106 if (verbosity_level_ > -2) ROS_ERROR ("[SVMClassification] Scaling requested from file %s but it is not possible!", scale_file_name_.c_str ());
00107 output_valid_ = false;
00108 return std::string("incorrect scale parameter file");
00109 }
00110 else
00111 {
00112 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] Scaling according to the limits from %s to the interval (%g,%g) enabled.", scale_file_name_.c_str (), lower, upper);
00113 }
00114 }
00115
00116
00117 ros::Time global_time = ros::Time::now ();
00118
00119
00120
00121 cloud_svm_ = boost::shared_ptr<sensor_msgs::PointCloud> (new sensor_msgs::PointCloud());
00122 cloud_svm_->header = cloud->header;
00123 cloud_svm_->points = cloud->points;
00124 cloud_svm_->channels = cloud->channels;
00125
00126
00127 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] Saving classification results to point_class channel.");
00128 int pcIdx = cloud_svm_->channels.size ();
00129 cloud_svm_->channels.resize (pcIdx + 1);
00130 cloud_svm_->channels[pcIdx].name = "point_class";
00131 cloud_svm_->channels[pcIdx].values.resize (cloud_svm_->points.size (), 0.0);
00132 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] Added channel: %s", cloud_svm_->channels[pcIdx].name.c_str ());
00133
00134
00135 int success = 0;
00136 for (size_t cp = 0; cp < cloud_svm_->points.size (); cp++)
00137 {
00138
00139 int i = 0;
00140 for (i = 0; i < nr_values; i++)
00141 {
00142 node[i].index = i+1;
00143 double feature_value = cloud_svm_->channels[fIdx + i].values[cp];
00144 if (value_ranges != NULL)
00145 node[i].value = scaleFeature (i, feature_value, value_ranges, lower, upper);
00146 else
00147 node[i].value = feature_value;
00148 }
00149 node[i].index = -1;
00150
00151
00152 cloud_svm_->channels[pcIdx].values[cp] = svm_predict (model, node);
00153
00154
00155 if (plIdx != -1 && cloud_svm_->channels[pcIdx].values[cp] == cloud_svm_->channels[plIdx].values[cp])
00156 success++;
00157 }
00158 if (plIdx != -1)
00159 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] Accuracy: %d/%d (%g%%).", success, (int)(cloud_svm_->points.size ()), success * 100.0 / cloud_svm_->points.size ());
00160
00161
00162 svm_destroy_model (model);
00163 free (node);
00164
00165
00166 if (verbosity_level_ > 0) ROS_INFO ("[SVMClassification] SVM classification done in %g seconds.", (ros::Time::now () - global_time).toSec ());
00167 output_valid_ = true;
00168 return std::string("ok");
00169 }
00170
00171 boost::shared_ptr<const SVMClassification::OutputType> SVMClassification::output ()
00172 {return cloud_svm_;}
00173
00174 #ifdef CREATE_NODE
00175 int main (int argc, char* argv[])
00176 {
00177 return standalone_node <SVMClassification> (argc, argv);
00178 }
00179 #endif
00180