random_forest_server.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 try:
4  from ml_classifiers.srv import *
5 except:
6  import roslib;roslib.load_manifest("ml_classifiers")
7  from ml_classifiers.srv import *
8 
9 import rospy
10 import numpy as np
11 from sklearn.ensemble import RandomForestClassifier
12 from sklearn.ensemble import ExtraTreesClassifier
13 from sklearn.externals import joblib
14 
15 
17  def __init__(self, clf):
18  self.clf = clf
19  s = rospy.Service('predict', ClassifyData, self.classifyData)
20 
21  @classmethod
22  def initWithData(cls, data_x, data_y):
23  if len(data_x) != len(data_y):
24  rospy.logerr("Lenght of datas are different")
25  exit()
26  rospy.loginfo("InitWithData please wait..")
27  clf = RandomForestClassifier(
28  n_estimators=250, max_features=2, max_depth=29,
29  min_samples_split=2, random_state=0)
30  clf.fit(data_x, data_y)
31  return cls(clf)
32 
33  @classmethod
34  def initWithFileModel(cls, filename):
35  rospy.loginfo("InitWithFileModel with%s please wait.."%filename)
36  clf = joblib.load(filename)
37  return cls(clf)
38 
39  #Return predict result
40  def classifyData(self, req):
41  ret = []
42  for data in req.data:
43  print(data)
44  ret.append(" ".join([
45  str(predict_data)
46  for predict_data in self.clf.predict([data.point])]))
47  rospy.loginfo("req : " + str(data.point) + "-> answer : " + str(ret))
48  return ClassifyDataResponse(ret)
49 
50  #Run random forest
51  def run(self):
52  rospy.loginfo("RandomForestServer is running!")
53  rospy.spin()
54 
55 
56 if __name__ == "__main__":
57  rospy.init_node('random_forest_cloth_classifier')
58 
59  try:
60  train_file = rospy.get_param('~random_forest_train_file')
61  except KeyError:
62  rospy.logerr("Train File is not Set. Set train_data file or tree model file as ~random_forest_train_file.")
63  exit()
64 
65  if train_file.endswith("pkl"):
66  node = RandomForestServer.initWithFileModel(train_file)
67  else:
68  try:
69  class_file = rospy.get_param('~random_forest_train_class_file')
70 
71  data_x = []
72  data_y = []
73  for l in open(train_file).readlines():
74  float_strings = l.split(",");
75  data_x.append(map(lambda x: float(x), float_strings))
76 
77  for l in open(class_file).readlines():
78  data_y.append(float(l))
79 
80  #build servece server
81  node = RandomForestServer.initWithData(np.array(data_x), np.array(data_y))
82 
83  except KeyError:
84  rospy.logerr("Train Class File is not Set. Set train_data file or tree model file.")
85  rospy.logerr("Or Did you expect Extension to be pkl?.")
86  exit()
87 
88 
89  #run
90  node.run()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27