mask_rcnn_instance_segmentation.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from __future__ import print_function
00004 import sys
00005 
00006 import chainer
00007 import numpy as np
00008 
00009 try:
00010     import chainer_mask_rcnn
00011 except ImportError:
00012     print('''Please install chainer_mask_rcnn:
00013 
00014     sudo pip install chainer-mask-rcnn
00015 
00016 ''', file=sys.stderr)
00017     sys.exit(1)
00018 
00019 import cv_bridge
00020 from jsk_recognition_msgs.msg import ClusterPointIndices
00021 from jsk_recognition_msgs.msg import Label
00022 from jsk_recognition_msgs.msg import LabelArray
00023 from jsk_recognition_msgs.msg import Rect
00024 from jsk_recognition_msgs.msg import RectArray
00025 from jsk_recognition_msgs.msg import ClassificationResult
00026 from jsk_topic_tools import ConnectionBasedTransport
00027 from pcl_msgs.msg import PointIndices
00028 import rospy
00029 from sensor_msgs.msg import Image
00030 
00031 
00032 class MaskRCNNInstanceSegmentation(ConnectionBasedTransport):
00033 
00034     def __init__(self):
00035         rospy.logwarn('This node is experimental, and its interface '
00036                       'can be changed in the future.')
00037 
00038         super(MaskRCNNInstanceSegmentation, self).__init__()
00039         # gpu
00040         self.gpu = rospy.get_param('~gpu', 0)
00041         chainer.global_config.train = False
00042         chainer.global_config.enable_backprop = False
00043 
00044         self.fg_class_names = rospy.get_param('~fg_class_names')
00045         pretrained_model = rospy.get_param('~pretrained_model')
00046         self.classifier_name = rospy.get_param(
00047             "~classifier_name", rospy.get_name())
00048 
00049         n_fg_class = len(self.fg_class_names)
00050         self.model = chainer_mask_rcnn.models.MaskRCNNResNet(
00051             n_layers=50,
00052             n_fg_class=n_fg_class,
00053             pretrained_model=pretrained_model,
00054             anchor_scales=rospy.get_param('~anchor_scales', [4, 8, 16, 32]),
00055             min_size=rospy.get_param('~min_size', 600),
00056             max_size=rospy.get_param('~max_size', 1000),
00057         )
00058         self.model.score_thresh = rospy.get_param('~score_thresh', 0.7)
00059         if self.gpu >= 0:
00060             self.model.to_gpu(self.gpu)
00061 
00062         self.pub_indices = self.advertise(
00063             '~output/cluster_indices', ClusterPointIndices, queue_size=1)
00064         self.pub_labels = self.advertise(
00065             '~output/labels', LabelArray, queue_size=1)
00066         self.pub_lbl_cls = self.advertise(
00067             '~output/label_cls', Image, queue_size=1)
00068         self.pub_lbl_ins = self.advertise(
00069             '~output/label_ins', Image, queue_size=1)
00070         self.pub_viz = self.advertise(
00071             '~output/viz', Image, queue_size=1)
00072         self.pub_rects = self.advertise(
00073             "~output/rects", RectArray,
00074             queue_size=1)
00075         self.pub_class = self.advertise(
00076             "~output/class", ClassificationResult,
00077             queue_size=1)
00078 
00079     def subscribe(self):
00080         self.sub = rospy.Subscriber('~input', Image, self.callback,
00081                                     queue_size=1, buff_size=2**24)
00082 
00083     def unsubscribe(self):
00084         self.sub.unregister()
00085 
00086     def callback(self, imgmsg):
00087         bridge = cv_bridge.CvBridge()
00088         img = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
00089         img_chw = img.transpose(2, 0, 1)  # C, H, W
00090 
00091         if self.gpu >= 0:
00092             chainer.cuda.get_device_from_id(self.gpu).use()
00093         bboxes, masks, labels, scores = self.model.predict([img_chw])
00094 
00095         bboxes = bboxes[0]
00096         masks = masks[0]
00097         labels = labels[0]
00098         scores = scores[0]
00099 
00100         msg_indices = ClusterPointIndices(header=imgmsg.header)
00101         msg_labels = LabelArray(header=imgmsg.header)
00102         # -1: label for background
00103         lbl_cls = - np.ones(img.shape[:2], dtype=np.int32)
00104         lbl_ins = - np.ones(img.shape[:2], dtype=np.int32)
00105         for ins_id, (mask, label) in enumerate(zip(masks, labels)):
00106             indices = np.where(mask.flatten())[0]
00107             indices_msg = PointIndices(header=imgmsg.header, indices=indices)
00108             msg_indices.cluster_indices.append(indices_msg)
00109             class_name = self.fg_class_names[label]
00110             msg_labels.labels.append(Label(id=label, name=class_name))
00111             lbl_cls[mask] = label
00112             lbl_ins[mask] = ins_id  # instance_id
00113         self.pub_indices.publish(msg_indices)
00114         self.pub_labels.publish(msg_labels)
00115 
00116         msg_lbl_cls = bridge.cv2_to_imgmsg(lbl_cls)
00117         msg_lbl_ins = bridge.cv2_to_imgmsg(lbl_ins)
00118         msg_lbl_cls.header = msg_lbl_ins.header = imgmsg.header
00119         self.pub_lbl_cls.publish(msg_lbl_cls)
00120         self.pub_lbl_ins.publish(msg_lbl_ins)
00121 
00122         cls_msg = ClassificationResult(
00123             header=imgmsg.header,
00124             classifier=self.classifier_name,
00125             target_names=self.fg_class_names,
00126             labels=labels,
00127             label_names=[self.fg_class_names[l] for l in labels],
00128             label_proba=scores,
00129         )
00130 
00131         rects_msg = RectArray(header=imgmsg.header)
00132         for bbox in bboxes:
00133             rect = Rect(x=bbox[1], y=bbox[0],
00134                         width=bbox[3] - bbox[1],
00135                         height=bbox[2] - bbox[0])
00136             rects_msg.rects.append(rect)
00137         self.pub_rects.publish(rects_msg)
00138         self.pub_class.publish(cls_msg)
00139 
00140         if self.pub_viz.get_num_connections() > 0:
00141             n_fg_class = len(self.fg_class_names)
00142             captions = ['{:d}: {:s}'.format(l, self.fg_class_names[l])
00143                         for l in labels]
00144             viz = chainer_mask_rcnn.utils.draw_instance_bboxes(
00145                 img, bboxes, labels + 1, n_class=n_fg_class + 1,
00146                 masks=masks, captions=captions)
00147             msg_viz = bridge.cv2_to_imgmsg(viz, encoding='rgb8')
00148             msg_viz.header = imgmsg.header
00149             self.pub_viz.publish(msg_viz)
00150 
00151 
00152 if __name__ == '__main__':
00153     rospy.init_node('mask_rcnn_instance_segmentation')
00154     node = MaskRCNNInstanceSegmentation()
00155     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07