Go to the documentation of this file.00001
00002
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006
00007 import textwrap
00008
00009 import cv2
00010 import cv_bridge
00011 from jsk_recognition_msgs.msg import ClassificationResult
00012 from jsk_topic_tools import ConnectionBasedTransport
00013 import message_filters
00014 import numpy as np
00015 import rospy
00016 from sensor_msgs.msg import Image
00017 from skimage.color import rgb_colors
00018 from skimage.color.colorlabel import DEFAULT_COLORS
00019 from jsk_recognition_utils.color import labelcolormap
00020
00021
00022 class DrawClassificationResult(ConnectionBasedTransport):
00023
00024 def __init__(self):
00025 super(self.__class__, self).__init__()
00026 self.pub = self.advertise('~output', Image, queue_size=1)
00027 self.cmap = labelcolormap(255)
00028
00029 def subscribe(self):
00030 self.sub = message_filters.Subscriber('~input', ClassificationResult)
00031 self.sub_img = message_filters.Subscriber('~input/image', Image)
00032 sync = message_filters.TimeSynchronizer(
00033 [self.sub, self.sub_img], queue_size=10)
00034 sync.registerCallback(self._draw)
00035
00036 def unsubscribe(self):
00037 self.sub.unregister()
00038 self.sub_img.unregister()
00039
00040 def _draw(self, cls_msg, imgmsg):
00041 n_class = len(cls_msg.target_names)
00042 bridge = cv_bridge.CvBridge()
00043 rgb = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
00044
00045 n_results = len(cls_msg.labels)
00046 for i in xrange(n_results):
00047 label = cls_msg.labels[i]
00048 color = self.cmap[label % len(self.cmap)] * 255
00049 legend_size = int(rgb.shape[0] * 0.1)
00050 rgb[:legend_size, :] = (np.array(color) * 255).astype(np.uint8)
00051
00052 label_name = cls_msg.label_names[i]
00053 if len(label_name) > 16:
00054 label_name = label_name[:10] + '..' + label_name[-4:]
00055 label_proba = cls_msg.label_proba[i]
00056 title = '{0}: {1:.2%}'.format(label_name, label_proba)
00057 (text_w, text_h), baseline = cv2.getTextSize(
00058 title, cv2.FONT_HERSHEY_PLAIN, 1, 1)
00059 scale_h = legend_size / (text_h + baseline)
00060 scale_w = rgb.shape[1] / text_w
00061 scale = min(scale_h, scale_w)
00062 (text_w, text_h), baseline = cv2.getTextSize(
00063 label_name, cv2.FONT_HERSHEY_SIMPLEX, scale, 1)
00064 cv2.putText(rgb, title, (0, text_h - baseline),
00065 cv2.FONT_HERSHEY_PLAIN + cv2.FONT_ITALIC,
00066 scale, (255, 255, 255), 1,
00067 cv2.CV_AA)
00068
00069 out_msg = bridge.cv2_to_imgmsg(rgb, encoding='rgb8')
00070 out_msg.header = imgmsg.header
00071 self.pub.publish(out_msg)
00072
00073
00074 if __name__ == '__main__':
00075 rospy.init_node('draw_classification_result')
00076 app = DrawClassificationResult()
00077 rospy.spin()