Go to the documentation of this file.00001
00002
00003 try:
00004 from ml_classifiers.srv import *
00005 except:
00006 import roslib;roslib.load_manifest("ml_classifiers")
00007 from ml_classifiers.srv import *
00008
00009 import rospy
00010 import numpy as np
00011 from sklearn.cross_validation import cross_val_score
00012 from sklearn.ensemble import RandomForestClassifier
00013 from sklearn.ensemble import ExtraTreesClassifier
00014 from sklearn.externals import joblib
00015
00016 class RandomForestServer:
00017 def __init__(self, clf):
00018 self.clf = clf
00019 s = rospy.Service('predict', ClassifyData, self.classifyData)
00020
00021 @classmethod
00022 def initWithData(cls, data_x, data_y):
00023 if len(data_x) != len(data_y):
00024 rospy.logerr("Lenght of datas are different")
00025 exit()
00026 rospy.loginfo("InitWithData please wait..")
00027 clf = RandomForestClassifier(n_estimators=250, max_features=2, max_depth=29, min_samples_split=1, random_state=0)
00028 clf.fit(data_x, data_y)
00029 return cls(clf)
00030
00031 @classmethod
00032 def initWithFileModel(cls, filename):
00033 rospy.loginfo("InitWithFileModel with%s please wait.."%filename)
00034 clf = joblib.load(filename)
00035 return cls(clf)
00036
00037
00038 def classifyData(self, req):
00039 ret = []
00040 for data in req.data:
00041 print data
00042 ret.append(" ".join([str(predict_data) for predict_data in self.clf.predict(data.point)]))
00043 rospy.loginfo("req : " + str(data.point) + "-> answer : " + str(ret))
00044 return ClassifyDataResponse(ret)
00045
00046
00047 def run(self):
00048 rospy.loginfo("RandomForestServer is running!")
00049 rospy.spin()
00050
00051
00052 if __name__ == "__main__":
00053 rospy.init_node('random_forest_cloth_classifier')
00054
00055 try:
00056 train_file = rospy.get_param('~random_forest_train_file')
00057 except KeyError:
00058 rospy.logerr("Train File is not Set. Set train_data file or tree model file as ~random_forest_train_file.")
00059 exit()
00060
00061 if train_file.endswith("pkl"):
00062 node = RandomForestServer.initWithFileModel(train_file)
00063 else:
00064 try:
00065 class_file = rospy.get_param('~random_forest_train_class_file')
00066
00067 data_x = []
00068 data_y = []
00069 for l in open(train_file).readlines():
00070 float_strings = l.split(",");
00071 data_x.append(map(lambda x: float(x), float_strings))
00072
00073 for l in open(class_file).readlines():
00074 data_y.append(float(l))
00075
00076
00077 node = RandomForestServer.initWithData(np.array(data_x), np.array(data_y))
00078
00079 except KeyError:
00080 rospy.logerr("Train Class File is not Set. Set train_data file or tree model file.")
00081 rospy.logerr("Or Did you expect Extension to be pkl?.")
00082 exit()
00083
00084
00085
00086 node.run()