probability_image_classifier.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from label_image_classifier import LabelImageClassifier
4 import numpy as np
5 import rospy
6 
7 
8 class ProbabilityImageClassifier(LabelImageClassifier):
9 
10  classifier_name = 'probability_image_classifier'
11 
12  def __init__(self):
13  super(ProbabilityImageClassifier, self).__init__()
14 
15  def _classify(self, proba_img):
16  proba = np.nansum(proba_img, axis=(0, 1)).astype(np.float32)
17 
18  self.ignore_labels = np.asarray(self.ignore_labels)
19 
20  # validation
21  assert proba.ndim == 1
22  n_labels = proba.shape[0]
23  mask_valid = self.ignore_labels < n_labels
24  if mask_valid.sum() != len(self.ignore_labels):
25  rospy.logwarn_throttle(
26  10, "The max label value in '~ignore_labels' exceeds "
27  "the number of labels of input probability image.")
28 
29  # filtering unrelated label probabilities
30  ignore_labels = self.ignore_labels[mask_valid]
31  proba[ignore_labels] = 0
32  label = np.argmax(proba)
33  proba = proba / proba.sum()
34 
35  return label, proba
36 
37 
38 if __name__ == '__main__':
39  rospy.init_node('probability_image_classifier')
41  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27