Go to the documentation of this file.00001
00002 import roslib; roslib.load_manifest('iri_grasping_learner')
00003 import rospy
00004
00005 import numpy as np
00006 from scipy.spatial.distance import cdist
00007 import sys
00008 from os.path import exists
00009 import cPickle
00010 import cv
00011 from cv_bridge import CvBridge, CvBridgeError
00012
00013 from iri_grasping_learner.srv import *
00014 from iri_learning_msgs.msg import GraspInfo
00015
00016 import ipdb
00017
00018 def cv_select_point(event, x, y, flags, param):
00019
00020 if event and cv.CV_EVENT_RBUTTONDOWN:
00021 param.append((x,y))
00022
00023
00024 def user_select_point(cvimage):
00025 """Select point based on user input."""
00026 cv.NamedWindow('select_point')
00027 cv.ShowImage('select_point', cvimage)
00028 selected_point=[]
00029 cv.SetMouseCallback('select_point', cv_select_point, selected_point)
00030 while len(selected_point)==0:
00031 cv.WaitKey(100)
00032 cv.DestroyWindow('select_point')
00033
00034 return selected_point
00035
00036
00037 class GraspsMemory:
00038 def __init__(self):
00039 self.descriptors = None
00040 self.grasps = []
00041 self.threshold = 0.3
00042
00043 def add_grasp(self, desc, ginfo):
00044 self.grasps.append(ginfo)
00045 if self.descriptors == None:
00046 self.descriptors = desc
00047 else:
00048 self.descriptors = np.vstack((self.descriptors,desc))
00049 return True
00050
00051 def compute_threshold(self, desc_set):
00052 pass
00053
00054
00055 class LearnedGrasps:
00056 def __init__(self):
00057 self.grasps_memory = None
00058 self.path_memory = None
00059
00060 def select_learned_grasp_point_cb(self, req):
00061
00062
00063
00064 bridge = CvBridge()
00065 try:
00066 cv_image = bridge.imgmsg_to_cv(req.ima, "bgr8")
00067 except CvBridgeError, e:
00068 print e
00069
00070
00071
00072 selected_point = np.array(user_select_point(cv_image)[0])
00073
00074
00075
00076
00077 descs = np.array([d.descriptor for d in req.descs.descriptors])
00078 uvs = np.array([(d.u, d.v) for d in req.descs.descriptors])
00079
00080 pos_image = np.argmin(cdist(np.matrix(selected_point), uvs)[0])
00081 selected_desc = descs[pos_image,:]
00082
00083 if len(self.grasps_memory.grasps) != 0:
00084
00085 dists_to_memory = cdist(np.matrix(selected_desc), np.matrix(self.grasps_memory.descriptors))
00086
00087 pos_memory = np.argmin(dists_to_memory)
00088 score_memory = np.min(dists_to_memory)
00089
00090 if score_memory < self.grasps_memory.threshold:
00091 knows = True
00092 grasp_info = self.grasps_memory.grasps[pos_memory]
00093 else:
00094 knows = False
00095 grasp_info = GraspInfo()
00096 else:
00097 knows = False
00098 grasp_info = GraspInfo()
00099
00100
00101 res = FindGoodPointResponse()
00102 res.knows = knows
00103 res.grasp_info = grasp_info
00104 res.desc_index = pos_image
00105 return res
00106
00107 def learn_new_grasp_cb(self, req):
00108 the_desc = np.array(req.desc.descriptors[req.desc_index].descriptor)
00109 update_ok = self.grasps_memory.add_grasp(the_desc, req.grasp_info)
00110 if update_ok:
00111 pf = open(self.path_memory,'w')
00112 cPickle.dump(self.grasps_memory,pf)
00113 pf.close()
00114 res = LearnNewGraspResponse()
00115 res.update_ok = update_ok
00116 return res
00117
00118 def load_memory(self, path_memory=None):
00119 self.path_memory = path_memory
00120 if exists(self.path_memory):
00121 pf=open(self.path_memory)
00122 self.grasps_memory = cPickle.load(pf)
00123 pf.close()
00124 else:
00125 self.grasps_memory = GraspsMemory()
00126
00127 def listener(self):
00128 rospy.init_node('grasping_learner')
00129 serv = rospy.Service('get_grasping_point', FindGoodPoint, self.select_learned_grasp_point_cb)
00130 serv2 = rospy.Service('learn_new_grasping', LearnNewGrasp, self.learn_new_grasp_cb)
00131
00132 if (not rospy.has_param("~path_memory")):
00133 rospy.logerr("Node has no param path_memory. Please set it up on launch file")
00134 sys.exit()
00135 path_memory = rospy.get_param("~path_memory")
00136 print "Launching grasp learner. Loading/saving from/to %s."%path_memory
00137 self.load_memory(path_memory)
00138 rospy.spin()
00139
00140
00141
00142 if __name__=="__main__":
00143 LfD = LearnedGrasps()
00144 LfD.listener()
00145
00146