sklearn_classifier.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 from __future__ import division
5 import gzip
6 import cPickle as pickle
7 
8 import numpy as np
9 from sklearn.preprocessing import normalize
10 
11 import rospy
12 from jsk_topic_tools import ConnectionBasedTransport
13 from jsk_recognition_msgs.msg import VectorArray, ClassificationResult
14 
15 
16 class ScikitLearnClassifier(ConnectionBasedTransport):
17  def __init__(self):
18  super(ScikitLearnClassifier, self).__init__()
19  self._init_classifier()
20  self._pub = self.advertise('~output', ClassificationResult,
21  queue_size=1)
22 
23  def _init_classifier(self):
24  clf_path = rospy.get_param('~clf_path')
25  with gzip.open(clf_path) as f:
26  self.clf = pickle.load(f)
27 
28  def subscribe(self):
29  self.sub_hist = rospy.Subscriber('~input', VectorArray, self._predict)
30 
31  def unsubscribe(self):
32  self.sub_hist.unregister()
33 
34  def _predict(self, msg):
35  X = np.array(msg.data).reshape((-1, msg.vector_dim))
36  normalize(X, copy=False)
37  y_proba = self.clf.predict_proba(X)
38  y_pred = np.argmax(y_proba, axis=-1)
39  target_names = np.array(self.clf.target_names_)
40  label_proba = [p[i] for p, i in zip(y_proba, y_pred)]
41  out_msg = ClassificationResult(
42  header=msg.header,
43  labels=y_pred,
44  label_names=target_names[y_pred],
45  label_proba=label_proba,
46  probabilities=y_proba.reshape(-1),
47  classifier=self.clf.__str__(),
48  target_names=target_names,
49  )
50  self._pub.publish(out_msg)
51 
52 
53 if __name__ == '__main__':
54  rospy.init_node('sklearn_classifier')
55  sklearn_clf = ScikitLearnClassifier()
56  rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27