Go to the documentation of this file.00001
00002
00003 import cv_bridge
00004 from jsk_recognition_msgs.msg import ClassificationResult
00005 from jsk_topic_tools import ConnectionBasedTransport
00006 import numpy as np
00007 import rospy
00008 from sensor_msgs.msg import Image
00009
00010
00011 class LabelImageClassifier(ConnectionBasedTransport):
00012
00013 classifier_name = 'label_image_classifier'
00014
00015 def __init__(self):
00016 super(LabelImageClassifier, self).__init__()
00017 self.ignore_labels = rospy.get_param('~ignore_labels', [])
00018 self.target_names = rospy.get_param('~target_names', [])
00019 self.pub = self.advertise(
00020 '~output', ClassificationResult, queue_size=1)
00021
00022 def subscribe(self):
00023 self.sub = rospy.Subscriber('~input', Image, self._cb)
00024
00025 def unsubscribe(self):
00026 self.sub.unregister()
00027
00028 def _cb(self, imgmsg):
00029 bridge = cv_bridge.CvBridge()
00030 img = bridge.imgmsg_to_cv2(imgmsg)
00031 label, proba = self._classify(img)
00032 msg = ClassificationResult()
00033 msg.header = imgmsg.header
00034 msg.labels = [label]
00035 msg.label_names = [self.target_names[label]]
00036 msg.label_proba = [proba[label]]
00037 msg.probabilities = proba
00038 msg.classifier = self.classifier_name
00039 msg.target_names = self.target_names
00040 self.pub.publish(msg)
00041
00042 def _classify(self, label_img):
00043 counts = np.bincount(label_img.flatten(),
00044 minlength=len(self.target_names))
00045 counts[self.ignore_labels] = 0
00046 label = np.argmax(counts)
00047 proba = counts.astype(np.float32) / counts.sum()
00048 return label, proba
00049
00050
00051 if __name__ == '__main__':
00052 rospy.init_node('label_image_classifier')
00053 app = LabelImageClassifier()
00054 rospy.spin()