random_forest_server.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 try:
00004     from ml_classifiers.srv import *
00005 except:
00006     import roslib;roslib.load_manifest("ml_classifiers")
00007     from ml_classifiers.srv import *
00008 
00009 import rospy
00010 import numpy as np
00011 from sklearn.ensemble import RandomForestClassifier
00012 from sklearn.ensemble import ExtraTreesClassifier
00013 from sklearn.externals import joblib
00014 
00015 
00016 class RandomForestServer:
00017     def __init__(self, clf):
00018         self.clf = clf
00019         s = rospy.Service('predict', ClassifyData, self.classifyData)
00020 
00021     @classmethod
00022     def initWithData(cls, data_x, data_y):
00023         if len(data_x) != len(data_y):
00024             rospy.logerr("Lenght of datas are different")
00025             exit()
00026         rospy.loginfo("InitWithData please wait..")
00027         clf = RandomForestClassifier(
00028             n_estimators=250, max_features=2, max_depth=29,
00029             min_samples_split=2, random_state=0)
00030         clf.fit(data_x, data_y)
00031         return cls(clf)
00032 
00033     @classmethod
00034     def initWithFileModel(cls, filename):
00035         rospy.loginfo("InitWithFileModel with%s please wait.."%filename)
00036         clf = joblib.load(filename)
00037         return cls(clf)
00038 
00039     #Return predict result
00040     def classifyData(self, req):
00041         ret = []
00042         for data in req.data:
00043             print data
00044             ret.append(" ".join([
00045                 str(predict_data)
00046                 for predict_data in self.clf.predict([data.point])]))
00047             rospy.loginfo("req : " + str(data.point) +  "-> answer : " + str(ret))
00048         return ClassifyDataResponse(ret)
00049 
00050     #Run random forest
00051     def run(self):
00052         rospy.loginfo("RandomForestServer is running!")
00053         rospy.spin()
00054 
00055 
00056 if __name__ == "__main__":
00057     rospy.init_node('random_forest_cloth_classifier')
00058 
00059     try:
00060         train_file = rospy.get_param('~random_forest_train_file')
00061     except KeyError:
00062         rospy.logerr("Train File is not Set. Set train_data file or tree model file as ~random_forest_train_file.")
00063         exit()
00064 
00065     if train_file.endswith("pkl"):
00066         node = RandomForestServer.initWithFileModel(train_file)
00067     else:
00068         try:
00069             class_file = rospy.get_param('~random_forest_train_class_file')
00070 
00071             data_x = []
00072             data_y = []
00073             for l in open(train_file).readlines():
00074                 float_strings = l.split(",");
00075                 data_x.append(map(lambda x: float(x), float_strings))
00076 
00077             for l in open(class_file).readlines():
00078                 data_y.append(float(l))
00079 
00080             #build servece server
00081             node = RandomForestServer.initWithData(np.array(data_x), np.array(data_y))
00082 
00083         except KeyError:
00084             rospy.logerr("Train Class File is not Set. Set train_data file or tree model file.")
00085             rospy.logerr("Or Did you expect  Extension to be pkl?.")
00086             exit()
00087 
00088 
00089     #run
00090     node.run()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07