dnn_detect.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017, Ubiquity Robotics
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  * this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright notice,
11  * this list of conditions and the following disclaimer in the documentation
12  * and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
18  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24  * POSSIBILITY OF SUCH DAMAGE.
25  *
26  * The views and conclusions contained in the software and documentation are
27  * those of the authors and should not be interpreted as representing official
28  * policies, either expressed or implied, of the FreeBSD Project.
29  *
30  */
31 
32 #include <assert.h>
33 #include <sys/time.h>
34 #include <unistd.h>
35 
36 #include <ros/ros.h>
38 #include <cv_bridge/cv_bridge.h>
39 
40 #include "dnn_detect/DetectedObject.h"
41 #include "dnn_detect/DetectedObjectArray.h"
42 #include "dnn_detect/Detect.h"
43 
44 #include <opencv2/highgui.hpp>
45 #include <opencv2/dnn.hpp>
46 #include <opencv2/calib3d.hpp>
47 
48 #include <list>
49 #include <string>
50 #include <boost/algorithm/string.hpp>
51 #include <boost/format.hpp>
52 
53 #include <thread>
54 #include <mutex>
55 #include <condition_variable>
56 
57 using namespace std;
58 using namespace cv;
59 
60 std::condition_variable cond;
61 std::mutex mutx;
62 
63 class DnnNode {
64  private:
66 
69 
70  // if set, we publish the images that contain objects
72 
73  int frame_num;
75  int im_size;
77  float scale_factor;
78  float mean_val;
79  std::vector<std::string> class_names;
80 
82 
83  cv::dnn::Net net;
84  cv::Mat resized_image;
85  cv::Mat rotated_image;
86 
88  volatile bool triggered;
89  volatile bool processed;
90 
91  dnn_detect::DetectedObjectArray results;
92 
94 
95  bool trigger_callback(dnn_detect::Detect::Request &req,
96  dnn_detect::Detect::Response &res);
97 
98  void image_callback(const sensor_msgs::ImageConstPtr &msg);
99 
100  public:
102 };
103 
104 bool DnnNode:: trigger_callback(dnn_detect::Detect::Request &req,
105  dnn_detect::Detect::Response &res)
106 {
107  ROS_INFO("Got service request");
108  triggered = true;
109 
110  std::unique_lock<std::mutex> lock(mutx);
111 
112  while (!processed) {
113  cond.wait(lock);
114  }
115  res.result = results;
116  processed = false;
117  return true;
118 }
119 
120 
121 void DnnNode::image_callback(const sensor_msgs::ImageConstPtr & msg)
122 {
123  if (single_shot && !triggered) {
124  return;
125  }
126  triggered = false;
127 
128  ROS_INFO("Got image %d", msg->header.seq);
129  frame_num++;
130 
131  cv_bridge::CvImagePtr cv_ptr;
132 
133  try {
135 
136  int w = cv_ptr->image.cols;
137  int h = cv_ptr->image.rows;
138 
139  if (rotate_flag >= 0) {
140  cv::rotate(cv_ptr->image, rotated_image, rotate_flag);
141  rotated_image.copyTo(cv_ptr->image);
142  }
143 
144  cv::resize(cv_ptr->image, resized_image, cvSize(im_size, im_size));
145  cv::Mat blob = cv::dnn::blobFromImage(resized_image, scale_factor,
146  cvSize(im_size, im_size), mean_val, false);
147 
148  net.setInput(blob, "data");
149  cv::Mat objs = net.forward("detection_out");
150 
151  cv::Mat detectionMat(objs.size[2], objs.size[3], CV_32F,
152  objs.ptr<float>());
153 
154  std::unique_lock<std::mutex> lock(mutx);
155  results.header.frame_id = msg->header.frame_id;
156  results.objects.clear();
157 
158  for(int i = 0; i < detectionMat.rows; i++) {
159 
160  float confidence = detectionMat.at<float>(i, 2);
161  if (confidence > min_confidence) {
162  int object_class = (int)(detectionMat.at<float>(i, 1));
163 
164  int x_min = static_cast<int>(detectionMat.at<float>(i, 3) * w);
165  int y_min = static_cast<int>(detectionMat.at<float>(i, 4) * h);
166  int x_max = static_cast<int>(detectionMat.at<float>(i, 5) * w);
167  int y_max = static_cast<int>(detectionMat.at<float>(i, 6) * h);
168 
169  std::string class_name;
170  if (object_class >= class_names.size()) {
171  class_name = "unknown";
172  ROS_ERROR("Object class %d out of range of class names",
173  object_class);
174  }
175  else {
176  class_name = class_names[object_class];
177  }
178  std::string label = str(boost::format{"%1% %2%"} %
179  class_name % confidence);
180 
181  ROS_INFO("%s", label.c_str());
182  dnn_detect::DetectedObject obj;
183  obj.class_name = class_name;
184  obj.confidence = confidence;
185  obj.x_min = x_min;
186  obj.x_max = x_max;
187  obj.y_min = y_min;
188  obj.y_max = y_max;
189  results.objects.push_back(obj);
190 
191  Rect object(x_min, y_min, x_max-x_min, y_max-y_min);
192 
193  rectangle(cv_ptr->image, object, Scalar(0, 255, 0));
194  int baseline=0;
195  cv::Size text_size = cv::getTextSize(label,
196  FONT_HERSHEY_SIMPLEX, 0.75, 2, &baseline);
197  putText(cv_ptr->image, label, Point(x_min, y_min-text_size.height),
198  FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0));
199  }
200  }
201 
202  results_pub.publish(results);
203 
204  image_pub.publish(cv_ptr->toImageMsg());
205 
206  }
207  catch(cv_bridge::Exception & e) {
208  ROS_ERROR("cv_bridge exception: %s", e.what());
209  }
210  catch(cv::Exception & e) {
211  ROS_ERROR("cv exception: %s", e.what());
212  }
213  ROS_DEBUG("Notifying condition variable");
214  processed = true;
215  cond.notify_all();
216 }
217 
219 {
220  frame_num = 0;
221 
222  std::string dir;
223  std::string proto_net_file;
224  std::string caffe_model_file;
225  std::string classes("background,"
226  "aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,"
227  "cow,diningtable,dog,horse,motorbike,person,pottedplant,"
228  "sheep,sofa,train,tvmonitor");
229 
230  nh.param<bool>("single_shot", single_shot, false);
231 
232  nh.param<bool>("publish_images", publish_images, false);
233  nh.param<string>("data_dir", dir, "");
234  nh.param<string>("protonet_file", proto_net_file,
235  "MobileNetSSD_deploy.prototxt.txt");
236  nh.param<string>("caffe_model_file", caffe_model_file,
237  "MobileNetSSD_deploy.caffemodel");
238  nh.param<float>("min_confidence", min_confidence, 0.2);
239  nh.param<int>("im_size", im_size, 300);
240  nh.param<int>("rotate_flag", rotate_flag, -1);
241  nh.param<float>("scale_factor", scale_factor, 0.007843f);
242  nh.param<float>("mean_val", mean_val, 127.5f);
243  nh.param<std::string>("class_names", classes, classes);
244 
245  boost::split(class_names, classes, boost::is_any_of(","));
246  ROS_INFO("Read %d class names", (int)class_names.size());
247 
248  try {
249  net = cv::dnn::readNetFromCaffe(dir + "/" + proto_net_file,
250  dir + "/" + caffe_model_file);
251  }
252  catch(cv::Exception & e) {
253  ROS_ERROR("cv exception: %s", e.what());
254  exit(1);
255  }
256 
257  triggered = false;
258 
260 
261  results_pub =
262  nh.advertise<dnn_detect::DetectedObjectArray>("/dnn_objects", 20);
263 
264  image_pub = it.advertise("/dnn_images", 1);
265 
266  img_sub = it.subscribe("/camera", 1,
267  &DnnNode::image_callback, this);
268 
269  ROS_INFO("DNN detection ready");
270 }
271 
272 int main(int argc, char ** argv) {
273  ros::init(argc, argv, "dnn_detect");
274  ros::NodeHandle nh("~");
275 
276  DnnNode node = DnnNode(nh);
278  spinner.spin();
279 
280  return 0;
281 }
image_transport::Publisher image_pub
Definition: dnn_detect.cpp:81
Subscriber subscribe(const std::string &base_topic, uint32_t queue_size, const boost::function< void(const sensor_msgs::ImageConstPtr &)> &callback, const ros::VoidPtr &tracked_object=ros::VoidPtr(), const TransportHints &transport_hints=TransportHints())
dnn_detect::DetectedObjectArray results
Definition: dnn_detect.cpp:91
ros::ServiceServer detect_srv
Definition: dnn_detect.cpp:93
int frame_num
Definition: dnn_detect.cpp:73
float mean_val
Definition: dnn_detect.cpp:78
int main(int argc, char **argv)
Definition: dnn_detect.cpp:272
int rotate_flag
Definition: dnn_detect.cpp:76
float scale_factor
Definition: dnn_detect.cpp:77
ROSCPP_DECL void init(int &argc, char **argv, const std::string &name, uint32_t options=0)
Publisher advertise(const std::string &base_topic, uint32_t queue_size, bool latch=false)
DnnNode(ros::NodeHandle &nh)
Definition: dnn_detect.cpp:218
volatile bool triggered
Definition: dnn_detect.cpp:88
image_transport::ImageTransport it
Definition: dnn_detect.cpp:67
ServiceServer advertiseService(const std::string &service, bool(T::*srv_func)(MReq &, MRes &), T *obj)
void image_callback(const sensor_msgs::ImageConstPtr &msg)
Definition: dnn_detect.cpp:121
cv::dnn::Net net
Definition: dnn_detect.cpp:83
void spinner()
std::vector< std::string > class_names
Definition: dnn_detect.cpp:79
int im_size
Definition: dnn_detect.cpp:75
bool publish_images
Definition: dnn_detect.cpp:71
bool single_shot
Definition: dnn_detect.cpp:87
#define ROS_INFO(...)
image_transport::Subscriber img_sub
Definition: dnn_detect.cpp:68
bool param(const std::string &param_name, T &param_val, const T &default_val) const
CvImagePtr toCvCopy(const sensor_msgs::ImageConstPtr &source, const std::string &encoding=std::string())
Publisher advertise(const std::string &topic, uint32_t queue_size, bool latch=false)
std::condition_variable cond
Definition: dnn_detect.cpp:60
std::mutex mutx
Definition: dnn_detect.cpp:61
cv::Mat resized_image
Definition: dnn_detect.cpp:84
virtual void spin(CallbackQueue *queue=0)
cv::Mat rotated_image
Definition: dnn_detect.cpp:85
bool trigger_callback(dnn_detect::Detect::Request &req, dnn_detect::Detect::Response &res)
Definition: dnn_detect.cpp:104
ros::Publisher results_pub
Definition: dnn_detect.cpp:65
float min_confidence
Definition: dnn_detect.cpp:74
#define ROS_ERROR(...)
volatile bool processed
Definition: dnn_detect.cpp:89
#define ROS_DEBUG(...)


dnn_detect
Author(s): Jim Vaughan
autogenerated on Thu Sep 24 2020 03:23:11