label_image_classifier.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 import cv_bridge
4 from jsk_recognition_msgs.msg import ClassificationResult
5 from jsk_topic_tools import ConnectionBasedTransport
6 import numpy as np
7 import rospy
8 from sensor_msgs.msg import Image
9 
10 
11 class LabelImageClassifier(ConnectionBasedTransport):
12 
13  classifier_name = 'label_image_classifier'
14 
15  def __init__(self):
16  super(LabelImageClassifier, self).__init__()
17  self.ignore_labels = rospy.get_param('~ignore_labels', [])
18  self.target_names = rospy.get_param('~target_names', [])
19  self.pub = self.advertise(
20  '~output', ClassificationResult, queue_size=1)
21 
22  def subscribe(self):
23  self.sub = rospy.Subscriber('~input', Image, self._cb)
24 
25  def unsubscribe(self):
26  self.sub.unregister()
27 
28  def _cb(self, imgmsg):
29  bridge = cv_bridge.CvBridge()
30  img = bridge.imgmsg_to_cv2(imgmsg)
31  label, proba = self._classify(img)
32  msg = ClassificationResult()
33  msg.header = imgmsg.header
34  msg.labels = [label]
35  msg.label_names = [self.target_names[label]]
36  msg.label_proba = [proba[label]]
37  msg.probabilities = proba
38  msg.classifier = self.classifier_name
39  msg.target_names = self.target_names
40  self.pub.publish(msg)
41 
42  def _classify(self, label_img):
43  counts = np.bincount(label_img.flatten(),
44  minlength=len(self.target_names))
45  counts[self.ignore_labels] = 0
46  label = np.argmax(counts)
47  proba = counts.astype(np.float32) / counts.sum()
48  return label, proba
49 
50 
51 if __name__ == '__main__':
52  rospy.init_node('label_image_classifier')
54  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27