trajectory_sampling_gui.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 PKG = 'trajectory_sampling_gui'
00004 import roslib; roslib.load_manifest(PKG)
00005 
00006 import rospy
00007 import rosbag
00008 
00009 import os
00010 import optparse
00011 import sys
00012 import time
00013 import thread
00014 import copy
00015 
00016 from sensor_msgs.msg import JointState
00017 from dmp_motion_learner.srv import LearnJointSpaceDMP
00018 from dmp_motion_controller.srv import AddToExecuteDMPQueue, AddToDualArmExecuteDMPQueue
00019 from pr2_mechanism_msgs.srv import SwitchController
00020 
00021 import wxversion
00022 WXVER = '2.8'
00023 if wxversion.checkInstalled(WXVER):
00024     wxversion.select(WXVER)
00025 else:
00026     print >> sys.stderr, 'This application requires wxPython version %s' % WXVER
00027     sys.exit(1)
00028 import wx
00029 
00030 class DMP:
00031     LEFT = 0
00032     RIGHT = 1
00033     BOTH = 2
00034 
00035 class TrajectorySamplingGuiApp(wx.App):
00036     def __init__(self, options):
00037         self.options = options
00038         self.bag = None
00039         self.sub = None
00040         wx.App.__init__(self)
00041     
00042     def OnInit(self):
00043         frame = TrajectorySamplingFrame(None, self.options)
00044         frame.Show()
00045         self.SetTopWindow(frame)
00046         return True
00047 
00048 class TrajectorySamplingFrame(wx.Frame):
00049     def __init__(self, parent, options, id=wx.ID_ANY, title='TrajectorySampling', pos=wx.DefaultPosition, size=(320, 330), style=wx.DEFAULT_FRAME_STYLE):
00050         wx.Frame.__init__(self, parent, id, title, pos, size, style)
00051 
00052         self.options = options
00053 
00054         self.font = wx.Font(9, wx.FONTFAMILY_MODERN, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL)
00055 
00056         self.colors = ['red', 'blue', 'purple', 'cyan', 'brown']
00057         self.robots = ['sim', 'pr2']
00058         self.robot = 'sim'
00059         
00060         self.recorders = ['joint_state', 'external']
00061         self.recorder = 'joint_state'
00062 
00063         self.what = 'sample_trajectory_to_bag_file'
00064 
00065         self.filename = None
00066         self.recording = False
00067         self.first_time = False
00068         self.record_pid = None
00069         self.mark_sub = None
00070         
00071         self.init_from_rospy()
00072         
00073         # setup services
00074         # enable rl loop if requested
00075         if options.rl:
00076             from learner_node.srv import start
00077             from learner_node.srv import stop
00078             from learner_node.srv import setStartState
00079             from rtrl.srv import setGoal
00080             self.learner_start_call = rospy.ServiceProxy('/learner_node/start', start)
00081             self.learner_stop_call = rospy.ServiceProxy('/learner_node/stop', stop)
00082             self.learner_setstart_call = rospy.ServiceProxy('/learner_node/set_start_state', setStartState)
00083             self.dmp_setgoal_call = rospy.ServiceProxy('/rtrl_controller/dmp/set_goal', setGoal)
00084             
00085         self.switch_controller = rospy.ServiceProxy('pr2_controller_manager/switch_controller', SwitchController)
00086         self.dmp_call = rospy.ServiceProxy('/dmp_motion_learner/learn_joint_space_dmp_from_bag_file', LearnJointSpaceDMP)
00087         self.execute_dmp_left  = rospy.ServiceProxy('/l_arm_dmp_joint_position_controller/add_to_execute_dmp_queue', AddToExecuteDMPQueue)
00088         self.execute_dmp_right = rospy.ServiceProxy('/r_arm_dmp_joint_position_controller/add_to_execute_dmp_queue', AddToExecuteDMPQueue)
00089         self.execute_dmp_both  = rospy.ServiceProxy('/dual_arm_dmp_joint_position_controller/add_to_execute_dmp_queue', AddToDualArmExecuteDMPQueue)
00090         self.dmp_execute_call  = 1
00091 
00092         self.build_menu()
00093         self.build_main_gui()
00094         
00095         # --> try to list all bag files
00096         self.fill_bag_chooser()
00097     
00098     ## initialization
00099     
00100     def init_from_rospy(self):
00101         
00102         # default type --> right arm
00103         self.dmp_type = DMP.RIGHT
00104         
00105         self.dmp_joints_r = ['r_shoulder_pan_joint', 'r_shoulder_lift_joint', 'r_upper_arm_roll_joint', 'r_elbow_flex_joint', 'r_forearm_roll_joint', 'r_wrist_flex_joint', 'r_wrist_roll_joint']
00106         self.dmp_joints_l = ['l_shoulder_pan_joint', 'l_shoulder_lift_joint', 'l_upper_arm_roll_joint', 'l_elbow_flex_joint', 'l_forearm_roll_joint', 'l_wrist_flex_joint', 'l_wrist_roll_joint']
00107         self.joints_r_arm = ['r_upper_arm_roll_joint', 'r_shoulder_pan_joint','r_shoulder_lift_joint','r_forearm_roll_joint','r_elbow_flex_joint','r_wrist_flex_joint','r_wrist_roll_joint']
00108         self.joints_l_arm = ['l_upper_arm_roll_joint', 'l_shoulder_pan_joint','l_shoulder_lift_joint','l_forearm_roll_joint','l_elbow_flex_joint','l_wrist_flex_joint','l_wrist_roll_joint']
00109         
00110         self.start_state_planner_r = []
00111         self.start_state_planner_l = []
00112         
00113         if not self.options.offline:
00114             if not rospy.has_param('/trajectory_sampling/data_directory_name'):
00115                 rospy.logerr("/trajectory_sampling/data_directory_name not set!")
00116                 self.bag_dir = os.path.expanduser("~") +  "/bag_dir" # sensible default ;)
00117             else:
00118                 self.bag_dir = rospy.get_param('/trajectory_sampling/data_directory_name')
00119                 
00120             if not rospy.has_param('/dmp_execution/dmp_dir'):
00121                 rospy.logerr("/dmp_execution/dmp_dir not set!")
00122                 self.scp_bag_dir = os.path.expanduser("~") +  "/bag_dir"
00123             else:
00124                 self.scp_bag_dir = rospy.get_param('/dmp_execution/dmp_dir')
00125             
00126             if rospy.has_param('/learner_node/start_state_names'):
00127                 rospy.logerr("learner_node not set!")
00128                 self.tracking_joints = set(rospy.get_param('/learner_node/start_state_names'))
00129             else:
00130                 self.tracking_joints = set(self.joints_r_arm)
00131         else:
00132             self.bag_dir = os.path.expanduser("~") + "/bag_dir"
00133             self.scp_bag_dir = os.path.expanduser("~") + "/bag_dir"
00134             self.tracking_joints = set(self.joints_r_arm)
00135     
00136     def build_menu(self):
00137         self.menu_bar = wx.MenuBar()
00138 
00139         self.file_menu = wx.Menu()
00140         self.file_exit_item = wx.MenuItem(self.file_menu, -1, 'E&xit', 'Exit the program')
00141         self.Bind(wx.EVT_MENU, self.on_exit, self.file_exit_item)
00142         self.file_menu.AppendItem(self.file_exit_item)
00143         self.menu_bar.Append(self.file_menu, '&File')
00144         
00145         self.dmp_menu = wx.Menu()
00146         self.type_r_arm_item = wx.MenuItem(self.dmp_menu, -1, '&right arm', 'Use right arm for dmp')
00147         self.Bind(wx.EVT_MENU, self.on_choose_right, self.type_r_arm_item)
00148         self.dmp_menu.AppendItem(self.type_r_arm_item)
00149         self.type_l_arm_item = wx.MenuItem(self.dmp_menu, -1, '&light arm', 'Use left arm for dmp')
00150         self.Bind(wx.EVT_MENU, self.on_choose_left, self.type_l_arm_item)
00151         self.dmp_menu.AppendItem(self.type_l_arm_item)
00152         self.type_b_arm_item = wx.MenuItem(self.dmp_menu, -1, '&both arms', 'Use both arms for dmp')
00153         self.Bind(wx.EVT_MENU, self.on_choose_both, self.type_b_arm_item)
00154         self.dmp_menu.AppendItem(self.type_b_arm_item)
00155         self.menu_bar.Append(self.dmp_menu, '&Type')
00156         
00157         self.SetMenuBar(self.menu_bar)
00158 
00159     def build_main_gui(self):
00160         self.output_text = wx.TextCtrl  (self, -1, '', pos=(1, 100), size=(318, 130), style=wx.TE_MULTILINE)
00161         self.record_bag_text = wx.TextCtrl  (self, -1, 'record.bag', pos=(5, 5), size=(180, 29))
00162         self.dmp_bag_text = wx.TextCtrl  (self, -1, 'dmp.bag', pos=(5, 30), size=(180, 29))
00163         self.bag_chooser = wx.ComboBox  (self, -1, 'Recorded Bag Files', pos=(5, 59), size=(180, 29), style=wx.CB_READONLY)
00164         self.record_button = wx.Button    (self, -1, 'Start Record', pos=(200, 5))
00165         self.learn_button = wx.Button    (self, -1, 'Learn DMP',             pos=(200,  35))
00166         self.distribute_button = wx.Button    (self, -1, 'Distribute DMP',             pos=(200,  65))
00167         if options.rl:
00168             # when rl is activated show more options to control the learning loop
00169             self.start_button = wx.Button    (self, -1, 'Start Execution',             pos=(30,  240))
00170             self.stop_button = wx.Button    (self, -1, 'Stop Execution',             pos=(170,  240))
00171             self.mark_start_button = wx.Button    (self, -1, 'Mark Start',             pos=(30,  270))
00172             self.mark_end_button = wx.Button    (self, -1, 'Mark Goal',             pos=(170,  270))
00173             self.Bind(wx.EVT_BUTTON, lambda e: self.start_learner(), self.start_button)
00174             self.Bind(wx.EVT_BUTTON, lambda e: self.stop_learner(), self.stop_button)
00175             self.Bind(wx.EVT_BUTTON, lambda e: self.mark_start(), self.mark_start_button)
00176             self.Bind(wx.EVT_BUTTON, lambda e: self.mark_end(), self.mark_end_button)
00177         else:
00178             # otherwise allow for simple playback of the recorded dmp
00179             self.start_button = wx.Button    (self, -1, 'Execute DMP',             pos=(105,  240))
00180             self.Bind(wx.EVT_BUTTON, lambda e: self.execute_dmp(), self.start_button)
00181         
00182         self.output_text.SetFont(self.font)
00183         #self.color_combo = wx.ComboBox  (self, -1, '', pos=(240, 80), size=(90, 29), style=wx.CB_READONLY)
00184         
00185         self.record_button.SetBackgroundColour(wx.GREEN)
00186 
00187         #self.Bind(wx.EVT_SIZE, self.on_size)
00188 
00189         #self.output_text.Bind(wx.EVT_TEXT, self.on_text)
00190         
00191         self.Bind(wx.EVT_BUTTON, lambda e: self.toggle_recording(), self.record_button)
00192         self.Bind(wx.EVT_BUTTON, lambda e: self.learn(), self.learn_button)
00193         self.Bind(wx.EVT_BUTTON, lambda e: self.distribute(), self.distribute_button)
00194     
00195     def fill_bag_chooser(self):
00196         bags = os.listdir(self.bag_dir)
00197         self.bag_chooser.SetItems(sorted(bags))
00198     
00199     ## helpers
00200     def append_output(self, output):
00201         self.output_text.AppendText(output + '\n')
00202     
00203     def on_text(self, event):
00204         # Scroll to end
00205         self.output_text.ScrollPages(1)
00206     
00207     ## callbacks
00208     
00209     def on_exit(self, event):
00210         self.Close()
00211         
00212     def on_choose_right(self, event):
00213         self.dmp_type = DMP.RIGHT
00214     
00215     def on_choose_left(self, event):
00216         self.dmp_type = DMP.LEFT
00217     
00218     def on_choose_both(self, event):
00219         self.dmp_type = DMP.BOTH
00220     
00221     ## Recording
00222 
00223     def toggle_recording(self):
00224         if self.recording:
00225             self.stop_recording()
00226         else:
00227             self.start_recording()
00228     
00229     def learn(self):
00230         bag_file = str(self.bag_chooser.GetValue())
00231         self.learn_dmp(bag_file)
00232     
00233     def learn_dmp(self, filename):
00234         # --> default joints to right 
00235         dmp_joints = self.dmp_joints_r[:]
00236         # and adapt if necessary
00237         if self.dmp_type == DMP.LEFT:
00238             dmp_joints = self.dmp_joints_l[:]
00239         elif self.dmp_type == DMP.BOTH:
00240             dmp_joints.extend(self.dmp_joints_l)
00241 
00242         resp = self.dmp_call(joint_names = dmp_joints, bag_file_name = self.bag_dir + '/' + filename, data_directory_name = self.bag_dir, dmp_id = 1)
00243 
00244         self.append_output(str(resp))
00245 
00246         # store the dmp to a bag file
00247         bag = rosbag.Bag(self.bag_dir + '/' + str(self.dmp_bag_text.GetValue()), 'w')
00248         bag.write('dynamic_movement_primitive', resp.dmp)
00249         bag.close()
00250     
00251     def execute_dmp(self):
00252         exec_dmp = None
00253         bag = rosbag.Bag(self.bag_dir + '/' + str(self.dmp_bag_text.GetValue()))
00254         for topic, msg, t in bag.read_messages(topics=['dynamic_movement_primitive']):
00255             if topic == 'dynamic_movement_primitive': # safety check
00256                 exec_dmp = msg
00257         bag.close()
00258         if exec_dmp == None:
00259             self.append_output('Error could not read dmp from bag!')
00260             return
00261         exec_dmp.is_setup = True
00262         exec_dmp.is_start_set = False
00263         if self.dmp_type == DMP.RIGHT:
00264             self.switch_controller.call(start_controllers = ['r_arm_dmp_joint_position_controller'], stop_controllers = ['r_arm_controller'], strictness = 2)
00265             self.execute_dmp_right.call(dmps = [exec_dmp], execution_durations = [exec_dmp.initial_duration], types = [0])
00266         elif self.dmp_type == DMP.LEFT:
00267             self.switch_controller.call(start_controllers = ['l_arm_dmp_joint_position_controller'], stop_controllers = ['l_arm_controller'], strictness = 2)
00268             self.execute_dmp_left.call(dmps = [exec_dmp], execution_durations = [exec_dmp.initial_duration], types = [0])
00269         elif self.dmp_type == DMP.BOTH:
00270             self.switch_controller.call(start_controllers = ['dual_arm_dmp_joint_position_controller', 'r_arm_dmp_joint_position_controller', 'l_arm_dmp_joint_position_controller', ], stop_controllers = ['r_arm_controller', 'l_arm_controller'], strictness = 2)
00271             right_exec = copy.deepcopy(exec_dmp)
00272             right_exec.transformation_systems = right_exec.transformation_systems[0:7]
00273             left_exec = copy.deepcopy(exec_dmp)
00274             left_exec.transformation_systems = left_exec.transformation_systems[7:]
00275             self.execute_dmp_both.call(left_arm_dmps = [left_exec], right_arm_dmps = [right_exec], execution_durations = [exec_dmp.initial_duration], left_arm_types = [0], right_arm_types = [0])
00276     
00277     def distribute(self):
00278                 def copy_dmp(source, target):
00279                         print("DISTRIBUTING dmp")
00280                         os.system('scp %s marvin:%s' % (source, target))
00281                         #os.system('scp tt localhost:~/')
00282                         print("DONE")
00283                         
00284                 self.append_output('~~~ DISTRIBUTING dmp ~~~\n--> see command line for instructions')
00285                 source = self.bag_dir + '/' + str(self.dmp_bag_text.GetValue())
00286                 target = self.scp_bag_dir + '/' + self.dmp_bag_text.GetValue()
00287                 thread.start_new_thread(copy_dmp ,(source, target))
00288                 
00289                 if not self.options.offline:
00290                         rospy.set_param('/dmp_execution/dmp_file_name', target)
00291                 #os.system('echo "scp %s marvin:%s/%s"' % (source, self.scp_bag_dir, self.dmp_bag_text.GetValue()))
00292     
00293     def start_learner(self):
00294                 call_string = '/learner_node/start'
00295                 self.append_output('calling %s' % (call_string))
00296                 resp = self.learner_start_call(rospy.Time.now())
00297                 self.append_output('Starting learner resp: %s' % (str(resp)))
00298     
00299     def stop_learner(self):
00300                 call_string = '/learner_node/stop'
00301                 self.append_output('calling %s' % (call_string))
00302                 resp = self.learner_stop_call(rospy.Time.now())
00303                 self.append_output('Stopping learner resp: %s' % (str(resp)))
00304         
00305     def mark_start(self):
00306         if self.mark_sub:
00307             self.mark_sub.unregister()
00308         self.append_output('Registering for sampling start state')
00309         self.mark_sub = rospy.Subscriber('/joint_states', JointState, self.mark_start_callback)
00310         
00311     def mark_start_callback(self, msg):
00312         idx = 0
00313         start_state = []
00314         for name in msg.name:
00315             if name in self.tracking_joints:
00316                 # pos
00317                 start_state.append(msg.position[idx])
00318                 # vel
00319                 start_state.append(0.0)
00320                 # acc
00321                 start_state.append(0.0)
00322             idx += 1
00323         # unregister --> we only need to sample one state here
00324         self.mark_sub.unregister()
00325         rospy.set_param('/learner_node/start_state', start_state)
00326         self.append_output('Done sampling start state --> call learner')
00327         resp = self.learner_setstart_call(rospy.Time.now(), start_state);
00328         self.append_output('learner response: %s' % (str(resp)))
00329         
00330     def mark_end(self):
00331         if self.mark_sub:
00332             self.mark_sub.unregister()
00333         self.append_output('Registering for sampling goal state')
00334         self.mark_sub = rospy.Subscriber('/joint_states', JointState, self.mark_end_callback)
00335     
00336     def mark_end_callback(self, msg):
00337         idx = 0
00338         end_state = []
00339         state_dict = dict()
00340         for name in msg.name:
00341             state_dict[name] = msg.position[idx]
00342             idx += 1
00343         for name in self.dmp_joints:  
00344             if name in state_dict:
00345                 # pos
00346                 end_state.append(state_dict[name])
00347         # unregister --> we only need to sample one state here
00348         self.mark_sub.unregister()
00349         rospy.set_param('/learner_node/terminal_pose', end_state)
00350         self.append_output('Done sampling goal state --> call dmp')
00351         resp = self.dmp_setgoal_call(rospy.Time.now(), end_state);
00352         self.append_output('dmp response: %s' % (str(resp)))
00353     
00354     def start_recording(self):
00355         if self.recording:
00356             return
00357         
00358         self.first_time = True
00359         self.record_button.SetLabel('Stop Record')
00360         self.record_button.SetBackgroundColour(wx.RED)
00361         
00362         self.filename = self.record_bag_text.GetValue()
00363                 
00364         self.recording = True
00365         
00366         self.bag = rosbag.Bag(self.bag_dir + '/' + self.filename, 'w')
00367         self.sub = rospy.Subscriber('/joint_states', JointState, self.record_callback)
00368 
00369     def record_callback(self, msg):
00370         if self.first_time:
00371             self.first_time = False
00372             self.start_state_planner_r = []
00373             self.start_state_planner_l = []
00374             idx = 0
00375             for name in msg.name:
00376                 if name in self.joints_r_arm:
00377                     self.start_state_planner_r.append(msg.position[idx])
00378                 if name in self.joints_l_arm:
00379                     self.start_state_planner_l.append(msg.position[idx])
00380                 idx += 1
00381             
00382         self.bag.write('/joint_states', msg)
00383 
00384     def stop_recording(self):
00385         if not self.recording:
00386             return
00387         
00388         self.sub.unregister()
00389         
00390         # wait for bag writing to finish
00391         time.sleep(0.1)
00392         self.bag.close()
00393         
00394         self.record_button.SetLabel('Start Record')
00395         self.record_button.SetBackgroundColour(wx.GREEN)
00396 
00397         # Wait before copying the file
00398         time.sleep(0.1)
00399 
00400         # re check for all bag files
00401         self.fill_bag_chooser()
00402 
00403         self.recording = False
00404 
00405 
00406 if __name__ == '__main__':
00407     parser = optparse.OptionParser()
00408     parser.add_option('-o', '--offline', action='store_true', default=False, help='run offline - don\'t connect to core')   
00409     parser.add_option('-r', '--rl', action='store_true', default=False, help='enable reinforcement learning loop --> rtrl package required')
00410     options, args = parser.parse_args(sys.argv[1:])
00411 
00412     if not options.offline:
00413         rospy.init_node('trajectory_sampling_gui', anonymous=True)
00414 
00415     app = TrajectorySamplingGuiApp(options)
00416     app.MainLoop()
00417  
00418     rospy.signal_shutdown('GUI shutdown')    
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Properties Friends


trajectory_sampling_gui
Author(s):
autogenerated on Wed Dec 26 2012 16:43:15