simple_classifier.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 #
00004 from __future__ import division
00005 import gzip
00006 import cPickle as pickle
00007 
00008 import numpy as np
00009 from sklearn.preprocessing import normalize
00010 
00011 import rospy
00012 from jsk_recognition_msgs.msg import Histogram
00013 from std_msgs.msg import String
00014 
00015 
00016 class SimpleClassifier(object):
00017     def __init__(self, clf_path):
00018         self._init_classifier(clf_path)
00019         self.input = Histogram()
00020         self._pub = rospy.Publisher('~output', String, queue_size=1)
00021         rospy.Subscriber('~input', Histogram, self._cb_predict)
00022 
00023     def _init_classifier(self, clf_path):
00024         with gzip.open(clf_path) as f:
00025             self.clf = pickle.load(f)
00026 
00027     def _cb_predict(self, msg):
00028         clf = self.clf
00029         if not (len(msg.histogram) > 0):
00030             return
00031         X = np.array([msg.histogram])
00032         normalize(X, copy=False)
00033         target_names = clf.target_names_
00034         index = clf.predict(X)[0]
00035         y_pred_0 = target_names[index]
00036         self._pub.publish(String(data=str(y_pred_0)))
00037 
00038 
00039 def main():
00040     rospy.init_node('simple_classifier')
00041     clf_path = rospy.get_param('~clf_path')
00042     clf = SimpleClassifier(clf_path=clf_path)
00043     rospy.spin()
00044 
00045 
00046 if __name__ == '__main__':
00047     main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Wed Sep 16 2015 04:36:15