sklearn_classifier.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 #
4 from __future__ import division
5 import gzip
6 import sys
7 if sys.version_info.major <= 2:
8  import cPickle as pickle
9 else: # for python3
10  import _pickle as pickle
11 
12 import numpy as np
13 from sklearn.preprocessing import normalize
14 
15 import rospy
16 from jsk_topic_tools import ConnectionBasedTransport
17 from jsk_recognition_msgs.msg import VectorArray, ClassificationResult
18 
19 
20 class ScikitLearnClassifier(ConnectionBasedTransport):
21  def __init__(self):
22  super(ScikitLearnClassifier, self).__init__()
23  self._init_classifier()
24  self._pub = self.advertise('~output', ClassificationResult,
25  queue_size=1)
26 
27  def _init_classifier(self):
28  clf_path = rospy.get_param('~clf_path')
29  with gzip.open(clf_path) as f:
30  if sys.version_info.major <= 2:
31  self.clf = pickle.load(f)
32  else:
33  self.clf = pickle.load(f, encoding='latin1')
34 
35  def subscribe(self):
36  self.sub_hist = rospy.Subscriber('~input', VectorArray, self._predict)
37 
38  def unsubscribe(self):
39  self.sub_hist.unregister()
40 
41  def _predict(self, msg):
42  X = np.array(msg.data).reshape((-1, msg.vector_dim))
43  normalize(X, copy=False)
44  y_proba = self.clf.predict_proba(X)
45  y_pred = np.argmax(y_proba, axis=-1)
46  target_names = np.array(self.clf.target_names_)
47  label_proba = [p[i] for p, i in zip(y_proba, y_pred)]
48  out_msg = ClassificationResult(
49  header=msg.header,
50  labels=y_pred,
51  label_names=target_names[y_pred],
52  label_proba=label_proba,
53  probabilities=y_proba.reshape(-1),
54  classifier=self.clf.__str__(),
55  target_names=target_names,
56  )
57  self._pub.publish(out_msg)
58 
59 
60 if __name__ == '__main__':
61  rospy.init_node('sklearn_classifier')
62  sklearn_clf = ScikitLearnClassifier()
63  rospy.spin()
node_scripts.sklearn_classifier.ScikitLearnClassifier._init_classifier
def _init_classifier(self)
Definition: sklearn_classifier.py:27
node_scripts.sklearn_classifier.ScikitLearnClassifier._pub
_pub
Definition: sklearn_classifier.py:24
node_scripts.sklearn_classifier.ScikitLearnClassifier.unsubscribe
def unsubscribe(self)
Definition: sklearn_classifier.py:38
node_scripts.sklearn_classifier.ScikitLearnClassifier._predict
def _predict(self, msg)
Definition: sklearn_classifier.py:41
node_scripts.sklearn_classifier.ScikitLearnClassifier
Definition: sklearn_classifier.py:20
node_scripts.sklearn_classifier.ScikitLearnClassifier.__init__
def __init__(self)
Definition: sklearn_classifier.py:21
node_scripts.sklearn_classifier.ScikitLearnClassifier.sub_hist
sub_hist
Definition: sklearn_classifier.py:36
node_scripts.sklearn_classifier.ScikitLearnClassifier.clf
clf
Definition: sklearn_classifier.py:31
node_scripts.sklearn_classifier.ScikitLearnClassifier.subscribe
def subscribe(self)
Definition: sklearn_classifier.py:35


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:17