39 #include "rclcpp/rclcpp.hpp" 43 #include "ml_classifiers/srv/add_class_data.hpp" 44 #include "ml_classifiers/srv/classify_data.hpp" 45 #include "ml_classifiers/srv/clear_classifier.hpp" 46 #include "ml_classifiers/srv/create_classifier.hpp" 47 #include "ml_classifiers/srv/load_classifier.hpp" 48 #include "ml_classifiers/srv/save_classifier.hpp" 49 #include "ml_classifiers/srv/train_classifier.hpp" 55 using std::placeholders::_1;
56 using std::placeholders::_2;
57 using std::placeholders::_3;
63 : Node(
"classifier_server"),
64 c_loader(
"ml_classifiers",
"ml_classifiers::Classifier")
66 create_srv = this->create_service<ml_classifiers::srv::CreateClassifier>(
69 add_srv = this->create_service<ml_classifiers::srv::AddClassData>(
72 train_srv = this->create_service<ml_classifiers::srv::TrainClassifier>(
75 clear_srv = this->create_service<ml_classifiers::srv::ClearClassifier>(
78 save_srv = this->create_service<ml_classifiers::srv::SaveClassifier>(
81 load_srv = this->create_service<ml_classifiers::srv::LoadClassifier>(
84 classify_srv = this->create_service<ml_classifiers::srv::ClassifyData>(
88 RCLCPP_INFO(this->get_logger(),
"Classifier services now ready");
95 rclcpp::Service<ml_classifiers::srv::CreateClassifier>::SharedPtr
create_srv;
96 rclcpp::Service<ml_classifiers::srv::AddClassData>::SharedPtr
add_srv;
97 rclcpp::Service<ml_classifiers::srv::TrainClassifier>::SharedPtr
train_srv;
98 rclcpp::Service<ml_classifiers::srv::ClearClassifier>::SharedPtr
clear_srv;
99 rclcpp::Service<ml_classifiers::srv::SaveClassifier>::SharedPtr
save_srv;
100 rclcpp::Service<ml_classifiers::srv::LoadClassifier>::SharedPtr
load_srv;
101 rclcpp::Service<ml_classifiers::srv::ClassifyData>::SharedPtr
classify_srv;
110 "Classifer plugin failed to load! Error: %s",
119 const std::shared_ptr<rmw_request_id_t> req_hdr,
120 const std::shared_ptr<ml_classifiers::srv::CreateClassifier::Request> req,
121 std::shared_ptr<ml_classifiers::srv::CreateClassifier::Response> res)
124 string id = req->identifier;
125 std::shared_ptr<Classifier>
c;
128 res->success =
false;
130 if (classifier_list.find(
id) != classifier_list.end()) {
133 "WARNING: ID already exists, overwriting: %s",
134 req->identifier.c_str());
135 classifier_list.erase(
id);
138 classifier_list[id] =
c;
145 const std::shared_ptr<rmw_request_id_t> req_hdr,
146 const std::shared_ptr<ml_classifiers::srv::AddClassData::Request> req,
147 std::shared_ptr<ml_classifiers::srv::AddClassData::Response> res)
150 string id = req->identifier;
152 if (classifier_list.find(
id) == classifier_list.end()) {
153 res->success =
false;
155 for (
size_t i = 0; i < req->data.size(); i++) {
156 classifier_list[id]->addTrainingPoint(req->data[i].target_class, req->data[i].point);
164 const std::shared_ptr<rmw_request_id_t> req_hdr,
165 const std::shared_ptr<ml_classifiers::srv::TrainClassifier::Request> req,
166 std::shared_ptr<ml_classifiers::srv::TrainClassifier::Response> res)
169 string id = req->identifier;
171 if (classifier_list.find(
id) == classifier_list.end()) {
172 res->success =
false;
179 classifier_list[id]->train();
185 const std::shared_ptr<rmw_request_id_t> req_hdr,
186 const std::shared_ptr<ml_classifiers::srv::ClearClassifier::Request> req,
187 std::shared_ptr<ml_classifiers::srv::ClearClassifier::Response> res)
190 string id = req->identifier;
192 if (classifier_list.find(
id) == classifier_list.end()) {
193 res->success =
false;
195 classifier_list[id]->clear();
201 const std::shared_ptr<rmw_request_id_t> req_hdr,
202 const std::shared_ptr<ml_classifiers::srv::SaveClassifier::Request> req,
203 std::shared_ptr<ml_classifiers::srv::SaveClassifier::Response> res)
206 string id = req->identifier;
208 if (classifier_list.find(
id) == classifier_list.end()) {
209 res->success =
false;
211 classifier_list[id]->save(req->filename);
217 const std::shared_ptr<rmw_request_id_t> req_hdr,
218 const std::shared_ptr<ml_classifiers::srv::LoadClassifier::Request> req,
219 std::shared_ptr<ml_classifiers::srv::LoadClassifier::Response> res)
222 string id = req->identifier;
224 std::shared_ptr<Classifier>
c;
227 res->success =
false;
229 if (!c->load(req->filename)) {
230 res->success =
false;
232 if (classifier_list.find(
id) != classifier_list.end()) {
235 "WARNING: ID already exists, overwriting: %s",
236 req->identifier.c_str());
237 classifier_list.erase(
id);
239 classifier_list[id] =
c;
247 const std::shared_ptr<rmw_request_id_t> req_hdr,
248 const std::shared_ptr<ml_classifiers::srv::ClassifyData::Request> req,
249 std::shared_ptr<ml_classifiers::srv::ClassifyData::Response> res)
252 string id = req->identifier;
254 for (
size_t i = 0; i < req->data.size(); i++) {
255 string class_num = classifier_list[id]->classifyPoint(req->data[i].point);
256 res->classifications.push_back(class_num);
261 int main(
int argc,
char ** argv)
263 rclcpp::init(argc, argv);
265 auto node = std::make_shared<ClassifierServer>();
bool createHelper(string class_type, boost::shared_ptr< Classifier > &c)
int main(int argc, char **argv)
void createCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::CreateClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::CreateClassifier::Response > res)
rclcpp::Service< ml_classifiers::srv::LoadClassifier >::SharedPtr load_srv
void classifyCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::ClassifyData::Request > req, std::shared_ptr< ml_classifiers::srv::ClassifyData::Response > res)
void loadCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::LoadClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::LoadClassifier::Response > res)
void trainCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::TrainClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::TrainClassifier::Response > res)
void saveCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::SaveClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::SaveClassifier::Response > res)
T * createUnmanagedInstance(const std::string &lookup_name)
rclcpp::Service< ml_classifiers::srv::AddClassData >::SharedPtr add_srv
pluginlib::ClassLoader< Classifier > c_loader
rclcpp::Service< ml_classifiers::srv::ClassifyData >::SharedPtr classify_srv
void addCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::AddClassData::Request > req, std::shared_ptr< ml_classifiers::srv::AddClassData::Response > res)
pluginlib::ClassLoader< Classifier > c_loader("ml_classifiers", "ml_classifiers::Classifier")
rclcpp::Service< ml_classifiers::srv::TrainClassifier >::SharedPtr train_srv
rclcpp::Service< ml_classifiers::srv::ClearClassifier >::SharedPtr clear_srv
void clearCallback(const std::shared_ptr< rmw_request_id_t > req_hdr, const std::shared_ptr< ml_classifiers::srv::ClearClassifier::Request > req, std::shared_ptr< ml_classifiers::srv::ClearClassifier::Response > res)
rclcpp::Service< ml_classifiers::srv::SaveClassifier >::SharedPtr save_srv
std::map< string, std::shared_ptr< Classifier > > classifier_list
bool createHelper(string class_type, std::shared_ptr< Classifier > &c)
rclcpp::Service< ml_classifiers::srv::CreateClassifier >::SharedPtr create_srv