probability_image_classifier.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from label_image_classifier import LabelImageClassifier
00004 import numpy as np
00005 import rospy
00006 
00007 
00008 class ProbabilityImageClassifier(LabelImageClassifier):
00009 
00010     classifier_name = 'probability_image_classifier'
00011 
00012     def __init__(self):
00013         super(ProbabilityImageClassifier, self).__init__()
00014 
00015     def _classify(self, proba_img):
00016         proba = np.nansum(proba_img, axis=(0, 1)).astype(np.float32)
00017 
00018         self.ignore_labels = np.asarray(self.ignore_labels)
00019 
00020         # validation
00021         assert proba.ndim == 1
00022         n_labels = proba.shape[0]
00023         mask_valid = self.ignore_labels < n_labels
00024         if mask_valid.sum() != len(self.ignore_labels):
00025             rospy.logwarn_throttle(
00026                 10, "The max label value in '~ignore_labels' exceeds "
00027                     "the number of labels of input probability image.")
00028 
00029         # filtering unrelated label probabilities
00030         ignore_labels = self.ignore_labels[mask_valid]
00031         proba[ignore_labels] = 0
00032         label = np.argmax(proba)
00033         proba = proba / proba.sum()
00034 
00035         return label, proba
00036 
00037 
00038 if __name__ == '__main__':
00039     rospy.init_node('probability_image_classifier')
00040     app = ProbabilityImageClassifier()
00041     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Sun Oct 8 2017 02:43:23