random_forest_server.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 try:
00004     from ml_classifiers.srv import *
00005 except:
00006     import roslib;roslib.load_manifest("ml_classifiers")
00007     from ml_classifiers.srv import *
00008 
00009 import rospy
00010 import numpy as np
00011 from sklearn.cross_validation import cross_val_score
00012 from sklearn.ensemble import RandomForestClassifier
00013 from sklearn.ensemble import ExtraTreesClassifier
00014 from sklearn.externals import joblib
00015 
00016 class RandomForestServer:
00017     def __init__(self, clf):
00018         self.clf = clf
00019         s = rospy.Service('predict', ClassifyData, self.classifyData)
00020 
00021     @classmethod
00022     def initWithData(cls, data_x, data_y):
00023         if len(data_x) != len(data_y):
00024             rospy.logerr("Lenght of datas are different")
00025             exit()
00026         rospy.loginfo("InitWithData please wait..")
00027         clf = RandomForestClassifier(n_estimators=250, max_features=2, max_depth=29, min_samples_split=1, random_state=0)
00028         clf.fit(data_x, data_y)
00029         return cls(clf)
00030 
00031     @classmethod
00032     def initWithFileModel(cls, filename):
00033         rospy.loginfo("InitWithFileModel with%s please wait.."%filename)
00034         clf = joblib.load(filename)
00035         return cls(clf)
00036 
00037     #Return predict result
00038     def classifyData(self, req):
00039         ret = []
00040         for data in req.data:
00041             print data
00042             ret.append(" ".join([str(predict_data) for predict_data in self.clf.predict(data.point)]))
00043             rospy.loginfo("req : " + str(data.point) +  "-> answer : " + str(ret))
00044         return ClassifyDataResponse(ret)
00045 
00046     #Run random forest
00047     def run(self):
00048         rospy.loginfo("RandomForestServer is running!")
00049         rospy.spin()
00050 
00051 
00052 if __name__ == "__main__":
00053     rospy.init_node('random_forest_cloth_classifier')
00054 
00055     try:
00056         train_file = rospy.get_param('~random_forest_train_file')
00057     except KeyError:
00058         rospy.logerr("Train File is not Set. Set train_data file or tree model file as ~random_forest_train_file.")
00059         exit()
00060 
00061     if train_file.endswith("pkl"):
00062         node = RandomForestServer.initWithFileModel(train_file)
00063     else:
00064         try:
00065             class_file = rospy.get_param('~random_forest_train_class_file')
00066 
00067             data_x = []
00068             data_y = []
00069             for l in open(train_file).readlines():
00070                 float_strings = l.split(",");
00071                 data_x.append(map(lambda x: float(x), float_strings))
00072 
00073             for l in open(class_file).readlines():
00074                 data_y.append(float(l))
00075 
00076             #build servece server
00077             node = RandomForestServer.initWithData(np.array(data_x), np.array(data_y))
00078 
00079         except KeyError:
00080             rospy.logerr("Train Class File is not Set. Set train_data file or tree model file.")
00081             rospy.logerr("Or Did you expect  Extension to be pkl?.")
00082             exit()
00083 
00084 
00085     #run
00086     node.run()


sklearn
Author(s):
autogenerated on Thu Oct 8 2015 11:21:03