classifier_server.cpp
Go to the documentation of this file.
00001 /*********************************************************************
00002  *
00003  * Software License Agreement (BSD License)
00004  *
00005  *  Copyright (c) 2012, Scott Niekum
00006  *  All rights reserved.
00007  *
00008  *  Redistribution and use in source and binary forms, with or without
00009  *  modification, are permitted provided that the following conditions
00010  *  are met:
00011  *
00012  *   * Redistributions of source code must retain the above copyright
00013  *     notice, this list of conditions and the following disclaimer.
00014  *   * Redistributions in binary form must reproduce the above
00015  *     copyright notice, this list of conditions and the following
00016  *     disclaimer in the documentation and/or other materials provided
00017  *     with the distribution.
00018  *   * Neither the name of the Willow Garage nor the names of its
00019  *     contributors may be used to endorse or promote products derived
00020  *     from this software without specific prior written permission.
00021  *
00022  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00023  *  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00024  *  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
00025  *  FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
00026  *  COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
00027  *  INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
00028  *  BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00029  *  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
00030  *  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00031  *  LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
00032  *  ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00033  *  POSSIBILITY OF SUCH DAMAGE.
00034  *
00035  *********************************************************************/
00036 
00041 #include "ml_classifiers/zero_classifier.h"
00042 #include "ml_classifiers/nearest_neighbor_classifier.h"
00043 #include "ml_classifiers/CreateClassifier.h"
00044 #include "ml_classifiers/AddClassData.h"
00045 #include "ml_classifiers/TrainClassifier.h"
00046 #include "ml_classifiers/ClearClassifier.h"
00047 #include "ml_classifiers/SaveClassifier.h"
00048 #include "ml_classifiers/LoadClassifier.h"
00049 #include "ml_classifiers/ClassifyData.h"
00050 #include <pluginlib/class_loader.h>
00051 
00052 using namespace ml_classifiers;
00053 using namespace std;
00054 
00055 
00056 map<string,Classifier*> classifier_list;
00057 pluginlib::ClassLoader<Classifier> c_loader("ml_classifiers", "ml_classifiers::Classifier");
00058 
00059 bool createHelper(string class_type, Classifier* &c)
00060 {
00061     try{
00062         c = c_loader.createClassInstance(class_type);
00063     }
00064     catch(pluginlib::PluginlibException& ex){
00065         ROS_ERROR("Classifer plugin failed to load! Error: %s", ex.what());
00066     }
00067     
00068     return true;
00069 }
00070 
00071 
00072 bool createCallback(CreateClassifier::Request &req,
00073                     CreateClassifier::Response &res )
00074 {
00075     string id = req.identifier;
00076     Classifier *c;
00077     
00078     if(!createHelper(req.class_type, c)){
00079         res.success = false;
00080         return false;
00081     }
00082     
00083     if(classifier_list.find(id) != classifier_list.end()){
00084         cout << "WARNING: ID already exists, overwriting: " << req.identifier << endl;
00085         delete classifier_list[id];
00086     }
00087     classifier_list[id] = c;
00088     
00089     res.success = true;
00090     return true;
00091 }
00092 
00093 
00094 bool addCallback(AddClassData::Request &req,
00095                  AddClassData::Response &res )
00096 {
00097     string id = req.identifier;
00098     if(classifier_list.find(id) == classifier_list.end()){    
00099         res.success = false;
00100         return false;
00101     }
00102     
00103     for(size_t i=0; i<req.data.size(); i++)
00104         classifier_list[id]->addTrainingPoint(req.data[i].target_class, req.data[i].point);
00105     
00106     res.success = true;
00107     return true;
00108 }
00109 
00110 
00111 bool trainCallback(TrainClassifier::Request &req,
00112                    TrainClassifier::Response &res )
00113 {
00114     string id = req.identifier;
00115     if(classifier_list.find(id) == classifier_list.end()){    
00116         res.success = false;
00117         return false;
00118     }
00119     
00120     cout << "Training " << id << endl;
00121     
00122     classifier_list[id]->train();
00123     res.success = true;
00124     return true;
00125 }
00126 
00127 
00128 bool clearCallback(ClearClassifier::Request &req,
00129                    ClearClassifier::Response &res )
00130 {
00131     string id = req.identifier;
00132     if(classifier_list.find(id) == classifier_list.end()){    
00133         res.success = false;
00134         return false;
00135     }
00136     
00137     classifier_list[id]->clear();
00138     res.success = true;
00139     return true;
00140 }
00141 
00142 
00143 bool saveCallback(SaveClassifier::Request &req,
00144                   SaveClassifier::Response &res )
00145 {
00146     string id = req.identifier;
00147     if(classifier_list.find(id) == classifier_list.end()){    
00148         res.success = false;
00149         return false;
00150     }
00151     
00152     classifier_list[id]->save(req.filename);
00153     res.success = true;
00154     return true;
00155 }
00156 
00157 
00158 bool loadCallback(LoadClassifier::Request &req,
00159                   LoadClassifier::Response &res )
00160 {
00161     string id = req.identifier;
00162     
00163     Classifier *c;
00164     if(!createHelper(req.class_type, c)){
00165         res.success = false;
00166         return false;
00167     }
00168     
00169     if(!c->load(req.filename)){
00170         res.success = false;
00171         return false;
00172     }
00173     
00174     if(classifier_list.find(id) != classifier_list.end()){
00175         cout << "WARNING: ID already exists, overwriting: " << req.identifier << endl;
00176         delete classifier_list[id];
00177     }
00178     classifier_list[id] = c;
00179     
00180     res.success = true;
00181     return true;
00182 }
00183 
00184 
00185 bool classifyCallback(ClassifyData::Request &req,
00186                       ClassifyData::Response &res )
00187 {
00188     string id = req.identifier;
00189     for(size_t i=0; i<req.data.size(); i++){
00190         string class_num = classifier_list[id]->classifyPoint(req.data[i].point);
00191         res.classifications.push_back(class_num);
00192     }
00193     
00194     return true;
00195 }
00196 
00197 
00198 int main(int argc, char **argv)
00199 {
00200     ros::init(argc, argv, "classifier_server");
00201     ros::NodeHandle n;
00202     
00203     ros::ServiceServer service1 = n.advertiseService("/ml_classifiers/create_classifier", createCallback);
00204     ros::ServiceServer service2 = n.advertiseService("/ml_classifiers/add_class_data", addCallback);
00205     ros::ServiceServer service3 = n.advertiseService("/ml_classifiers/train_classifier", trainCallback);
00206     ros::ServiceServer service4 = n.advertiseService("/ml_classifiers/clear_classifier", clearCallback);
00207     ros::ServiceServer service5 = n.advertiseService("/ml_classifiers/save_classifier", saveCallback);
00208     ros::ServiceServer service6 = n.advertiseService("/ml_classifiers/load_classifier", loadCallback);
00209     ros::ServiceServer service7 = n.advertiseService("/ml_classifiers/classify_data", classifyCallback);
00210     
00211     ROS_INFO("Classifier services now ready");
00212     ros::spin();
00213 
00214     return 0;
00215 }


ml_classifiers
Author(s): Scott Niekum
autogenerated on Thu Aug 27 2015 13:59:04