ros2_classifier_server.cpp
Go to the documentation of this file.
1 // Copyright (c) 2012, 2019 Scott Niekum, Joshua Whitley
2 // All rights reserved.
3 //
4 // Software License Agreement (BSD License 2.0)
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
8 // are met:
9 //
10 // * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 // * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following
14 // disclaimer in the documentation and/or other materials provided
15 // with the distribution.
16 // * Neither the name of {copyright_holder} nor the names of its
17 // contributors may be used to endorse or promote products derived
18 // from this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24 // COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31 // POSSIBILITY OF SUCH DAMAGE.
32 
34 
35 #include <string>
36 #include <map>
37 #include <memory>
38 
39 #include "rclcpp/rclcpp.hpp"
40 
43 #include "ml_classifiers/srv/add_class_data.hpp"
44 #include "ml_classifiers/srv/classify_data.hpp"
45 #include "ml_classifiers/srv/clear_classifier.hpp"
46 #include "ml_classifiers/srv/create_classifier.hpp"
47 #include "ml_classifiers/srv/load_classifier.hpp"
48 #include "ml_classifiers/srv/save_classifier.hpp"
49 #include "ml_classifiers/srv/train_classifier.hpp"
50 
51 using namespace ml_classifiers; // NOLINT
52 using std::string;
53 using std::cout;
54 using std::endl;
55 using std::placeholders::_1;
56 using std::placeholders::_2;
57 using std::placeholders::_3;
58 
59 class ClassifierServer : public rclcpp::Node
60 {
61 public:
63  : Node("classifier_server"),
64  c_loader("ml_classifiers", "ml_classifiers::Classifier")
65  {
66  create_srv = this->create_service<ml_classifiers::srv::CreateClassifier>(
67  "create_classifier",
68  std::bind(&ClassifierServer::createCallback, this, _1, _2, _3));
69  add_srv = this->create_service<ml_classifiers::srv::AddClassData>(
70  "add_class_data",
71  std::bind(&ClassifierServer::addCallback, this, _1, _2, _3));
72  train_srv = this->create_service<ml_classifiers::srv::TrainClassifier>(
73  "train_classifier",
74  std::bind(&ClassifierServer::trainCallback, this, _1, _2, _3));
75  clear_srv = this->create_service<ml_classifiers::srv::ClearClassifier>(
76  "clear_classifier",
77  std::bind(&ClassifierServer::clearCallback, this, _1, _2, _3));
78  save_srv = this->create_service<ml_classifiers::srv::SaveClassifier>(
79  "save_classifier",
80  std::bind(&ClassifierServer::saveCallback, this, _1, _2, _3));
81  load_srv = this->create_service<ml_classifiers::srv::LoadClassifier>(
82  "load_classifier",
83  std::bind(&ClassifierServer::loadCallback, this, _1, _2, _3));
84  classify_srv = this->create_service<ml_classifiers::srv::ClassifyData>(
85  "classify_data",
86  std::bind(&ClassifierServer::classifyCallback, this, _1, _2, _3));
87 
88  RCLCPP_INFO(this->get_logger(), "Classifier services now ready");
89  }
90 
91 private:
93  std::map<string, std::shared_ptr<Classifier>> classifier_list;
94 
95  rclcpp::Service<ml_classifiers::srv::CreateClassifier>::SharedPtr create_srv;
96  rclcpp::Service<ml_classifiers::srv::AddClassData>::SharedPtr add_srv;
97  rclcpp::Service<ml_classifiers::srv::TrainClassifier>::SharedPtr train_srv;
98  rclcpp::Service<ml_classifiers::srv::ClearClassifier>::SharedPtr clear_srv;
99  rclcpp::Service<ml_classifiers::srv::SaveClassifier>::SharedPtr save_srv;
100  rclcpp::Service<ml_classifiers::srv::LoadClassifier>::SharedPtr load_srv;
101  rclcpp::Service<ml_classifiers::srv::ClassifyData>::SharedPtr classify_srv;
102 
103  bool createHelper(string class_type, std::shared_ptr<Classifier> & c)
104  {
105  try {
106  c = std::shared_ptr<Classifier>(c_loader.createUnmanagedInstance(class_type));
107  } catch (pluginlib::PluginlibException & ex) {
108  RCLCPP_ERROR(
109  this->get_logger(),
110  "Classifer plugin failed to load! Error: %s",
111  ex.what());
112  return false;
113  }
114 
115  return true;
116  }
117 
119  const std::shared_ptr<rmw_request_id_t> req_hdr,
120  const std::shared_ptr<ml_classifiers::srv::CreateClassifier::Request> req,
121  std::shared_ptr<ml_classifiers::srv::CreateClassifier::Response> res)
122  {
123  (void)req_hdr;
124  string id = req->identifier;
125  std::shared_ptr<Classifier> c;
126 
127  if (!createHelper(req->class_type, c)) {
128  res->success = false;
129  } else {
130  if (classifier_list.find(id) != classifier_list.end()) {
131  RCLCPP_INFO(
132  this->get_logger(),
133  "WARNING: ID already exists, overwriting: %s",
134  req->identifier.c_str());
135  classifier_list.erase(id);
136  }
137 
138  classifier_list[id] = c;
139 
140  res->success = true;
141  }
142  }
143 
145  const std::shared_ptr<rmw_request_id_t> req_hdr,
146  const std::shared_ptr<ml_classifiers::srv::AddClassData::Request> req,
147  std::shared_ptr<ml_classifiers::srv::AddClassData::Response> res)
148  {
149  (void)req_hdr;
150  string id = req->identifier;
151 
152  if (classifier_list.find(id) == classifier_list.end()) {
153  res->success = false;
154  } else {
155  for (size_t i = 0; i < req->data.size(); i++) {
156  classifier_list[id]->addTrainingPoint(req->data[i].target_class, req->data[i].point);
157  }
158 
159  res->success = true;
160  }
161  }
162 
164  const std::shared_ptr<rmw_request_id_t> req_hdr,
165  const std::shared_ptr<ml_classifiers::srv::TrainClassifier::Request> req,
166  std::shared_ptr<ml_classifiers::srv::TrainClassifier::Response> res)
167  {
168  (void)req_hdr;
169  string id = req->identifier;
170 
171  if (classifier_list.find(id) == classifier_list.end()) {
172  res->success = false;
173  } else {
174  RCLCPP_INFO(
175  this->get_logger(),
176  "Training %s",
177  id.c_str());
178 
179  classifier_list[id]->train();
180  res->success = true;
181  }
182  }
183 
185  const std::shared_ptr<rmw_request_id_t> req_hdr,
186  const std::shared_ptr<ml_classifiers::srv::ClearClassifier::Request> req,
187  std::shared_ptr<ml_classifiers::srv::ClearClassifier::Response> res)
188  {
189  (void)req_hdr;
190  string id = req->identifier;
191 
192  if (classifier_list.find(id) == classifier_list.end()) {
193  res->success = false;
194  } else {
195  classifier_list[id]->clear();
196  res->success = true;
197  }
198  }
199 
201  const std::shared_ptr<rmw_request_id_t> req_hdr,
202  const std::shared_ptr<ml_classifiers::srv::SaveClassifier::Request> req,
203  std::shared_ptr<ml_classifiers::srv::SaveClassifier::Response> res)
204  {
205  (void)req_hdr;
206  string id = req->identifier;
207 
208  if (classifier_list.find(id) == classifier_list.end()) {
209  res->success = false;
210  } else {
211  classifier_list[id]->save(req->filename);
212  res->success = true;
213  }
214  }
215 
217  const std::shared_ptr<rmw_request_id_t> req_hdr,
218  const std::shared_ptr<ml_classifiers::srv::LoadClassifier::Request> req,
219  std::shared_ptr<ml_classifiers::srv::LoadClassifier::Response> res)
220  {
221  (void)req_hdr;
222  string id = req->identifier;
223 
224  std::shared_ptr<Classifier> c;
225 
226  if (!createHelper(req->class_type, c)) {
227  res->success = false;
228  } else {
229  if (!c->load(req->filename)) {
230  res->success = false;
231  } else {
232  if (classifier_list.find(id) != classifier_list.end()) {
233  RCLCPP_WARN(
234  this->get_logger(),
235  "WARNING: ID already exists, overwriting: %s",
236  req->identifier.c_str());
237  classifier_list.erase(id);
238  }
239  classifier_list[id] = c;
240 
241  res->success = true;
242  }
243  }
244  }
245 
247  const std::shared_ptr<rmw_request_id_t> req_hdr,
248  const std::shared_ptr<ml_classifiers::srv::ClassifyData::Request> req,
249  std::shared_ptr<ml_classifiers::srv::ClassifyData::Response> res)
250  {
251  (void)req_hdr;
252  string id = req->identifier;
253 
254  for (size_t i = 0; i < req->data.size(); i++) {
255  string class_num = classifier_list[id]->classifyPoint(req->data[i].point);
256  res->classifications.push_back(class_num);
257  }
258  }
259 };
260 
261 int main(int argc, char ** argv)
262 {
263  rclcpp::init(argc, argv);
264 
265  auto node = std::make_shared<ClassifierServer>();
266 
267  rclcpp::spin(node);
268 
269  rclcpp::shutdown();
270  return 0;
271 }
bool createHelper(string class_type, boost::shared_ptr< Classifier > &c)
int main(int argc, char **argv)
void createCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::CreateClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::CreateClassifier::Response > res)
rclcpp::Service< ml_classifiers::srv::LoadClassifier >::SharedPtr load_srv
void classifyCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::ClassifyData::Request > req, std::shared_ptr< ml_classifiers::srv::ClassifyData::Response > res)
void loadCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::LoadClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::LoadClassifier::Response > res)
void trainCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::TrainClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::TrainClassifier::Response > res)
void saveCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::SaveClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::SaveClassifier::Response > res)
T * createUnmanagedInstance(const std::string &lookup_name)
rclcpp::Service< ml_classifiers::srv::AddClassData >::SharedPtr add_srv
pluginlib::ClassLoader< Classifier > c_loader
rclcpp::Service< ml_classifiers::srv::ClassifyData >::SharedPtr classify_srv
void addCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::AddClassData::Request > req, std::shared_ptr< ml_classifiers::srv::AddClassData::Response > res)
pluginlib::ClassLoader< Classifier > c_loader("ml_classifiers", "ml_classifiers::Classifier")
rclcpp::Service< ml_classifiers::srv::TrainClassifier >::SharedPtr train_srv
rclcpp::Service< ml_classifiers::srv::ClearClassifier >::SharedPtr clear_srv
c
Definition: easy.py:61
void clearCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::ClearClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::ClearClassifier::Response > res)
rclcpp::Service< ml_classifiers::srv::SaveClassifier >::SharedPtr save_srv
std::map< string, std::shared_ptr< Classifier > > classifier_list
bool createHelper(string class_type, std::shared_ptr< Classifier > &c)
rclcpp::Service< ml_classifiers::srv::CreateClassifier >::SharedPtr create_srv


ml_classifiers
Author(s): Scott Niekum , Joshua Whitley
autogenerated on Mon Feb 28 2022 22:46:49