random_forest_client_sample.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import random
00004 
00005 import matplotlib
00006 matplotlib.use('Agg')  # NOQA
00007 from matplotlib import patches
00008 import matplotlib.pyplot as plt
00009 import numpy as np
00010 
00011 import cv_bridge
00012 import rospy
00013 from sensor_msgs.msg import Image
00014 
00015 try:
00016     from ml_classifiers.srv import *
00017     from ml_classifiers.msg import *
00018 except:
00019     import roslib;roslib.load_manifest("ml_classifiers")
00020     from ml_classifiers.srv import *
00021     from ml_classifiers.msg import *
00022 
00023 
00024 HEADER = '\033[95m'
00025 OKBLUE = '\033[94m'
00026 OKGREEN = '\033[92m'
00027 WARNING = '\033[93m'
00028 FAIL = '\033[91m'
00029 ENDC = '\033[0m'
00030 
00031 
00032 if __name__ == "__main__":
00033     rospy.init_node("random_forest_client")
00034     br = cv_bridge.CvBridge()
00035     pub_img = rospy.Publisher('~output/debug_image', Image, queue_size=1)
00036     rospy.wait_for_service('predict')
00037     rospy.loginfo("Start Request Service!!")
00038 
00039     predict_data = rospy.ServiceProxy('predict', ClassifyData)
00040 
00041     old_targets_ok = []
00042     old_targets_fail = []
00043 
00044     while not rospy.is_shutdown():
00045         req = ClassifyDataRequest()
00046         req_point = ClassDataPoint()
00047         target = [random.random(), random.random()]
00048         answer = 1
00049         #Check if it is in the circle radius = 1?
00050         if target[0] * target[0] + target[1] * target[1] > 1:
00051             answer = 0
00052         req_point.point = target
00053         req.data.append(req_point)
00054         print OKGREEN,"Send Request         ====================>         Answer",ENDC
00055         print OKGREEN,"    ",req_point.point,"       : ",str(answer),ENDC
00056         response = predict_data(req)
00057         print WARNING,"Get the result : ",ENDC
00058         print WARNING,response.classifications,ENDC
00059         succeed = int(float(response.classifications[0])) == answer
00060         if succeed:
00061             print OKBLUE,"Succeed!!!",ENDC
00062         else:
00063             print FAIL,"FAIL...",FAIL
00064         print "--- --- --- ---"
00065 
00066         # Config for plotting
00067         fig = plt.figure()
00068         ax = fig.add_subplot(1, 1, 1)
00069         ax.set_aspect('equal')
00070         ax.grid()
00071         ax.set_xlim(-0.1, 1.1)
00072         ax.set_ylim(-0.1, 1.1)
00073         ax.set_xticks([0.0, 0.5, 1.0])
00074         ax.set_yticks([0.0, 0.5, 1.0])
00075         ax.set_title('Random Forest Classification Result')
00076 
00077         # Draw boundary
00078         circle = patches.Circle(xy=(0, 0), radius=1.0, fill=False, ec='g')
00079         ax.add_patch(circle)
00080 
00081         # Plot old classification result
00082         old_targets_ok_nparr = np.array(old_targets_ok)
00083         old_targets_fail_nparr = np.array(old_targets_fail)
00084         if old_targets_ok_nparr.size > 0:
00085             ax.plot(old_targets_ok_nparr[:, 0], old_targets_ok_nparr[:, 1],
00086                     'bo', label='Successfully classified')
00087         if old_targets_fail_nparr.size > 0:
00088             ax.plot(old_targets_fail_nparr[:, 0], old_targets_fail_nparr[:, 1],
00089                     'rx', label='Failed to classify')
00090 
00091         # Plot current classification result
00092         if succeed:
00093             ax.plot(target[0], target[1], 'bo', markersize=12)
00094             old_targets_ok.append(target)
00095         else:
00096             ax.plot(target[0], target[1], 'rx', markersize=12)
00097             old_targets_fail.append(target)
00098 
00099         # Update and publish the image.
00100         # Somehow we have to call tight_layout() some times
00101         # to keep the legend from protruding from the image.
00102         ax.legend(bbox_to_anchor=(1.05, 0.5), loc='center left')
00103         for i in range(3):
00104             fig.tight_layout()
00105         fig.canvas.draw()
00106         w, h = fig.canvas.get_width_height()
00107         img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
00108         img.shape = (h, w, 3)
00109         fig.clf()
00110         plt.close()
00111         img_msg = br.cv2_to_imgmsg(img, 'rgb8')
00112         pub_img.publish(img_msg)
00113 
00114         rospy.sleep(1)


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07