ssd_object_detector.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Author: Furushchev <furushchev@jsk.imi.i.u-tokyo.ac.jp>
4 
5 from __future__ import print_function
6 
7 import matplotlib
8 matplotlib.use("Agg")
9 import matplotlib.pyplot as plt
10 
11 from cv_bridge import CvBridge
12 from jsk_topic_tools import ConnectionBasedTransport
13 import numpy as np
14 import rospy
15 import time
16 import yaml
17 
18 from dynamic_reconfigure.server import Server
19 from jsk_perception.cfg import SSDObjectDetectorConfig as Config
20 
21 from sensor_msgs.msg import Image
22 from jsk_recognition_msgs.msg import Rect, RectArray
23 from jsk_recognition_msgs.msg import ClassificationResult
24 from jsk_recognition_msgs.msg import Label
25 from jsk_recognition_msgs.msg import LabelArray
26 from jsk_recognition_msgs.msg import ClusterPointIndices
27 from pcl_msgs.msg import PointIndices
28 
29 import itertools, pkg_resources, sys
30 from distutils.version import LooseVersion
31 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
32  sys.version_info.major == 2:
33  print('''Please install chainer < 7.0.0:
34 
35  sudo pip install chainer==6.7.0
36 
37 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
38 ''', file=sys.stderr)
39  sys.exit(1)
40 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
41  print('''Please install CuPy
42 
43  sudo pip install cupy-cuda[your cuda version]
44 i.e.
45  sudo pip install cupy-cuda91
46 
47 ''', file=sys.stderr)
48 import chainer
49 from chainercv.links import SSD300
50 from chainercv.links import SSD512
51 from chainercv.visualizations import vis_bbox
52 
53 
54 chainer.config.cv_resize_backend = 'cv2'
55 
56 
57 class SSDObjectDetector(ConnectionBasedTransport):
58 
59  def __init__(self):
60  super(SSDObjectDetector, self).__init__()
61  self.gpu = rospy.get_param("~gpu", -1)
62  self.classifier_name = rospy.get_param("~classifier_name", rospy.get_name())
63 
64  self.cv_bridge = CvBridge()
65 
66  # load model
68  rospy.loginfo("Loaded %d labels" % len(self.label_names))
69 
70  # model_path: name of pretrained model or path to model file
71  model_path = rospy.get_param("~model_path", None)
72  model_name = rospy.get_param('~model', 'ssd300')
73  if model_name == 'ssd300':
74  model_class = SSD300
75  elif model_name == 'ssd512':
76  model_class = SSD512
77  else:
78  rospy.logerr('Unsupported ~model: {}'.format(model_name))
79  self.model = model_class(
80  n_fg_class=len(self.label_names),
81  pretrained_model=model_path)
82  if self.gpu >= 0:
83  chainer.cuda.get_device_from_id(self.gpu).use()
84  self.model.to_gpu()
85  rospy.loginfo("Loaded model: %s" % model_path)
86 
87  # dynamic reconfigure
88  self.srv = Server(Config, self.config_callback)
89 
90  # advertise
91  self.pub_labels = self.advertise("~output/labels", LabelArray,
92  queue_size=1)
93  self.pub_indices = self.advertise("~output/cluster_indices", ClusterPointIndices,
94  queue_size=1)
95  self.pub_rects = self.advertise("~output/rect", RectArray,
96  queue_size=1)
97  self.pub_class = self.advertise("~output/class", ClassificationResult,
98  queue_size=1)
99  self.pub_image = self.advertise("~output/image", Image,
100  queue_size=1)
101 
102  def subscribe(self):
103  self.sub_image = rospy.Subscriber("~input", Image, self.image_cb,
104  queue_size=1, buff_size=2**26)
105 
106  def unsubscribe(self):
107  self.sub_image.unregister()
108 
109  @property
110  def visualize(self):
111  return self.pub_image.get_num_connections() > 0
112 
113  def load_label_names(self):
114  label_names = rospy.get_param("~label_names", tuple())
115  if not label_names:
116  try:
117  from chainercv.datasets import voc_detection_label_names
118  label_names = voc_detection_label_names
119  except:
120  from chainercv.datasets import voc_bbox_label_names
121  label_names = voc_bbox_label_names
122  elif isinstance(label_names, str):
123  with open(label_names, "r") as f:
124  label_names = tuple(yaml.load(f))
125  return label_names
126 
127  def config_callback(self, config, level):
128  self.model.nms_thresh = config.nms_thresh
129  self.model.score_thresh = config.score_thresh
130  self.profiling = config.profiling
131  return config
132 
133  def image_cb(self, msg):
134  if self.profiling:
135  rospy.loginfo("callback start: incomming msg is %s msec behind" % ((rospy.Time.now() - msg.header.stamp).to_sec() * 1000.0))
136  tprev = time.time()
137  try:
138  # transform image to RGB, float, CHW
139  img = self.cv_bridge.imgmsg_to_cv2(msg, desired_encoding="rgb8")
140  img = np.asarray(img, dtype=np.float32)
141  img = img.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
142  except Exception as e:
143  rospy.logerr("Failed to convert image: %s" % str(e))
144  return
145  if self.profiling:
146  tcur = time.time()
147  rospy.loginfo("%s: elapsed %f msec" % ("convert", (tcur-tprev)*1000))
148  tprev = tcur
149 
150  if self.gpu >= 0:
151  chainer.cuda.get_device_from_id(self.gpu).use()
152  bboxes, labels, scores = self.model.predict([img])
153  bboxes, labels, scores = bboxes[0], labels[0], scores[0]
154 
155  if self.profiling:
156  tcur = time.time()
157  rospy.loginfo("%s: elapsed %f msec" % ("predict", (tcur-tprev)*1000))
158  tprev = tcur
159 
160  labels_msg = LabelArray(header=msg.header)
161  for l in labels:
162  l_name = self.label_names[l]
163  labels_msg.labels.append(Label(id=l, name=l_name))
164 
165  if self.profiling:
166  tcur = time.time()
167  rospy.loginfo("%s: elapsed %f msec" % ("make labels msg", (tcur-tprev)*1000))
168 
169  cluster_indices_msg = ClusterPointIndices(header=msg.header)
170  H = img.shape[1]
171  W = img.shape[2]
172  for bbox in bboxes:
173  ymin = max(0, int(np.floor(bbox[0])))
174  xmin = max(0, int(np.floor(bbox[1])))
175  ymax = min(H, int(np.ceil(bbox[2])))
176  xmax = min(W, int(np.ceil(bbox[3])))
177  indices = [range(W*y+xmin, W*y+xmax) for y in range(ymin, ymax)]
178  indices = np.array(indices, dtype=np.int32).flatten()
179  indices_msg = PointIndices(header=msg.header, indices=indices)
180  cluster_indices_msg.cluster_indices.append(indices_msg)
181 
182  if self.profiling:
183  tcur = time.time()
184  rospy.loginfo("%s: elapsed %f msec" % ("make cluster_indices msg", (tcur-tprev)*1000))
185  tprev = tcur
186 
187  rect_msg = RectArray(header=msg.header)
188  for bbox in bboxes:
189  rect = Rect(x=bbox[1], y=bbox[0],
190  width= bbox[3] - bbox[1],
191  height=bbox[2] - bbox[0])
192  rect_msg.rects.append(rect)
193 
194  if self.profiling:
195  tcur = time.time()
196  rospy.loginfo("%s: elapsed %f msec" % ("make rect msg", (tcur-tprev)*1000))
197  tprev = tcur
198 
199  cls_msg = ClassificationResult(
200  header=msg.header,
201  classifier=self.classifier_name,
202  target_names=self.label_names,
203  labels=labels,
204  label_names=[self.label_names[l] for l in labels],
205  label_proba=scores,
206  )
207 
208  if self.profiling:
209  tcur = time.time()
210  rospy.loginfo("%s: elapsed %f msec" % ("make cls msg", (tcur-tprev)*1000))
211  tprev = tcur
212 
213  self.pub_labels.publish(labels_msg)
214  self.pub_indices.publish(cluster_indices_msg)
215  self.pub_rects.publish(rect_msg)
216  self.pub_class.publish(cls_msg)
217 
218  if self.profiling:
219  tcur = time.time()
220  rospy.loginfo("%s: elapsed %f msec" % ("publish msg", (tcur-tprev)*1000))
221  tprev = tcur
222 
223  if self.visualize:
224  self.publish_bbox_image(img, bboxes, labels, scores, msg.header)
225 
226  if self.profiling:
227  tcur = time.time()
228  rospy.loginfo("%s: elapsed %f msec" % ("callback end", (tcur-tprev)*1000))
229  tprev = tcur
230 
231  def publish_bbox_image(self, img, bbox, label, score, header):
232  vis_bbox(img, bbox, label, score,
233  label_names=self.label_names)
234  fig = plt.gcf()
235  fig.canvas.draw()
236  w, h = fig.canvas.get_width_height()
237  img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
238  fig.clf()
239  img.shape = (h, w, 3)
240  plt.close()
241  try:
242  msg = self.cv_bridge.cv2_to_imgmsg(img, "rgb8")
243  except Exception as e:
244  rospy.logerr("Failed to convert bbox image: %s" % str(e))
245  return
246  msg.header = header
247  self.pub_image.publish(msg)
248 
249 
250 if __name__ == '__main__':
251  rospy.init_node("ssd_object_detector")
253  rospy.spin()
def publish_bbox_image(self, img, bbox, label, score, header)


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27