Go to the documentation of this file.00001
00002
00003 from label_image_classifier import LabelImageClassifier
00004 import numpy as np
00005 import rospy
00006
00007
00008 class ProbabilityImageClassifier(LabelImageClassifier):
00009
00010 classifier_name = 'probability_image_classifier'
00011
00012 def __init__(self):
00013 super(ProbabilityImageClassifier, self).__init__()
00014
00015 def _classify(self, proba_img):
00016 proba = np.nansum(proba_img, axis=(0, 1)).astype(np.float32)
00017
00018 self.ignore_labels = np.asarray(self.ignore_labels)
00019
00020
00021 assert proba.ndim == 1
00022 n_labels = proba.shape[0]
00023 mask_valid = self.ignore_labels < n_labels
00024 if mask_valid.sum() != len(self.ignore_labels):
00025 rospy.logwarn_throttle(
00026 10, "The max label value in '~ignore_labels' exceeds "
00027 "the number of labels of input probability image.")
00028
00029
00030 ignore_labels = self.ignore_labels[mask_valid]
00031 proba[ignore_labels] = 0
00032 label = np.argmax(proba)
00033 proba = proba / proba.sum()
00034
00035 return label, proba
00036
00037
00038 if __name__ == '__main__':
00039 rospy.init_node('probability_image_classifier')
00040 app = ProbabilityImageClassifier()
00041 rospy.spin()