draw_classification_result.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from __future__ import absolute_import
00004 from __future__ import division
00005 from __future__ import print_function
00006 
00007 import textwrap
00008 
00009 import cv2
00010 import cv_bridge
00011 from jsk_recognition_msgs.msg import ClassificationResult
00012 from jsk_topic_tools import ConnectionBasedTransport
00013 import message_filters
00014 import numpy as np
00015 import rospy
00016 from sensor_msgs.msg import Image
00017 from skimage.color import rgb_colors
00018 from skimage.color.colorlabel import DEFAULT_COLORS
00019 from jsk_recognition_utils.color import labelcolormap
00020 
00021 
00022 class DrawClassificationResult(ConnectionBasedTransport):
00023 
00024     def __init__(self):
00025         super(self.__class__, self).__init__()
00026         self.pub = self.advertise('~output', Image, queue_size=1)
00027         self.cmap = labelcolormap(255)
00028 
00029     def subscribe(self):
00030         self.sub = message_filters.Subscriber('~input', ClassificationResult)
00031         self.sub_img = message_filters.Subscriber('~input/image', Image)
00032         sync = message_filters.TimeSynchronizer(
00033             [self.sub, self.sub_img], queue_size=10)
00034         sync.registerCallback(self._draw)
00035 
00036     def unsubscribe(self):
00037         self.sub.unregister()
00038         self.sub_img.unregister()
00039 
00040     def _draw(self, cls_msg, imgmsg):
00041         n_class = len(cls_msg.target_names)
00042         bridge = cv_bridge.CvBridge()
00043         rgb = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
00044 
00045         n_results = len(cls_msg.labels)
00046         for i in xrange(n_results):
00047             label = cls_msg.labels[i]
00048             color = self.cmap[label % len(self.cmap)] * 255
00049             legend_size = int(rgb.shape[0] * 0.1)
00050             rgb[:legend_size, :] = (np.array(color) * 255).astype(np.uint8)
00051 
00052             label_name = cls_msg.label_names[i]
00053             if len(label_name) > 16:
00054                 label_name = label_name[:10] + '..' + label_name[-4:]
00055             label_proba = cls_msg.label_proba[i]
00056             title = '{0}: {1:.2%}'.format(label_name, label_proba)
00057             (text_w, text_h), baseline = cv2.getTextSize(
00058                 title, cv2.FONT_HERSHEY_PLAIN, 1, 1)
00059             scale_h = legend_size / (text_h + baseline)
00060             scale_w = rgb.shape[1] / text_w
00061             scale = min(scale_h, scale_w)
00062             (text_w, text_h), baseline = cv2.getTextSize(
00063                 label_name, cv2.FONT_HERSHEY_SIMPLEX, scale, 1)
00064             cv2.putText(rgb, title, (0, text_h - baseline),
00065                         cv2.FONT_HERSHEY_PLAIN + cv2.FONT_ITALIC,
00066                         scale, (255, 255, 255), 1,
00067                         cv2.CV_AA)
00068 
00069         out_msg = bridge.cv2_to_imgmsg(rgb, encoding='rgb8')
00070         out_msg.header = imgmsg.header
00071         self.pub.publish(out_msg)
00072 
00073 
00074 if __name__ == '__main__':
00075     rospy.init_node('draw_classification_result')
00076     app = DrawClassificationResult()
00077     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23