Go to the documentation of this file.00001
00002
00003
00004 from __future__ import division
00005 import gzip
00006 import cPickle as pickle
00007
00008 import numpy as np
00009 from sklearn.preprocessing import normalize
00010
00011 import rospy
00012 from jsk_recognition_msgs.msg import Histogram
00013 from std_msgs.msg import String
00014
00015
00016 class SimpleClassifier(object):
00017 def __init__(self, clf_path):
00018 self._init_classifier(clf_path)
00019 self.input = Histogram()
00020 self._pub = rospy.Publisher('~output', String, queue_size=1)
00021 rospy.Subscriber('~input', Histogram, self._cb_predict)
00022
00023 def _init_classifier(self, clf_path):
00024 with gzip.open(clf_path) as f:
00025 self.clf = pickle.load(f)
00026
00027 def _cb_predict(self, msg):
00028 clf = self.clf
00029 if not (len(msg.histogram) > 0):
00030 return
00031 X = np.array([msg.histogram])
00032 normalize(X, copy=False)
00033 target_names = clf.target_names_
00034 index = clf.predict(X)[0]
00035 y_pred_0 = target_names[index]
00036 self._pub.publish(String(data=str(y_pred_0)))
00037
00038
00039 def main():
00040 rospy.init_node('simple_classifier')
00041 clf_path = rospy.get_param('~clf_path')
00042 clf = SimpleClassifier(clf_path=clf_path)
00043 rospy.spin()
00044
00045
00046 if __name__ == '__main__':
00047 main()