draw_classification_result.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 
7 import cv2
8 import cv_bridge
9 from distutils.version import LooseVersion
10 from jsk_recognition_msgs.msg import ClassificationResult
11 from jsk_topic_tools import ConnectionBasedTransport
12 import message_filters
13 import numpy as np
14 import rospy
15 from sensor_msgs.msg import Image
16 from jsk_recognition_utils.color import labelcolormap
17 
18 
19 class DrawClassificationResult(ConnectionBasedTransport):
20 
21  def __init__(self):
22  super(self.__class__, self).__init__()
23  self.pub = self.advertise('~output', Image, queue_size=1)
24  self.cmap = labelcolormap(255)
25 
26  def subscribe(self):
27  self.sub = message_filters.Subscriber('~input', ClassificationResult)
28  self.sub_img = message_filters.Subscriber('~input/image', Image)
30  [self.sub, self.sub_img], queue_size=10)
31  sync.registerCallback(self._draw)
32 
33  def unsubscribe(self):
34  self.sub.unregister()
35  self.sub_img.unregister()
36 
37  def _draw(self, cls_msg, imgmsg):
38  bridge = cv_bridge.CvBridge()
39  rgb = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='rgb8')
40 
41  n_results = len(cls_msg.labels)
42  for i in xrange(n_results):
43  label = cls_msg.labels[i]
44  color = self.cmap[label % len(self.cmap)] * 255
45  legend_size = int(rgb.shape[0] * 0.1)
46  rgb[:legend_size, :] = (np.array(color) * 255).astype(np.uint8)
47 
48  label_name = cls_msg.label_names[i]
49  if len(label_name) > 16:
50  label_name = label_name[:10] + '..' + label_name[-4:]
51  label_proba = cls_msg.label_proba[i]
52  title = '{0}: {1:.2%}'.format(label_name, label_proba)
53  (text_w, text_h), baseline = cv2.getTextSize(
54  title, cv2.FONT_HERSHEY_PLAIN, 1, 1)
55  scale_h = legend_size / (text_h + baseline)
56  scale_w = rgb.shape[1] / text_w
57  scale = min(scale_h, scale_w)
58  (text_w, text_h), baseline = cv2.getTextSize(
59  label_name, cv2.FONT_HERSHEY_SIMPLEX, scale, 1)
60  if LooseVersion(cv2.__version__).version[0] < 3:
61  line_type = cv2.CV_AA
62  else: # for opencv version > 3
63  line_type = cv2.LINE_AA
64  cv2.putText(rgb, title, (0, text_h - baseline),
65  cv2.FONT_HERSHEY_PLAIN + cv2.FONT_ITALIC,
66  scale, (255, 255, 255), 1,
67  line_type)
68 
69  out_msg = bridge.cv2_to_imgmsg(rgb, encoding='rgb8')
70  out_msg.header = imgmsg.header
71  self.pub.publish(out_msg)
72 
73 
74 if __name__ == '__main__':
75  rospy.init_node('draw_classification_result')
77  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27