00001
00002
00003 import cv2
00004 import matplotlib
00005 matplotlib.use('Agg')
00006 import matplotlib.cm
00007 import numpy as np
00008 import scipy.ndimage
00009
00010 import cv_bridge
00011 from jsk_recognition_utils import bounding_rect_of_mask
00012 from jsk_recognition_utils import get_tile_image
00013 from jsk_recognition_utils.color import labelcolormap
00014 from jsk_topic_tools import ConnectionBasedTransport
00015 from jsk_topic_tools import warn_no_remap
00016 import message_filters
00017 import rospy
00018 from sensor_msgs.msg import Image
00019
00020
00021 def get_text_color(color):
00022 if color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114 > 170:
00023 return (0, 0, 0)
00024 return (255, 255, 255)
00025
00026
00027 def label2rgb(lbl, img=None, label_names=None, alpha=0.3):
00028 if label_names is None:
00029 n_labels = lbl.max() + 1
00030 else:
00031 n_labels = len(label_names)
00032 cmap = labelcolormap(n_labels)
00033 cmap = (cmap * 255).astype(np.uint8)
00034
00035 lbl_viz = cmap[lbl]
00036
00037 if img is not None:
00038 img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
00039 img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
00040 lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
00041 lbl_viz = lbl_viz.astype(np.uint8)
00042
00043 if label_names is None:
00044 return lbl_viz
00045
00046 np.random.seed(1234)
00047 labels = np.unique(lbl)
00048 labels = labels[labels != 0]
00049 for label in labels:
00050 mask = lbl == label
00051 mask = (mask * 255).astype(np.uint8)
00052 y, x = scipy.ndimage.center_of_mass(mask)
00053 y, x = map(int, [y, x])
00054
00055 if lbl[y, x] != label:
00056 Y, X = np.where(mask)
00057 point_index = np.random.randint(0, len(Y))
00058 y, x = Y[point_index], X[point_index]
00059
00060 text = label_names[label]
00061 font_face = cv2.FONT_HERSHEY_SIMPLEX
00062 font_scale = 0.7
00063 thickness = 2
00064 text_size, baseline = cv2.getTextSize(
00065 text, font_face, font_scale, thickness)
00066
00067 color = get_text_color(lbl_viz[y, x])
00068 cv2.putText(lbl_viz, text,
00069 (x - text_size[0] // 2, y),
00070 font_face, font_scale, color, thickness)
00071 return lbl_viz
00072
00073
00074 class LabelImageDecomposer(ConnectionBasedTransport):
00075
00076 def __init__(self):
00077 super(LabelImageDecomposer, self).__init__()
00078 self.pub_img = self.advertise('~output', Image, queue_size=5)
00079 self.pub_label_viz = self.advertise('~output/label_viz', Image,
00080 queue_size=5)
00081 self._label_names = rospy.get_param('~label_names', None)
00082
00083 self._publish_mask = rospy.get_param('~publish_mask', False)
00084 if self._publish_mask:
00085 self.pub_fg_mask = self.advertise('~output/fg_mask', Image,
00086 queue_size=5)
00087 self.pub_bg_mask = self.advertise('~output/bg_mask', Image,
00088 queue_size=5)
00089
00090 self._publish_tile = rospy.get_param('~publish_tile', False)
00091 rospy.loginfo('~publish_tile: {}'.format(self._publish_tile))
00092 if self._publish_tile:
00093 self.pub_tile = self.advertise('~output/tile', Image, queue_size=5)
00094
00095 def subscribe(self):
00096 self.sub_img = message_filters.Subscriber('~input', Image)
00097 self.sub_label = message_filters.Subscriber('~input/label', Image)
00098 warn_no_remap('~input', '~input/label')
00099 use_async = rospy.get_param('~approximate_sync', False)
00100 queue_size = rospy.get_param('~queue_size', 10)
00101 rospy.loginfo('~approximate_sync: {}, queue_size: {}'
00102 .format(use_async, queue_size))
00103 if use_async:
00104 slop = rospy.get_param('~slop', 0.1)
00105 rospy.loginfo('~slop: {}'.format(slop))
00106 async = message_filters.ApproximateTimeSynchronizer(
00107 [self.sub_img, self.sub_label],
00108 queue_size=queue_size, slop=slop)
00109 async.registerCallback(self._apply)
00110 if self._publish_tile:
00111 async.registerCallback(self._apply_tile)
00112 else:
00113 sync = message_filters.TimeSynchronizer(
00114 [self.sub_img, self.sub_label], queue_size=queue_size)
00115 sync.registerCallback(self._apply)
00116 if self._publish_tile:
00117 sync.registerCallback(self._apply_tile)
00118
00119 def unsubscribe(self):
00120 self.sub_img.sub.unregister()
00121 self.sub_label.sub.unregister()
00122
00123 def _apply(self, img_msg, label_msg):
00124 bridge = cv_bridge.CvBridge()
00125 img = bridge.imgmsg_to_cv2(img_msg)
00126 label_img = bridge.imgmsg_to_cv2(label_msg)
00127
00128 applied = img.copy()
00129 applied[label_img == 0] = 0
00130 applied_msg = bridge.cv2_to_imgmsg(applied, encoding=img_msg.encoding)
00131 applied_msg.header = img_msg.header
00132 self.pub_img.publish(applied_msg)
00133
00134 if img_msg.encoding in {'16UC1', '32SC1'}:
00135
00136 min_value, max_value = img.min(), img.max()
00137 img = (img - min_value) / (max_value - min_value) * 255
00138 img = img.astype(np.uint8)
00139 img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
00140
00141 label_viz = label2rgb(label_img, img, label_names=self._label_names)
00142 label_viz_msg = bridge.cv2_to_imgmsg(label_viz, encoding='rgb8')
00143 label_viz_msg.header = img_msg.header
00144 self.pub_label_viz.publish(label_viz_msg)
00145
00146
00147 if self._publish_mask:
00148 bg_mask = (label_img == 0)
00149 fg_mask = ~bg_mask
00150 bg_mask = (bg_mask * 255).astype(np.uint8)
00151 fg_mask = (fg_mask * 255).astype(np.uint8)
00152 fg_mask_msg = bridge.cv2_to_imgmsg(fg_mask, encoding='mono8')
00153 fg_mask_msg.header = img_msg.header
00154 bg_mask_msg = bridge.cv2_to_imgmsg(bg_mask, encoding='mono8')
00155 bg_mask_msg.header = img_msg.header
00156 self.pub_fg_mask.publish(fg_mask_msg)
00157 self.pub_bg_mask.publish(bg_mask_msg)
00158
00159 def _apply_tile(self, img_msg, label_msg):
00160 bridge = cv_bridge.CvBridge()
00161 img = bridge.imgmsg_to_cv2(img_msg)
00162 label_img = bridge.imgmsg_to_cv2(label_msg)
00163
00164 imgs = []
00165 labels = np.unique(label_img)
00166 for label in labels:
00167 if label == 0:
00168
00169
00170 continue
00171 img_tmp = img.copy()
00172 mask = label_img == label
00173 img_tmp[~mask] = 0
00174 img_tmp = bounding_rect_of_mask(img_tmp, mask)
00175 imgs.append(img_tmp)
00176 tile_img = get_tile_image(imgs)
00177 tile_msg = bridge.cv2_to_imgmsg(tile_img, encoding='bgr8')
00178 tile_msg.header = img_msg.header
00179 self.pub_tile.publish(tile_msg)
00180
00181
00182 if __name__ == '__main__':
00183 rospy.init_node('label_image_decomposer')
00184 label_image_decomposer = LabelImageDecomposer()
00185 rospy.spin()