sklearn_classifier.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 #
00004 from __future__ import division
00005 import gzip
00006 import cPickle as pickle
00007 
00008 import numpy as np
00009 from sklearn.preprocessing import normalize
00010 
00011 import rospy
00012 from jsk_topic_tools import ConnectionBasedTransport
00013 from jsk_recognition_msgs.msg import VectorArray, ClassificationResult
00014 
00015 
00016 class ScikitLearnClassifier(ConnectionBasedTransport):
00017     def __init__(self):
00018         super(ScikitLearnClassifier, self).__init__()
00019         self._init_classifier()
00020         self._pub = self.advertise('~output', ClassificationResult,
00021                                    queue_size=1)
00022 
00023     def _init_classifier(self):
00024         clf_path = rospy.get_param('~clf_path')
00025         with gzip.open(clf_path) as f:
00026             self.clf = pickle.load(f)
00027 
00028     def subscribe(self):
00029         self.sub_hist = rospy.Subscriber('~input', VectorArray, self._predict)
00030 
00031     def unsubscribe(self):
00032         self.sub_hist.unregister()
00033 
00034     def _predict(self, msg):
00035         X = np.array(msg.data).reshape((-1, msg.vector_dim))
00036         normalize(X, copy=False)
00037         y_proba = self.clf.predict_proba(X)
00038         y_pred = np.argmax(y_proba, axis=-1)
00039         target_names = np.array(self.clf.target_names_)
00040         label_proba = [p[i] for p, i in zip(y_proba, y_pred)]
00041         out_msg = ClassificationResult(
00042             header=msg.header,
00043             labels=y_pred,
00044             label_names=target_names[y_pred],
00045             label_proba=label_proba,
00046             probabilities=y_proba.reshape(-1),
00047             classifier=self.clf.__str__(),
00048             target_names=target_names,
00049         )
00050         self._pub.publish(out_msg)
00051 
00052 
00053 if __name__ == '__main__':
00054     rospy.init_node('sklearn_classifier')
00055     sklearn_clf = ScikitLearnClassifier()
00056     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07