Go to the documentation of this file.00001
00002
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006
00007 import cv2
00008 import cv_bridge
00009 from distutils.version import LooseVersion
00010 from jsk_recognition_msgs.msg import ClassificationResult
00011 from jsk_topic_tools import ConnectionBasedTransport
00012 import message_filters
00013 import numpy as np
00014 import rospy
00015 from sensor_msgs.msg import Image
00016 from jsk_recognition_utils.color import labelcolormap
00017
00018
00019 class DrawClassificationResult(ConnectionBasedTransport):
00020
00021 def __init__(self):
00022 super(self.__class__, self).__init__()
00023 self.pub = self.advertise('~output', Image, queue_size=1)
00024 self.cmap = labelcolormap(255)
00025
00026 def subscribe(self):
00027 self.sub = message_filters.Subscriber('~input', ClassificationResult)
00028 self.sub_img = message_filters.Subscriber('~input/image', Image)
00029 sync = message_filters.TimeSynchronizer(
00030 [self.sub, self.sub_img], queue_size=10)
00031 sync.registerCallback(self._draw)
00032
00033 def unsubscribe(self):
00034 self.sub.unregister()
00035 self.sub_img.unregister()
00036
00037 def _draw(self, cls_msg, imgmsg):
00038 bridge = cv_bridge.CvBridge()
00039 rgb = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
00040
00041 n_results = len(cls_msg.labels)
00042 for i in xrange(n_results):
00043 label = cls_msg.labels[i]
00044 color = self.cmap[label % len(self.cmap)] * 255
00045 legend_size = int(rgb.shape[0] * 0.1)
00046 rgb[:legend_size, :] = (np.array(color) * 255).astype(np.uint8)
00047
00048 label_name = cls_msg.label_names[i]
00049 if len(label_name) > 16:
00050 label_name = label_name[:10] + '..' + label_name[-4:]
00051 label_proba = cls_msg.label_proba[i]
00052 title = '{0}: {1:.2%}'.format(label_name, label_proba)
00053 (text_w, text_h), baseline = cv2.getTextSize(
00054 title, cv2.FONT_HERSHEY_PLAIN, 1, 1)
00055 scale_h = legend_size / (text_h + baseline)
00056 scale_w = rgb.shape[1] / text_w
00057 scale = min(scale_h, scale_w)
00058 (text_w, text_h), baseline = cv2.getTextSize(
00059 label_name, cv2.FONT_HERSHEY_SIMPLEX, scale, 1)
00060 if LooseVersion(cv2.__version__).version[0] < 3:
00061 line_type = cv2.CV_AA
00062 else:
00063 line_type = cv2.LINE_AA
00064 cv2.putText(rgb, title, (0, text_h - baseline),
00065 cv2.FONT_HERSHEY_PLAIN + cv2.FONT_ITALIC,
00066 scale, (255, 255, 255), 1,
00067 line_type)
00068
00069 out_msg = bridge.cv2_to_imgmsg(rgb, encoding='rgb8')
00070 out_msg.header = imgmsg.header
00071 self.pub.publish(out_msg)
00072
00073
00074 if __name__ == '__main__':
00075 rospy.init_node('draw_classification_result')
00076 app = DrawClassificationResult()
00077 rospy.spin()