Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00041 #include "ml_classifiers/zero_classifier.h"
00042 #include "ml_classifiers/nearest_neighbor_classifier.h"
00043 #include "ml_classifiers/CreateClassifier.h"
00044 #include "ml_classifiers/AddClassData.h"
00045 #include "ml_classifiers/TrainClassifier.h"
00046 #include "ml_classifiers/ClearClassifier.h"
00047 #include "ml_classifiers/SaveClassifier.h"
00048 #include "ml_classifiers/LoadClassifier.h"
00049 #include "ml_classifiers/ClassifyData.h"
00050 #include <pluginlib/class_loader.h>
00051
00052 using namespace ml_classifiers;
00053 using namespace std;
00054
00055
00056 map<string,Classifier*> classifier_list;
00057 pluginlib::ClassLoader<Classifier> c_loader("ml_classifiers", "ml_classifiers::Classifier");
00058
00059 bool createHelper(string class_type, Classifier* &c)
00060 {
00061 try{
00062 c = c_loader.createClassInstance(class_type);
00063 }
00064 catch(pluginlib::PluginlibException& ex){
00065 ROS_ERROR("Classifer plugin failed to load! Error: %s", ex.what());
00066 }
00067
00068 return true;
00069 }
00070
00071
00072 bool createCallback(CreateClassifier::Request &req,
00073 CreateClassifier::Response &res )
00074 {
00075 string id = req.identifier;
00076 Classifier *c;
00077
00078 if(!createHelper(req.class_type, c)){
00079 res.success = false;
00080 return false;
00081 }
00082
00083 if(classifier_list.find(id) != classifier_list.end()){
00084 cout << "WARNING: ID already exists, overwriting: " << req.identifier << endl;
00085 delete classifier_list[id];
00086 }
00087 classifier_list[id] = c;
00088
00089 res.success = true;
00090 return true;
00091 }
00092
00093
00094 bool addCallback(AddClassData::Request &req,
00095 AddClassData::Response &res )
00096 {
00097 string id = req.identifier;
00098 if(classifier_list.find(id) == classifier_list.end()){
00099 res.success = false;
00100 return false;
00101 }
00102
00103 for(size_t i=0; i<req.data.size(); i++)
00104 classifier_list[id]->addTrainingPoint(req.data[i].target_class, req.data[i].point);
00105
00106 res.success = true;
00107 return true;
00108 }
00109
00110
00111 bool trainCallback(TrainClassifier::Request &req,
00112 TrainClassifier::Response &res )
00113 {
00114 string id = req.identifier;
00115 if(classifier_list.find(id) == classifier_list.end()){
00116 res.success = false;
00117 return false;
00118 }
00119
00120 cout << "Training " << id << endl;
00121
00122 classifier_list[id]->train();
00123 res.success = true;
00124 return true;
00125 }
00126
00127
00128 bool clearCallback(ClearClassifier::Request &req,
00129 ClearClassifier::Response &res )
00130 {
00131 string id = req.identifier;
00132 if(classifier_list.find(id) == classifier_list.end()){
00133 res.success = false;
00134 return false;
00135 }
00136
00137 classifier_list[id]->clear();
00138 res.success = true;
00139 return true;
00140 }
00141
00142
00143 bool saveCallback(SaveClassifier::Request &req,
00144 SaveClassifier::Response &res )
00145 {
00146 string id = req.identifier;
00147 if(classifier_list.find(id) == classifier_list.end()){
00148 res.success = false;
00149 return false;
00150 }
00151
00152 classifier_list[id]->save(req.filename);
00153 res.success = true;
00154 return true;
00155 }
00156
00157
00158 bool loadCallback(LoadClassifier::Request &req,
00159 LoadClassifier::Response &res )
00160 {
00161 string id = req.identifier;
00162
00163 Classifier *c;
00164 if(!createHelper(req.class_type, c)){
00165 res.success = false;
00166 return false;
00167 }
00168
00169 if(!c->load(req.filename)){
00170 res.success = false;
00171 return false;
00172 }
00173
00174 if(classifier_list.find(id) != classifier_list.end()){
00175 cout << "WARNING: ID already exists, overwriting: " << req.identifier << endl;
00176 delete classifier_list[id];
00177 }
00178 classifier_list[id] = c;
00179
00180 res.success = true;
00181 return true;
00182 }
00183
00184
00185 bool classifyCallback(ClassifyData::Request &req,
00186 ClassifyData::Response &res )
00187 {
00188 string id = req.identifier;
00189 for(size_t i=0; i<req.data.size(); i++){
00190 string class_num = classifier_list[id]->classifyPoint(req.data[i].point);
00191 res.classifications.push_back(class_num);
00192 }
00193
00194 return true;
00195 }
00196
00197
00198 int main(int argc, char **argv)
00199 {
00200 ros::init(argc, argv, "classifier_server");
00201 ros::NodeHandle n;
00202
00203 ros::ServiceServer service1 = n.advertiseService("/ml_classifiers/create_classifier", createCallback);
00204 ros::ServiceServer service2 = n.advertiseService("/ml_classifiers/add_class_data", addCallback);
00205 ros::ServiceServer service3 = n.advertiseService("/ml_classifiers/train_classifier", trainCallback);
00206 ros::ServiceServer service4 = n.advertiseService("/ml_classifiers/clear_classifier", clearCallback);
00207 ros::ServiceServer service5 = n.advertiseService("/ml_classifiers/save_classifier", saveCallback);
00208 ros::ServiceServer service6 = n.advertiseService("/ml_classifiers/load_classifier", loadCallback);
00209 ros::ServiceServer service7 = n.advertiseService("/ml_classifiers/classify_data", classifyCallback);
00210
00211 ROS_INFO("Classifier services now ready");
00212 ros::spin();
00213
00214 return 0;
00215 }