00001
00002
00003 from __future__ import print_function
00004 import sys
00005
00006 import chainer
00007 import numpy as np
00008
00009 try:
00010 import chainer_mask_rcnn
00011 except ImportError:
00012 print('''Please install chainer_mask_rcnn:
00013
00014 sudo pip install chainer-mask-rcnn
00015
00016 ''', file=sys.stderr)
00017 sys.exit(1)
00018
00019 import cv_bridge
00020 from jsk_recognition_msgs.msg import ClusterPointIndices
00021 from jsk_recognition_msgs.msg import Label
00022 from jsk_recognition_msgs.msg import LabelArray
00023 from jsk_recognition_msgs.msg import Rect
00024 from jsk_recognition_msgs.msg import RectArray
00025 from jsk_recognition_msgs.msg import ClassificationResult
00026 from jsk_topic_tools import ConnectionBasedTransport
00027 from pcl_msgs.msg import PointIndices
00028 import rospy
00029 from sensor_msgs.msg import Image
00030
00031
00032 class MaskRCNNInstanceSegmentation(ConnectionBasedTransport):
00033
00034 def __init__(self):
00035 rospy.logwarn('This node is experimental, and its interface '
00036 'can be changed in the future.')
00037
00038 super(MaskRCNNInstanceSegmentation, self).__init__()
00039
00040 self.gpu = rospy.get_param('~gpu', 0)
00041 chainer.global_config.train = False
00042 chainer.global_config.enable_backprop = False
00043
00044 self.fg_class_names = rospy.get_param('~fg_class_names')
00045 pretrained_model = rospy.get_param('~pretrained_model')
00046 self.classifier_name = rospy.get_param(
00047 "~classifier_name", rospy.get_name())
00048
00049 n_fg_class = len(self.fg_class_names)
00050 self.model = chainer_mask_rcnn.models.MaskRCNNResNet(
00051 n_layers=50,
00052 n_fg_class=n_fg_class,
00053 pretrained_model=pretrained_model,
00054 anchor_scales=rospy.get_param('~anchor_scales', [4, 8, 16, 32]),
00055 min_size=rospy.get_param('~min_size', 600),
00056 max_size=rospy.get_param('~max_size', 1000),
00057 )
00058 self.model.score_thresh = rospy.get_param('~score_thresh', 0.7)
00059 if self.gpu >= 0:
00060 self.model.to_gpu(self.gpu)
00061
00062 self.pub_indices = self.advertise(
00063 '~output/cluster_indices', ClusterPointIndices, queue_size=1)
00064 self.pub_labels = self.advertise(
00065 '~output/labels', LabelArray, queue_size=1)
00066 self.pub_lbl_cls = self.advertise(
00067 '~output/label_cls', Image, queue_size=1)
00068 self.pub_lbl_ins = self.advertise(
00069 '~output/label_ins', Image, queue_size=1)
00070 self.pub_viz = self.advertise(
00071 '~output/viz', Image, queue_size=1)
00072 self.pub_rects = self.advertise(
00073 "~output/rects", RectArray,
00074 queue_size=1)
00075 self.pub_class = self.advertise(
00076 "~output/class", ClassificationResult,
00077 queue_size=1)
00078
00079 def subscribe(self):
00080 self.sub = rospy.Subscriber('~input', Image, self.callback,
00081 queue_size=1, buff_size=2**24)
00082
00083 def unsubscribe(self):
00084 self.sub.unregister()
00085
00086 def callback(self, imgmsg):
00087 bridge = cv_bridge.CvBridge()
00088 img = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
00089 img_chw = img.transpose(2, 0, 1)
00090
00091 if self.gpu >= 0:
00092 chainer.cuda.get_device_from_id(self.gpu).use()
00093 bboxes, masks, labels, scores = self.model.predict([img_chw])
00094
00095 bboxes = bboxes[0]
00096 masks = masks[0]
00097 labels = labels[0]
00098 scores = scores[0]
00099
00100 msg_indices = ClusterPointIndices(header=imgmsg.header)
00101 msg_labels = LabelArray(header=imgmsg.header)
00102
00103 lbl_cls = - np.ones(img.shape[:2], dtype=np.int32)
00104 lbl_ins = - np.ones(img.shape[:2], dtype=np.int32)
00105 for ins_id, (mask, label) in enumerate(zip(masks, labels)):
00106 indices = np.where(mask.flatten())[0]
00107 indices_msg = PointIndices(header=imgmsg.header, indices=indices)
00108 msg_indices.cluster_indices.append(indices_msg)
00109 class_name = self.fg_class_names[label]
00110 msg_labels.labels.append(Label(id=label, name=class_name))
00111 lbl_cls[mask] = label
00112 lbl_ins[mask] = ins_id
00113 self.pub_indices.publish(msg_indices)
00114 self.pub_labels.publish(msg_labels)
00115
00116 msg_lbl_cls = bridge.cv2_to_imgmsg(lbl_cls)
00117 msg_lbl_ins = bridge.cv2_to_imgmsg(lbl_ins)
00118 msg_lbl_cls.header = msg_lbl_ins.header = imgmsg.header
00119 self.pub_lbl_cls.publish(msg_lbl_cls)
00120 self.pub_lbl_ins.publish(msg_lbl_ins)
00121
00122 cls_msg = ClassificationResult(
00123 header=imgmsg.header,
00124 classifier=self.classifier_name,
00125 target_names=self.fg_class_names,
00126 labels=labels,
00127 label_names=[self.fg_class_names[l] for l in labels],
00128 label_proba=scores,
00129 )
00130
00131 rects_msg = RectArray(header=imgmsg.header)
00132 for bbox in bboxes:
00133 rect = Rect(x=bbox[1], y=bbox[0],
00134 width=bbox[3] - bbox[1],
00135 height=bbox[2] - bbox[0])
00136 rects_msg.rects.append(rect)
00137 self.pub_rects.publish(rects_msg)
00138 self.pub_class.publish(cls_msg)
00139
00140 if self.pub_viz.get_num_connections() > 0:
00141 n_fg_class = len(self.fg_class_names)
00142 captions = ['{:d}: {:s}'.format(l, self.fg_class_names[l])
00143 for l in labels]
00144 viz = chainer_mask_rcnn.utils.draw_instance_bboxes(
00145 img, bboxes, labels + 1, n_class=n_fg_class + 1,
00146 masks=masks, captions=captions)
00147 msg_viz = bridge.cv2_to_imgmsg(viz, encoding='rgb8')
00148 msg_viz.header = imgmsg.header
00149 self.pub_viz.publish(msg_viz)
00150
00151
00152 if __name__ == '__main__':
00153 rospy.init_node('mask_rcnn_instance_segmentation')
00154 node = MaskRCNNInstanceSegmentation()
00155 rospy.spin()