Go to the documentation of this file.00001
00002
00003 try:
00004 from ml_classifiers.srv import *
00005 except:
00006 import roslib;roslib.load_manifest("ml_classifiers")
00007 from ml_classifiers.srv import *
00008
00009 import rospy
00010 import numpy as np
00011 from sklearn.ensemble import RandomForestClassifier
00012 from sklearn.ensemble import ExtraTreesClassifier
00013 from sklearn.externals import joblib
00014
00015
00016 class RandomForestServer:
00017 def __init__(self, clf):
00018 self.clf = clf
00019 s = rospy.Service('predict', ClassifyData, self.classifyData)
00020
00021 @classmethod
00022 def initWithData(cls, data_x, data_y):
00023 if len(data_x) != len(data_y):
00024 rospy.logerr("Lenght of datas are different")
00025 exit()
00026 rospy.loginfo("InitWithData please wait..")
00027 clf = RandomForestClassifier(
00028 n_estimators=250, max_features=2, max_depth=29,
00029 min_samples_split=2, random_state=0)
00030 clf.fit(data_x, data_y)
00031 return cls(clf)
00032
00033 @classmethod
00034 def initWithFileModel(cls, filename):
00035 rospy.loginfo("InitWithFileModel with%s please wait.."%filename)
00036 clf = joblib.load(filename)
00037 return cls(clf)
00038
00039
00040 def classifyData(self, req):
00041 ret = []
00042 for data in req.data:
00043 print data
00044 ret.append(" ".join([
00045 str(predict_data)
00046 for predict_data in self.clf.predict([data.point])]))
00047 rospy.loginfo("req : " + str(data.point) + "-> answer : " + str(ret))
00048 return ClassifyDataResponse(ret)
00049
00050
00051 def run(self):
00052 rospy.loginfo("RandomForestServer is running!")
00053 rospy.spin()
00054
00055
00056 if __name__ == "__main__":
00057 rospy.init_node('random_forest_cloth_classifier')
00058
00059 try:
00060 train_file = rospy.get_param('~random_forest_train_file')
00061 except KeyError:
00062 rospy.logerr("Train File is not Set. Set train_data file or tree model file as ~random_forest_train_file.")
00063 exit()
00064
00065 if train_file.endswith("pkl"):
00066 node = RandomForestServer.initWithFileModel(train_file)
00067 else:
00068 try:
00069 class_file = rospy.get_param('~random_forest_train_class_file')
00070
00071 data_x = []
00072 data_y = []
00073 for l in open(train_file).readlines():
00074 float_strings = l.split(",");
00075 data_x.append(map(lambda x: float(x), float_strings))
00076
00077 for l in open(class_file).readlines():
00078 data_y.append(float(l))
00079
00080
00081 node = RandomForestServer.initWithData(np.array(data_x), np.array(data_y))
00082
00083 except KeyError:
00084 rospy.logerr("Train Class File is not Set. Set train_data file or tree model file.")
00085 rospy.logerr("Or Did you expect Extension to be pkl?.")
00086 exit()
00087
00088
00089
00090 node.run()