random_forest_client_sample.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 import random
4 
5 import matplotlib
6 matplotlib.use('Agg') # NOQA
7 from matplotlib import patches
8 import matplotlib.pyplot as plt
9 import numpy as np
10 
11 import cv_bridge
12 import rospy
13 from sensor_msgs.msg import Image
14 
15 try:
16  from ml_classifiers.srv import *
17  from ml_classifiers.msg import *
18 except:
19  import roslib;roslib.load_manifest("ml_classifiers")
20  from ml_classifiers.srv import *
21  from ml_classifiers.msg import *
22 
23 
24 HEADER = '\033[95m'
25 OKBLUE = '\033[94m'
26 OKGREEN = '\033[92m'
27 WARNING = '\033[93m'
28 FAIL = '\033[91m'
29 ENDC = '\033[0m'
30 
31 
32 if __name__ == "__main__":
33  rospy.init_node("random_forest_client")
34  br = cv_bridge.CvBridge()
35  pub_img = rospy.Publisher('~output/debug_image', Image, queue_size=1)
36  rospy.wait_for_service('predict')
37  rospy.loginfo("Start Request Service!!")
38 
39  predict_data = rospy.ServiceProxy('predict', ClassifyData)
40 
41  old_targets_ok = []
42  old_targets_fail = []
43 
44  while not rospy.is_shutdown():
45  req = ClassifyDataRequest()
46  req_point = ClassDataPoint()
47  target = [random.random(), random.random()]
48  answer = 1
49  #Check if it is in the circle radius = 1?
50  if target[0] * target[0] + target[1] * target[1] > 1:
51  answer = 0
52  req_point.point = target
53  req.data.append(req_point)
54  print(OKGREEN,"Send Request ====================> Answer",ENDC)
55  print(OKGREEN," ",req_point.point," : ",str(answer),ENDC)
56  response = predict_data(req)
57  print(WARNING,"Get the result : ",ENDC)
58  print(WARNING,response.classifications,ENDC)
59  succeed = int(float(response.classifications[0])) == answer
60  if succeed:
61  print(OKBLUE,"Succeed!!!",ENDC)
62  else:
63  print(FAIL,"FAIL...",FAIL)
64  print("--- --- --- ---")
65 
66  # Config for plotting
67  fig = plt.figure()
68  ax = fig.add_subplot(1, 1, 1)
69  ax.set_aspect('equal')
70  ax.grid()
71  ax.set_xlim(-0.1, 1.1)
72  ax.set_ylim(-0.1, 1.1)
73  ax.set_xticks([0.0, 0.5, 1.0])
74  ax.set_yticks([0.0, 0.5, 1.0])
75  ax.set_title('Random Forest Classification Result')
76 
77  # Draw boundary
78  circle = patches.Circle(xy=(0, 0), radius=1.0, fill=False, ec='g')
79  ax.add_patch(circle)
80 
81  # Plot old classification result
82  old_targets_ok_nparr = np.array(old_targets_ok)
83  old_targets_fail_nparr = np.array(old_targets_fail)
84  if old_targets_ok_nparr.size > 0:
85  ax.plot(old_targets_ok_nparr[:, 0], old_targets_ok_nparr[:, 1],
86  'bo', label='Successfully classified')
87  if old_targets_fail_nparr.size > 0:
88  ax.plot(old_targets_fail_nparr[:, 0], old_targets_fail_nparr[:, 1],
89  'rx', label='Failed to classify')
90 
91  # Plot current classification result
92  if succeed:
93  ax.plot(target[0], target[1], 'bo', markersize=12)
94  old_targets_ok.append(target)
95  else:
96  ax.plot(target[0], target[1], 'rx', markersize=12)
97  old_targets_fail.append(target)
98 
99  # Update and publish the image.
100  # Somehow we have to call tight_layout() some times
101  # to keep the legend from protruding from the image.
102  ax.legend(bbox_to_anchor=(1.05, 0.5), loc='center left')
103  for i in range(3):
104  fig.tight_layout()
105  fig.canvas.draw()
106  w, h = fig.canvas.get_width_height()
107  img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
108  img.shape = (h, w, 3)
109  fig.clf()
110  plt.close()
111  img_msg = br.cv2_to_imgmsg(img, 'rgb8')
112  pub_img.publish(img_msg)
113 
114  rospy.sleep(1)


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27