draw_classification_result.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006 
00007 import cv2
00008 import cv_bridge
00009 from distutils.version import LooseVersion
00010 from jsk_recognition_msgs.msg import ClassificationResult
00011 from jsk_topic_tools import ConnectionBasedTransport
00012 import message_filters
00013 import numpy as np
00014 import rospy
00015 from sensor_msgs.msg import Image
00016 from jsk_recognition_utils.color import labelcolormap
00017 
00018 
00019 class DrawClassificationResult(ConnectionBasedTransport):
00020 
00021     def __init__(self):
00022         super(self.__class__, self).__init__()
00023         self.pub = self.advertise('~output', Image, queue_size=1)
00024         self.cmap = labelcolormap(255)
00025 
00026     def subscribe(self):
00027         self.sub = message_filters.Subscriber('~input', ClassificationResult)
00028         self.sub_img = message_filters.Subscriber('~input/image', Image)
00029         sync = message_filters.TimeSynchronizer(
00030             [self.sub, self.sub_img], queue_size=10)
00031         sync.registerCallback(self._draw)
00032 
00033     def unsubscribe(self):
00034         self.sub.unregister()
00035         self.sub_img.unregister()
00036 
00037     def _draw(self, cls_msg, imgmsg):
00038         bridge = cv_bridge.CvBridge()
00039         rgb = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
00040 
00041         n_results = len(cls_msg.labels)
00042         for i in xrange(n_results):
00043             label = cls_msg.labels[i]
00044             color = self.cmap[label % len(self.cmap)] * 255
00045             legend_size = int(rgb.shape[0] * 0.1)
00046             rgb[:legend_size, :] = (np.array(color) * 255).astype(np.uint8)
00047 
00048             label_name = cls_msg.label_names[i]
00049             if len(label_name) > 16:
00050                 label_name = label_name[:10] + '..' + label_name[-4:]
00051             label_proba = cls_msg.label_proba[i]
00052             title = '{0}: {1:.2%}'.format(label_name, label_proba)
00053             (text_w, text_h), baseline = cv2.getTextSize(
00054                 title, cv2.FONT_HERSHEY_PLAIN, 1, 1)
00055             scale_h = legend_size / (text_h + baseline)
00056             scale_w = rgb.shape[1] / text_w
00057             scale = min(scale_h, scale_w)
00058             (text_w, text_h), baseline = cv2.getTextSize(
00059                 label_name, cv2.FONT_HERSHEY_SIMPLEX, scale, 1)
00060             if LooseVersion(cv2.__version__).version[0] < 3:
00061                 line_type = cv2.CV_AA
00062             else:  # for opencv version > 3
00063                 line_type = cv2.LINE_AA
00064             cv2.putText(rgb, title, (0, text_h - baseline),
00065                         cv2.FONT_HERSHEY_PLAIN + cv2.FONT_ITALIC,
00066                         scale, (255, 255, 255), 1,
00067                         line_type)
00068 
00069         out_msg = bridge.cv2_to_imgmsg(rgb, encoding='rgb8')
00070         out_msg.header = imgmsg.header
00071         self.pub.publish(out_msg)
00072 
00073 
00074 if __name__ == '__main__':
00075     rospy.init_node('draw_classification_result')
00076     app = DrawClassificationResult()
00077     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07