mask_rcnn_instance_segmentation.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 import os
5 import sys
6 import yaml
7 
8 import itertools, pkg_resources
9 from distutils.version import LooseVersion
10 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
11  sys.version_info.major == 2:
12  print('''Please install chainer < 7.0.0:
13 
14  sudo pip install chainer==6.7.0
15 
16 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
17 ''', file=sys.stderr)
18  sys.exit(1)
19 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
20  print('''Please install CuPy
21 
22  sudo pip install cupy-cuda[your cuda version]
23 i.e.
24  sudo pip install cupy-cuda91
25 
26 ''', file=sys.stderr)
27  # sys.exit(1)
28 import chainer
29 from chainercv.datasets.coco.coco_utils \
30  import coco_instance_segmentation_label_names
31 try:
32  from chainercv.links import MaskRCNNFPNResNet101
33  from chainercv.links import MaskRCNNFPNResNet50
34 except ImportError:
35  print('''If you want to use chainercv mask_rcnn, please upgrade chainercv
36 
37  sudo pip install chainercv>=0.13.0
38 
39 ''', file=sys.stdout)
40 
41 from chainercv.utils import mask_to_bbox
42 import numpy as np
43 import yaml
44 
45 try:
46  import chainer_mask_rcnn
47 except ImportError:
48  print('''Please install chainer_mask_rcnn:
49 
50  sudo pip install chainer-mask-rcnn
51 
52 ''', file=sys.stderr)
53  sys.exit(1)
54 
55 import cv_bridge
56 from dynamic_reconfigure.server import Server
57 from jsk_perception.cfg import MaskRCNNInstanceSegmentationConfig as Config
58 from jsk_recognition_msgs.msg import ClusterPointIndices
59 from jsk_recognition_msgs.msg import Label
60 from jsk_recognition_msgs.msg import LabelArray
61 from jsk_recognition_msgs.msg import Rect
62 from jsk_recognition_msgs.msg import RectArray
63 from jsk_recognition_msgs.msg import ClassificationResult
64 from jsk_topic_tools import ConnectionBasedTransport
65 from pcl_msgs.msg import PointIndices
66 import rospkg
67 import rospy
68 from sensor_msgs.msg import Image
69 
70 
71 class MaskRCNNInstanceSegmentation(ConnectionBasedTransport):
72 
73  def __init__(self):
74  rospy.logwarn('This node is experimental, and its interface '
75  'can be changed in the future.')
76 
77  super(MaskRCNNInstanceSegmentation, self).__init__()
78  # gpu
79  self.gpu = rospy.get_param('~gpu', 0)
80  chainer.global_config.train = False
81  chainer.global_config.enable_backprop = False
82 
83  fg_class_names = rospy.get_param('~fg_class_names', None)
84  if isinstance(fg_class_names, str) and os.path.exists(fg_class_names):
85  rospy.loginfo('Loading class names from file: {}'.format(fg_class_names))
86  with open(fg_class_names, 'r') as f:
87  fg_class_names = yaml.load(f)
88  self.fg_class_names = fg_class_names
89 
90  pretrained_model = rospy.get_param('~pretrained_model')
91  self.classifier_name = rospy.get_param(
92  "~classifier_name", rospy.get_name())
93  self.model_name = rospy.get_param('~model_name', 'mask_rcnn_resnet50')
94  rospack = rospkg.RosPack()
95 
96  if self.model_name == 'mask_rcnn_resnet50':
97  if pretrained_model == 'coco':
98  pretrained_model = os.path.join(
99  rospack.get_path('jsk_perception'),
100  'trained_data/mask_rcnn_resnet50_coco_20180730.npz')
101  if self.fg_class_names is None:
102  yaml_path = os.path.join(
103  rospack.get_path('jsk_perception'),
104  'sample/config/coco_class_names.yaml')
105  with open(yaml_path) as yaml_f:
106  self.fg_class_names = yaml.load(yaml_f)
107  elif pretrained_model == 'voc':
108  pretrained_model = os.path.join(
109  rospack.get_path('jsk_perception'),
110  'trained_data/mask_rcnn_resnet50_voc_20180516.npz')
111  if self.fg_class_names is None:
112  yaml_path = os.path.join(
113  rospack.get_path('jsk_perception'),
114  'sample/config/voc_class_names.yaml')
115  with open(yaml_path) as yaml_f:
116  self.fg_class_names = yaml.load(yaml_f)
117 
118  self.model = chainer_mask_rcnn.models.MaskRCNNResNet(
119  n_layers=50,
120  n_fg_class=len(self.fg_class_names),
121  pretrained_model=pretrained_model,
122  anchor_scales=rospy.get_param('~anchor_scales', [4, 8, 16, 32]),
123  min_size=rospy.get_param('~min_size', 600),
124  max_size=rospy.get_param('~max_size', 1000),
125  )
126  elif self.model_name == 'mask_rcnn_fpn_resnet50':
127  if pretrained_model == 'coco':
128  self.fg_class_names = coco_instance_segmentation_label_names
129  self.model = MaskRCNNFPNResNet50(
130  n_fg_class=len(self.fg_class_names),
131  pretrained_model=pretrained_model)
132  self.model.use_preset('visualize')
133  elif self.model_name == 'mask_rcnn_fpn_resnet101':
134  if pretrained_model == 'coco':
135  self.fg_class_names = coco_instance_segmentation_label_names
136  self.model = MaskRCNNFPNResNet101(
137  n_fg_class=len(self.fg_class_names),
138  pretrained_model=pretrained_model)
139  self.model.use_preset('visualize')
140  else:
141  rospy.logerr('Unsupported model_name: {}'.format(self.model_name))
142  self.model.score_thresh = rospy.get_param('~score_thresh', 0.7)
143  if self.gpu >= 0:
144  self.model.to_gpu(self.gpu)
145 
146  self.srv = Server(Config, self.config_callback)
147 
148  self.pub_indices = self.advertise(
149  '~output/cluster_indices', ClusterPointIndices, queue_size=1)
150  self.pub_labels = self.advertise(
151  '~output/labels', LabelArray, queue_size=1)
152  self.pub_lbl_cls = self.advertise(
153  '~output/label_cls', Image, queue_size=1)
154  self.pub_lbl_ins = self.advertise(
155  '~output/label_ins', Image, queue_size=1)
156  self.pub_viz = self.advertise(
157  '~output/viz', Image, queue_size=1)
158  self.pub_rects = self.advertise(
159  "~output/rects", RectArray,
160  queue_size=1)
161  self.pub_class = self.advertise(
162  "~output/class", ClassificationResult,
163  queue_size=1)
164 
165  def subscribe(self):
166  self.sub = rospy.Subscriber('~input', Image, self.callback,
167  queue_size=1, buff_size=2**24)
168 
169  def unsubscribe(self):
170  self.sub.unregister()
171 
172  def config_callback(self, config, level):
173  self.model.score_thresh = config.score_thresh
174  return config
175 
176  def callback(self, imgmsg):
177  bridge = cv_bridge.CvBridge()
178  img = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
179  img_chw = img.transpose((2, 0, 1)) # C, H, W
180 
181  if self.gpu >= 0:
182  chainer.cuda.get_device_from_id(self.gpu).use()
183  if self.model_name == 'mask_rcnn_resnet50':
184  bboxes, masks, labels, scores = self.model.predict([img_chw])
185  bboxes = bboxes[0]
186  masks = masks[0]
187  labels = labels[0]
188  scores = scores[0]
189  else:
190  img_chw = img_chw.astype(np.float32)
191  masks, labels, scores = self.model.predict([img_chw])
192  masks = masks[0]
193  labels = labels[0]
194  scores = scores[0]
195  bboxes = mask_to_bbox(masks)
196 
197  msg_indices = ClusterPointIndices(header=imgmsg.header)
198  msg_labels = LabelArray(header=imgmsg.header)
199  # -1: label for background
200  lbl_cls = - np.ones(img.shape[:2], dtype=np.int32)
201  lbl_ins = - np.ones(img.shape[:2], dtype=np.int32)
202  for ins_id, (mask, label) in enumerate(zip(masks, labels)):
203  indices = np.where(mask.flatten())[0]
204  indices_msg = PointIndices(header=imgmsg.header, indices=indices)
205  msg_indices.cluster_indices.append(indices_msg)
206  class_name = self.fg_class_names[label]
207  msg_labels.labels.append(Label(id=label, name=class_name))
208  lbl_cls[mask] = label
209  lbl_ins[mask] = ins_id # instance_id
210  self.pub_indices.publish(msg_indices)
211  self.pub_labels.publish(msg_labels)
212 
213  msg_lbl_cls = bridge.cv2_to_imgmsg(lbl_cls)
214  msg_lbl_ins = bridge.cv2_to_imgmsg(lbl_ins)
215  msg_lbl_cls.header = msg_lbl_ins.header = imgmsg.header
216  self.pub_lbl_cls.publish(msg_lbl_cls)
217  self.pub_lbl_ins.publish(msg_lbl_ins)
218 
219  cls_msg = ClassificationResult(
220  header=imgmsg.header,
221  classifier=self.classifier_name,
222  target_names=self.fg_class_names,
223  labels=labels,
224  label_names=[self.fg_class_names[l] for l in labels],
225  label_proba=scores,
226  )
227 
228  rects_msg = RectArray(header=imgmsg.header)
229  for bbox in bboxes:
230  rect = Rect(x=bbox[1], y=bbox[0],
231  width=bbox[3] - bbox[1],
232  height=bbox[2] - bbox[0])
233  rects_msg.rects.append(rect)
234  self.pub_rects.publish(rects_msg)
235  self.pub_class.publish(cls_msg)
236 
237  if self.pub_viz.get_num_connections() > 0:
238  n_fg_class = len(self.fg_class_names)
239  captions = ['{:d}: {:s}'.format(l, self.fg_class_names[l])
240  for l in labels]
241  viz = chainer_mask_rcnn.utils.draw_instance_bboxes(
242  img, bboxes, labels + 1, n_class=n_fg_class + 1,
243  masks=masks, captions=captions)
244  msg_viz = bridge.cv2_to_imgmsg(viz, encoding='rgb8')
245  msg_viz.header = imgmsg.header
246  self.pub_viz.publish(msg_viz)
247 
248 
249 if __name__ == '__main__':
250  rospy.init_node('mask_rcnn_instance_segmentation')
252  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27