label_image_decomposer.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import cv2
00004 import matplotlib
00005 matplotlib.use('Agg')  # NOQA
00006 import matplotlib.cm
00007 import numpy as np
00008 import scipy.ndimage
00009 
00010 import cv_bridge
00011 from jsk_recognition_utils import bounding_rect_of_mask
00012 from jsk_recognition_utils import get_tile_image
00013 from jsk_recognition_utils.color import labelcolormap
00014 from jsk_topic_tools import ConnectionBasedTransport
00015 from jsk_topic_tools import warn_no_remap
00016 import message_filters
00017 import rospy
00018 from sensor_msgs.msg import Image
00019 
00020 
00021 def get_text_color(color):
00022     if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 > 170:
00023         return (0, 0, 0)
00024     return (255, 255, 255)
00025 
00026 
00027 def label2rgb(lbl, img=None, label_names=None, alpha=0.3):
00028     if label_names is None:
00029         n_labels = lbl.max() + 1  # +1 for bg_label 0
00030     else:
00031         n_labels = len(label_names)
00032     cmap = labelcolormap(n_labels)
00033     cmap = (cmap * 255).astype(np.uint8)
00034 
00035     lbl_viz = cmap[lbl]
00036 
00037     if img is not None:
00038         img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
00039         img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
00040         lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
00041         lbl_viz = lbl_viz.astype(np.uint8)
00042 
00043     if label_names is None:
00044         return lbl_viz
00045 
00046     np.random.seed(1234)
00047     labels = np.unique(lbl)
00048     labels = labels[labels != 0]
00049     for label in labels:
00050         mask = lbl == label
00051         mask = (mask * 255).astype(np.uint8)
00052         y, x = scipy.ndimage.center_of_mass(mask)
00053         y, x = map(int, [y, x])
00054 
00055         if lbl[y, x] != label:
00056             Y, X = np.where(mask)
00057             point_index = np.random.randint(0, len(Y))
00058             y, x = Y[point_index], X[point_index]
00059 
00060         text = label_names[label]
00061         font_face = cv2.FONT_HERSHEY_SIMPLEX
00062         font_scale = 0.7
00063         thickness = 2
00064         text_size, baseline = cv2.getTextSize(
00065             text, font_face, font_scale, thickness)
00066 
00067         color = get_text_color(lbl_viz[y, x])
00068         cv2.putText(lbl_viz, text,
00069                     (x - text_size[0] // 2, y),
00070                     font_face, font_scale, color, thickness)
00071     return lbl_viz
00072 
00073 
00074 class LabelImageDecomposer(ConnectionBasedTransport):
00075 
00076     def __init__(self):
00077         super(LabelImageDecomposer, self).__init__()
00078         self.pub_img = self.advertise('~output', Image, queue_size=5)
00079         self.pub_label_viz = self.advertise('~output/label_viz', Image,
00080                                             queue_size=5)
00081         self._label_names = rospy.get_param('~label_names', None)
00082         # publish masks of fg/bg by decomposing each label
00083         self._publish_mask = rospy.get_param('~publish_mask', False)
00084         if self._publish_mask:
00085             self.pub_fg_mask = self.advertise('~output/fg_mask', Image,
00086                                               queue_size=5)
00087             self.pub_bg_mask = self.advertise('~output/bg_mask', Image,
00088                                               queue_size=5)
00089         # publish each region image. this can take time so optional.
00090         self._publish_tile = rospy.get_param('~publish_tile', False)
00091         rospy.loginfo('~publish_tile: {}'.format(self._publish_tile))
00092         if self._publish_tile:
00093             self.pub_tile = self.advertise('~output/tile', Image, queue_size=5)
00094 
00095     def subscribe(self):
00096         self.sub_img = message_filters.Subscriber('~input', Image)
00097         self.sub_label = message_filters.Subscriber('~input/label', Image)
00098         warn_no_remap('~input', '~input/label')
00099         use_async = rospy.get_param('~approximate_sync', False)
00100         queue_size = rospy.get_param('~queue_size', 10)
00101         rospy.loginfo('~approximate_sync: {}, queue_size: {}'
00102                       .format(use_async, queue_size))
00103         if use_async:
00104             slop = rospy.get_param('~slop', 0.1)
00105             rospy.loginfo('~slop: {}'.format(slop))
00106             async = message_filters.ApproximateTimeSynchronizer(
00107                 [self.sub_img, self.sub_label],
00108                 queue_size=queue_size, slop=slop)
00109             async.registerCallback(self._apply)
00110             if self._publish_tile:
00111                 async.registerCallback(self._apply_tile)
00112         else:
00113             sync = message_filters.TimeSynchronizer(
00114                 [self.sub_img, self.sub_label], queue_size=queue_size)
00115             sync.registerCallback(self._apply)
00116             if self._publish_tile:
00117                 sync.registerCallback(self._apply_tile)
00118 
00119     def unsubscribe(self):
00120         self.sub_img.sub.unregister()
00121         self.sub_label.sub.unregister()
00122 
00123     def _apply(self, img_msg, label_msg):
00124         bridge = cv_bridge.CvBridge()
00125         img = bridge.imgmsg_to_cv2(img_msg)
00126         label_img = bridge.imgmsg_to_cv2(label_msg)
00127         # publish only valid label region
00128         applied = img.copy()
00129         applied[label_img == 0] = 0
00130         applied_msg = bridge.cv2_to_imgmsg(applied, encoding=img_msg.encoding)
00131         applied_msg.header = img_msg.header
00132         self.pub_img.publish(applied_msg)
00133         # publish visualized label
00134         if img_msg.encoding in {'16UC1', '32SC1'}:
00135             # do dynamic scaling to make it look nicely
00136             min_value, max_value = img.min(), img.max()
00137             img = (img - min_value) / (max_value - min_value) * 255
00138             img = img.astype(np.uint8)
00139             img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
00140 
00141         label_viz = label2rgb(label_img, img, label_names=self._label_names)
00142         label_viz_msg = bridge.cv2_to_imgmsg(label_viz, encoding='rgb8')
00143         label_viz_msg.header = img_msg.header
00144         self.pub_label_viz.publish(label_viz_msg)
00145 
00146         # publish mask
00147         if self._publish_mask:
00148             bg_mask = (label_img == 0)
00149             fg_mask = ~bg_mask
00150             bg_mask = (bg_mask * 255).astype(np.uint8)
00151             fg_mask = (fg_mask * 255).astype(np.uint8)
00152             fg_mask_msg = bridge.cv2_to_imgmsg(fg_mask, encoding='mono8')
00153             fg_mask_msg.header = img_msg.header
00154             bg_mask_msg = bridge.cv2_to_imgmsg(bg_mask, encoding='mono8')
00155             bg_mask_msg.header = img_msg.header
00156             self.pub_fg_mask.publish(fg_mask_msg)
00157             self.pub_bg_mask.publish(bg_mask_msg)
00158 
00159     def _apply_tile(self, img_msg, label_msg):
00160         bridge = cv_bridge.CvBridge()
00161         img = bridge.imgmsg_to_cv2(img_msg)
00162         label_img = bridge.imgmsg_to_cv2(label_msg)
00163 
00164         imgs = []
00165         labels = np.unique(label_img)
00166         for label in labels:
00167             if label == 0:
00168                 # should be skipped 0, because
00169                 # 0 is to label image as black region to mask image
00170                 continue
00171             img_tmp = img.copy()
00172             mask = label_img == label
00173             img_tmp[~mask] = 0
00174             img_tmp = bounding_rect_of_mask(img_tmp, mask)
00175             imgs.append(img_tmp)
00176         tile_img = get_tile_image(imgs)
00177         tile_msg = bridge.cv2_to_imgmsg(tile_img, encoding='bgr8')
00178         tile_msg.header = img_msg.header
00179         self.pub_tile.publish(tile_msg)
00180 
00181 
00182 if __name__ == '__main__':
00183     rospy.init_node('label_image_decomposer')
00184     label_image_decomposer = LabelImageDecomposer()
00185     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23