00001
00002
00003 import random
00004
00005 import matplotlib
00006 matplotlib.use('Agg')
00007 from matplotlib import patches
00008 import matplotlib.pyplot as plt
00009 import numpy as np
00010
00011 import cv_bridge
00012 import rospy
00013 from sensor_msgs.msg import Image
00014
00015 try:
00016 from ml_classifiers.srv import *
00017 from ml_classifiers.msg import *
00018 except:
00019 import roslib;roslib.load_manifest("ml_classifiers")
00020 from ml_classifiers.srv import *
00021 from ml_classifiers.msg import *
00022
00023
00024 HEADER = '\033[95m'
00025 OKBLUE = '\033[94m'
00026 OKGREEN = '\033[92m'
00027 WARNING = '\033[93m'
00028 FAIL = '\033[91m'
00029 ENDC = '\033[0m'
00030
00031
00032 if __name__ == "__main__":
00033 rospy.init_node("random_forest_client")
00034 br = cv_bridge.CvBridge()
00035 pub_img = rospy.Publisher('~output/debug_image', Image, queue_size=1)
00036 rospy.wait_for_service('predict')
00037 rospy.loginfo("Start Request Service!!")
00038
00039 predict_data = rospy.ServiceProxy('predict', ClassifyData)
00040
00041 old_targets_ok = []
00042 old_targets_fail = []
00043
00044 while not rospy.is_shutdown():
00045 req = ClassifyDataRequest()
00046 req_point = ClassDataPoint()
00047 target = [random.random(), random.random()]
00048 answer = 1
00049
00050 if target[0] * target[0] + target[1] * target[1] > 1:
00051 answer = 0
00052 req_point.point = target
00053 req.data.append(req_point)
00054 print OKGREEN,"Send Request ====================> Answer",ENDC
00055 print OKGREEN," ",req_point.point," : ",str(answer),ENDC
00056 response = predict_data(req)
00057 print WARNING,"Get the result : ",ENDC
00058 print WARNING,response.classifications,ENDC
00059 succeed = int(float(response.classifications[0])) == answer
00060 if succeed:
00061 print OKBLUE,"Succeed!!!",ENDC
00062 else:
00063 print FAIL,"FAIL...",FAIL
00064 print "--- --- --- ---"
00065
00066
00067 fig = plt.figure()
00068 ax = fig.add_subplot(1, 1, 1)
00069 ax.set_aspect('equal')
00070 ax.grid()
00071 ax.set_xlim(-0.1, 1.1)
00072 ax.set_ylim(-0.1, 1.1)
00073 ax.set_xticks([0.0, 0.5, 1.0])
00074 ax.set_yticks([0.0, 0.5, 1.0])
00075 ax.set_title('Random Forest Classification Result')
00076
00077
00078 circle = patches.Circle(xy=(0, 0), radius=1.0, fill=False, ec='g')
00079 ax.add_patch(circle)
00080
00081
00082 old_targets_ok_nparr = np.array(old_targets_ok)
00083 old_targets_fail_nparr = np.array(old_targets_fail)
00084 if old_targets_ok_nparr.size > 0:
00085 ax.plot(old_targets_ok_nparr[:, 0], old_targets_ok_nparr[:, 1],
00086 'bo', label='Successfully classified')
00087 if old_targets_fail_nparr.size > 0:
00088 ax.plot(old_targets_fail_nparr[:, 0], old_targets_fail_nparr[:, 1],
00089 'rx', label='Failed to classify')
00090
00091
00092 if succeed:
00093 ax.plot(target[0], target[1], 'bo', markersize=12)
00094 old_targets_ok.append(target)
00095 else:
00096 ax.plot(target[0], target[1], 'rx', markersize=12)
00097 old_targets_fail.append(target)
00098
00099
00100
00101
00102 ax.legend(bbox_to_anchor=(1.05, 0.5), loc='center left')
00103 for i in range(3):
00104 fig.tight_layout()
00105 fig.canvas.draw()
00106 w, h = fig.canvas.get_width_height()
00107 img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
00108 img.shape = (h, w, 3)
00109 fig.clf()
00110 plt.close()
00111 img_msg = br.cv2_to_imgmsg(img, 'rgb8')
00112 pub_img.publish(img_msg)
00113
00114 rospy.sleep(1)