7 from matplotlib
import patches
8 import matplotlib.pyplot
as plt
13 from sensor_msgs.msg
import Image
19 import roslib;roslib.load_manifest(
"ml_classifiers")
32 if __name__ ==
"__main__":
33 rospy.init_node(
"random_forest_client")
34 br = cv_bridge.CvBridge()
35 pub_img = rospy.Publisher(
'~output/debug_image', Image, queue_size=1)
36 rospy.wait_for_service(
'predict')
37 rospy.loginfo(
"Start Request Service!!")
39 predict_data = rospy.ServiceProxy(
'predict', ClassifyData)
44 while not rospy.is_shutdown():
45 req = ClassifyDataRequest()
46 req_point = ClassDataPoint()
47 target = [random.random(), random.random()]
50 if target[0] * target[0] + target[1] * target[1] > 1:
52 req_point.point = target
53 req.data.append(req_point)
54 print(OKGREEN,
"Send Request ====================> Answer",ENDC)
55 print(OKGREEN,
" ",req_point.point,
" : ",str(answer),ENDC)
57 print(WARNING,
"Get the result : ",ENDC)
58 print(WARNING,response.classifications,ENDC)
59 succeed =
int(float(response.classifications[0])) == answer
61 print(OKBLUE,
"Succeed!!!",ENDC)
63 print(FAIL,
"FAIL...",FAIL)
64 print(
"--- --- --- ---")
68 ax = fig.add_subplot(1, 1, 1)
69 ax.set_aspect(
'equal')
71 ax.set_xlim(-0.1, 1.1)
72 ax.set_ylim(-0.1, 1.1)
73 ax.set_xticks([0.0, 0.5, 1.0])
74 ax.set_yticks([0.0, 0.5, 1.0])
75 ax.set_title(
'Random Forest Classification Result')
78 circle = patches.Circle(xy=(0, 0), radius=1.0, fill=
False, ec=
'g')
82 old_targets_ok_nparr = np.array(old_targets_ok)
83 old_targets_fail_nparr = np.array(old_targets_fail)
84 if old_targets_ok_nparr.size > 0:
85 ax.plot(old_targets_ok_nparr[:, 0], old_targets_ok_nparr[:, 1],
86 'bo', label=
'Successfully classified')
87 if old_targets_fail_nparr.size > 0:
88 ax.plot(old_targets_fail_nparr[:, 0], old_targets_fail_nparr[:, 1],
89 'rx', label=
'Failed to classify')
93 ax.plot(target[0], target[1],
'bo', markersize=12)
94 old_targets_ok.append(target)
96 ax.plot(target[0], target[1],
'rx', markersize=12)
97 old_targets_fail.append(target)
102 ax.legend(bbox_to_anchor=(1.05, 0.5), loc=
'center left')
106 w, h = fig.canvas.get_width_height()
107 img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
108 img.shape = (h, w, 3)
111 img_msg = br.cv2_to_imgmsg(img,
'rgb8')
112 pub_img.publish(img_msg)