Go to the documentation of this file.00001
00002
00003
00004 from __future__ import division
00005 import gzip
00006 import cPickle as pickle
00007
00008 import numpy as np
00009 from sklearn.preprocessing import normalize
00010
00011 import rospy
00012 from jsk_topic_tools import ConnectionBasedTransport
00013 from jsk_recognition_msgs.msg import VectorArray, ClassificationResult
00014
00015
00016 class ScikitLearnClassifier(ConnectionBasedTransport):
00017 def __init__(self):
00018 super(ScikitLearnClassifier, self).__init__()
00019 self._init_classifier()
00020 self._pub = self.advertise('~output', ClassificationResult,
00021 queue_size=1)
00022
00023 def _init_classifier(self):
00024 clf_path = rospy.get_param('~clf_path')
00025 with gzip.open(clf_path) as f:
00026 self.clf = pickle.load(f)
00027
00028 def subscribe(self):
00029 self.sub_hist = rospy.Subscriber('~input', VectorArray, self._predict)
00030
00031 def unsubscribe(self):
00032 self.sub_hist.unregister()
00033
00034 def _predict(self, msg):
00035 X = np.array(msg.data).reshape((-1, msg.vector_dim))
00036 normalize(X, copy=False)
00037 y_proba = self.clf.predict_proba(X)
00038 y_pred = np.argmax(y_proba, axis=-1)
00039 target_names = np.array(self.clf.target_names_)
00040 label_proba = [p[i] for p, i in zip(y_proba, y_pred)]
00041 out_msg = ClassificationResult(
00042 header=msg.header,
00043 labels=y_pred,
00044 label_names=target_names[y_pred],
00045 label_proba=label_proba,
00046 probabilities=y_proba.reshape(-1),
00047 classifier=self.clf.__str__(),
00048 target_names=target_names,
00049 )
00050 self._pub.publish(out_msg)
00051
00052
00053 if __name__ == '__main__':
00054 rospy.init_node('sklearn_classifier')
00055 sklearn_clf = ScikitLearnClassifier()
00056 rospy.spin()