00001
00002 import rospy
00003 import tool_utils as tu
00004 import glob
00005 import os.path as pt
00006 import cPickle as pk
00007 import os
00008 import smach
00009 import outcome_tool as ot
00010 import graph
00011 import sm_thread_runner as smtr
00012 import time
00013 import shutil
00014
00015 def is_container(node):
00016 return hasattr(node, 'get_child_name')
00017
00018 class FSMDocument:
00019 count = 0
00020 @staticmethod
00021 def new_document():
00022 d = FSMDocument('untitled' + str(FSMDocument.count), False, False)
00023 FSMDocument.count = FSMDocument.count + 1
00024 return d
00025
00026 def __init__(self, filename, modified, real_filename=False):
00027 self.filename = filename
00028 self.modified = modified
00029 self.real_filename = real_filename
00030
00031 def get_name(self):
00032 return pt.split(self.filename)[1]
00033
00034 def get_filename(self):
00035 return self.filename
00036
00037 def set_filename(self, fn):
00038 self.filename = fn
00039
00040 def has_real_filename(self):
00041 return self.real_filename
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056 class GraphModel:
00057
00058
00059 EDGES_FILE = 'edges.graph'
00060
00061
00062 NODES_FILE = 'nodes.graph'
00063
00064 NODE_RADIUS = 14
00065
00066 EDGE_LENGTH = 2.
00067
00068 def __init__(self):
00069
00070 self.gve = graph.create(depth=True)
00071
00072 self.states_dict = {}
00073
00074 self.document = FSMDocument.new_document()
00075 self.start_state = None
00076
00077 self.node = self.gve.node
00078 self.edge = self.gve.edge
00079
00080 self.sm_thread = None
00081 self.status_cb_func = None
00082 self.last_outcome = None
00083
00084
00085 def get_start_state(self):
00086 return self.start_state
00087
00088 def set_start_state(self, state):
00089
00090
00091
00092 self.start_state = state
00093
00094 def set_document(self, document):
00095 self.document = document
00096
00097 def get_document(self):
00098 return self.document
00099
00100 @staticmethod
00101 def load(name):
00102 state_pkl_names = glob.glob(pt.join(name, '*.state'))
00103
00104 gm = GraphModel()
00105 gm.states_dict = {}
00106
00107
00108 nodes_fn = pt.join(name, GraphModel.NODES_FILE)
00109 pickle_file = open(nodes_fn, 'r')
00110 info = pk.load(pickle_file)
00111 gm.start_state = info['start_state']
00112 states_to_load = set(info['state_names'])
00113
00114
00115 for fname in state_pkl_names:
00116 sname = pt.splitext(pt.split(fname)[1])[0]
00117 if not states_to_load.issuperset([sname]):
00118 continue
00119
00120 pickle_file = open(fname, 'r')
00121 rospy.loginfo('Loading state %s' % sname)
00122 gm.states_dict[sname] = pk.load(pickle_file)
00123 gm.gve.add_node(sname, GraphModel.NODE_RADIUS)
00124 pickle_file.close()
00125
00126 rospy.loginfo('Got an instance of %s' % str(gm.states_dict[sname].__class__))
00127
00128 if is_container(gm.states_dict[sname]):
00129
00130 gm.states_dict[sname] = gm.states_dict[sname].load_and_recreate(name)
00131
00132
00133 edges_filename = pt.join(name, GraphModel.EDGES_FILE)
00134 edges_pickle_file = open(edges_filename, 'r')
00135 edges = pk.load(edges_pickle_file)
00136 edges_pickle_file.close()
00137 for node1, node2, n1_outcome in edges:
00138 gm.gve.add_edge(node1, node2, label=n1_outcome, length=GraphModel.EDGE_LENGTH)
00139
00140 gm.set_document(FSMDocument(name, modified=False, real_filename=True))
00141
00142
00143
00144
00145
00146
00147 return gm
00148
00149 def save(self, name):
00150 rospy.loginfo('GraphModel: saving to %s' % name)
00151 if not pt.exists(name):
00152 os.mkdir(name)
00153 else:
00154 shutil.rmtree(name)
00155 os.mkdir(name)
00156
00157
00158 for state_name in self.states_dict.keys():
00159 containerp = is_container(self.states_dict[state_name])
00160 if containerp:
00161 self.states_dict[state_name].save_child(name)
00162 child = self.states_dict[state_name].abort_child()
00163
00164 state_fname = pt.join(name, state_name) + '.state'
00165 pickle_file = open(state_fname, 'w')
00166
00167 pk.dump(self.states_dict[state_name], pickle_file)
00168 pickle_file.close()
00169 if containerp:
00170
00171 self.states_dict[state_name].set_child(child)
00172
00173
00174 edge_list = []
00175 for e in self.gve.edges:
00176 edge_list.append([e.node1.id, e.node2.id, e.label])
00177
00178 edge_fn = pt.join(name, GraphModel.EDGES_FILE)
00179 pickle_file = open(edge_fn, 'w')
00180 pk.dump(edge_list, pickle_file)
00181 pickle_file.close()
00182
00183 nodes_fn = pt.join(name, GraphModel.NODES_FILE)
00184 pickle_file = open(nodes_fn, 'w')
00185 pk.dump({'start_state': self.start_state, 'state_names': self.states_dict.keys()}, pickle_file)
00186 pickle_file.close()
00187
00188 self.document = FSMDocument(name, False, True)
00189
00190 def create_singleton_statemachine(self, state, robot):
00191
00192
00193
00194
00195 temp_gm = GraphModel()
00196 temp_gm.add_node(state)
00197 temp_gm.set_start_state(state.name)
00198
00199 return temp_gm.create_state_machine(robot), temp_gm
00200
00201
00202
00203 def run(self, name="", state_machine=None, userdata=None):
00204
00205
00206
00207
00208
00209 sm = state_machine
00210 sm.register_transition_cb(self._state_machine_transition_cb)
00211 sm.register_start_cb(self._state_machine_start_cb)
00212 sm.register_termination_cb(self._state_machine_termination_cb)
00213
00214 rthread = smtr.ThreadRunSM(name, sm)
00215 rthread.register_termination_cb(self._sm_thread_termination_cb)
00216 self.sm_thread = {}
00217 self.sm_thread['run_sm'] = rthread
00218 self.sm_thread['preempted'] = None
00219 self.sm_thread['current_states'] = None
00220 self.sm_thread['outcome'] = None
00221
00222 rthread.start()
00223 return rthread
00224
00225 def preempt(self):
00226 if self.is_running():
00227 self.sm_thread['run_sm'].preempt()
00228 self.sm_thread['preempted'] = time.time()
00229
00230
00231 def is_running(self):
00232 return self.sm_thread != None
00233
00234 def register_status_cb(self, func):
00235 self.status_cb_func = func
00236
00237 def _sm_thread_termination_cb(self, exception):
00238
00239
00240 if exception != None:
00241 if self.status_cb_func != None:
00242 self.status_cb_func('Error: %s' % str(exception))
00243
00244 elif self.sm_thread['preempted'] != None:
00245 if self.status_cb_func != None:
00246 self.status_cb_func('%s stopped.' % (self.document.get_name()))
00247
00248 self.sm_thread = None
00249
00250 def _state_machine_transition_cb(self, user_data, active_states):
00251 self.sm_thread['current_states'] = active_states
00252 if self.status_cb_func != None:
00253 self.status_cb_func('At state %s' % (active_states[0]))
00254
00255
00256
00257
00258
00259
00260
00261
00262 def _state_machine_start_cb(self, userdata, initial_states):
00263 self.sm_thread['current_states'] = initial_states
00264 if self.status_cb_func != None:
00265 self.status_cb_func('At state %s.' % (initial_states[0]))
00266
00267
00268
00269 def get_last_outcome(self):
00270 return self.last_outcome
00271
00272 def _state_machine_termination_cb(self, userdata, terminal_states, container_outcome):
00273
00274 self.sm_thread['current_states'] = terminal_states
00275 self.sm_thread['outcome'] = container_outcome
00276 self.last_outcome = [container_outcome, time.time()]
00277
00278 if self.status_cb_func != None:
00279 self.status_cb_func('Stopped with outcome %s' % container_outcome)
00280
00281
00282
00283
00284
00285 def outputs_of_type(self, class_filter):
00286 filtered_output_variables = []
00287 for node_name in self.real_states():
00288 node = self.get_state(node_name)
00289 output_names = node.output_names()
00290 for output_name in output_names:
00291 if issubclass(node.output_type(output_name), class_filter):
00292 filtered_output_variables.append(output_name)
00293 return filtered_output_variables
00294
00295 def create_state_machine(self, robot, userdata=None, ignore_start_state=False):
00296
00297 sm = smach.StateMachine(outcomes = self.outcomes())
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311 if userdata != None:
00312
00313 for key in userdata.keys():
00314 exec ("sm.userdata.%s = userdata.%s" % (key, key))
00315
00316 exec ("print 'data in key is', sm.userdata.%s" % (key))
00317
00318 with sm:
00319 for node_name in self.real_states():
00320 node = self.states_dict[node_name]
00321
00322 node_smach = node.get_smach_state()
00323
00324 if hasattr(node_smach, 'set_robot'):
00325
00326 node_smach.set_robot(robot)
00327
00328 transitions = {}
00329 for e in self.gve.node(node_name).edges:
00330 if e.node1.id == node_name:
00331 transitions[e.label] = e.node2.id
00332
00333 input_set = set(node_smach.get_registered_input_keys())
00334 output_set = set(node_smach.get_registered_output_keys())
00335 if len(input_set.intersection(output_set)) > 0:
00336 raise RuntimeError('Input keys has the same name as output_keys.')
00337
00338 remapping = {}
00339 for input_key in node_smach.get_registered_input_keys():
00340 remapping[input_key] = node.remapping_for(input_key)
00341
00342
00343
00344
00345
00346 for output_key in node_smach.get_registered_output_keys():
00347 remapping[output_key] = output_key
00348
00349 smach.StateMachine.add(node_name, node_smach, transitions=transitions, remapping=remapping)
00350
00351 if ignore_start_state:
00352 return sm
00353
00354 if self.start_state == None:
00355 raise RuntimeError('No start state set.')
00356 sm.set_initial_state([self.start_state])
00357 return sm
00358
00359
00360
00361 def current_children_of(self, node_name):
00362 ret_list = []
00363 for edge in self.gve.node(node_name).edges:
00364 if edge.node1.id != node_name:
00365 continue
00366 ret_list.append([edge.label, edge.node2.id])
00367 return ret_list
00368
00369
00370
00371 def real_states(self):
00372 noc = []
00373 for node_name in self.states_dict.keys():
00374 if self.states_dict[node_name].__class__ != tu.EmptyState:
00375 noc.append(node_name)
00376 return noc
00377
00378 def outcomes(self):
00379
00380 oc = []
00381 for node_name in self.states_dict.keys():
00382 if self.states_dict[node_name].__class__ == tu.EmptyState:
00383 oc.append(node_name)
00384 return oc
00385
00386 def pop_state(self, node_name):
00387 return self.states_dict.pop(node_name)
00388
00389 def get_state(self, node_name):
00390
00391 return self.states_dict[node_name]
00392
00393 def set_state(self, node_name, state):
00394 self.states_dict[node_name] = state
00395
00396 def replace_node(self, new_node, old_node_name):
00397 self.states_dict.pop(old_node_name)
00398 self.states_dict[new_node.get_name()] = new_node
00399 new_node_name = new_node.get_name()
00400
00401
00402
00403
00404 if new_node_name != old_node_name:
00405 self.gve.add_node(new_node_name, self.NODE_RADIUS)
00406
00407
00408 new_smach_node = new_node.get_smach_state()
00409 if hasattr(new_smach_node, 'set_robot'):
00410 new_smach_node.set_robot(None)
00411
00412 new_outcomes = new_smach_node.get_registered_outcomes()
00413 for e in self.gve.node(old_node_name).edges:
00414
00415 if e.label in new_outcomes:
00416
00417 if e.node1.id == old_node_name:
00418 self.gve.remove_edge(e.node1.id, e.node2.id, label=e.label)
00419 self.gve.add_edge(new_node_name, e.node2.id, label=e.label, length=GraphModel.EDGE_LENGTH)
00420 elif e.node2.id == old_node_name:
00421 self.gve.remove_edge(e.node1.id, e.node2.id, label=e.label)
00422 self.gve.add_edge(e.node1.id, new_node_name, label=e.label, length=GraphModel.EDGE_LENGTH)
00423
00424
00425 else:
00426 if e.node1.id == old_node_name:
00427 print 'removing edge', e.node1.id, e.node2.id
00428 self.gve.remove_edge(e.node1.id, e.node2.id, label=e.label)
00429 if not self.is_modifiable(e.node2.id) and len(e.node2.edges) < 1:
00430 self.gve.remove_node(e.node2.id)
00431 self.states_dict.pop(e.node2.id)
00432 else:
00433 self.gve.remove_edge(e.node1.id, e.node2.id, label=e.label)
00434 self.gve.add_edge(e.node1.id, new_node_name, label=e.label, length=GraphModel.EDGE_LENGTH)
00435
00436
00437 if new_node_name != old_node_name:
00438 self.gve.remove_node(old_node_name)
00439
00440
00441
00442 self.restore_node_consistency(new_node.get_name())
00443
00444 def connectable_nodes(self, node_name, outcome):
00445
00446
00447 allowed_nodes = []
00448
00449 for state_name in self.states_dict.keys():
00450
00451 if (not self.is_modifiable(state_name)) and (not self._is_type(state_name, outcome)):
00452 continue
00453
00454
00455 if node_name == state_name:
00456 continue
00457
00458 allowed_nodes.append(state_name)
00459
00460 if node_name == None:
00461 return []
00462
00463
00464
00465 else:
00466 for edge in self.gve.node(node_name).edges:
00467
00468 if edge.label == outcome and edge.node1.id == node_name and (not self._is_type(edge.node2.id, outcome)):
00469
00470 allowed_nodes.append(self._create_outcome_name(outcome))
00471
00472 allowed_nodes.sort()
00473 return allowed_nodes
00474
00475
00476
00477 def _create_outcome_name(self, outcome):
00478 idx = 0
00479 name = "%s%d" % (outcome, idx)
00480 while self.states_dict.has_key(name):
00481 idx = idx + 1
00482 name = "%s%d" % (outcome, idx)
00483 return name
00484
00485 def _is_type(self, state_name, outcome):
00486 r = state_name.find(outcome)
00487 if r < 0:
00488 return False
00489 else:
00490 return True
00491
00492 def add_node(self, node):
00493 if self.states_dict.has_key(node.get_name()):
00494 node.set_name(node.get_name() + '_dup')
00495
00496
00497
00498 if not hasattr(node, 'get_child_name') or \
00499 not self.states_dict.has_key(node.get_child_name()):
00500
00501
00502 smach_node = node.get_smach_state()
00503 if hasattr(smach_node, 'set_robot'):
00504 smach_node.set_robot(None)
00505
00506
00507 self.gve.add_node(node.get_name(), radius=self.NODE_RADIUS)
00508 self.states_dict[node.get_name()] = node
00509
00510
00511 for outcome in smach_node.get_registered_outcomes():
00512
00513 outcome_name = self._create_outcome_name(outcome)
00514 self.states_dict[outcome_name] = tu.EmptyState(outcome_name, temporary=True)
00515 self.gve.add_node(outcome_name, radius=self.NODE_RADIUS)
00516 self._add_edge(node.get_name(), outcome_name, outcome)
00517
00518 else:
00519
00520 self.replace_node(node, node.get_child_name())
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560 def add_outcome(self, outcome_name):
00561 self.gve.add_node(outcome_name, radius=self.NODE_RADIUS)
00562 self.states_dict[outcome_name] = tu.EmptyState(outcome_name, False)
00563
00564 def delete_node(self, node_name):
00565 node_obj = self.gve.node(node_name)
00566 children_edges = []
00567 parent_edges = []
00568
00569 print 'deleting', node_name
00570
00571 for cn in node_obj.links:
00572 for edge in self.gve.all_edges_between(node_name, cn.id):
00573 if (edge.node1.id == node_name) and (edge.node2.id == node_name):
00574 raise Exception('Self link detected on node %s! This isn\'t supposed to happen.' % node_name)
00575 if edge.node1.id == node_name:
00576 children_edges.append(edge)
00577 elif edge.node2.id == node_name:
00578 parent_edges.append(edge)
00579
00580
00581 filtered_children_edges = []
00582 for e in children_edges:
00583
00584
00585
00586 if not self.is_modifiable(e.node2.id) and len(e.node2.edges) <= 1:
00587
00588
00589 self.gve.remove_edge(node_name, e.node2.id, e.label)
00590 self.gve.remove_node(e.node2.id)
00591 self.states_dict.pop(e.node2.id)
00592 else:
00593 filtered_children_edges.append(e)
00594
00595
00596 if len(parent_edges) >= 1:
00597
00598 parent_node_id = parent_edges[0].node1.id
00599 parent_node = self.gve.node(parent_node_id)
00600 print 'picked parent', parent_node_id
00601
00602
00603 parents_children = {}
00604 for parent_outcome_name, sibling_node_name in self.current_children_of(parent_node_id):
00605 parents_children[parent_outcome_name] = sibling_node_name
00606 print 'siblings', parents_children
00607
00608
00609 for edge in filtered_children_edges:
00610 print 'processing child edge', edge.node1.id, edge.label, edge.node2.id
00611
00612 node_outcome_name = edge.label
00613
00614
00615 if parents_children.has_key(node_outcome_name):
00616 parent_outcome_node = parents_children[node_outcome_name]
00617
00618 if not self.is_modifiable(parent_outcome_node):
00619
00620 self.gve.remove_edge(parent_node_id, parent_outcome_node, label=node_outcome_name)
00621 self.gve.add_edge(parent_node_id, edge.node2.id, label=node_outcome_name, length=GraphModel.EDGE_LENGTH)
00622
00623
00624
00625 if len(self.gve.node(parent_outcome_node).edges) < 1:
00626 self.gve.remove_node(parent_outcome_node)
00627 self.states_dict.pop(parent_outcome_node)
00628
00629 self.gve.remove_edge(edge.node1.id, edge.node2.id, edge.label)
00630
00631
00632 elif len(parent_edges) == 0:
00633
00634 for e in filtered_children_edges:
00635 self.gve.remove_edge(node_name, e.node2.id, label=e.label)
00636
00637
00638 for parent_edge in parent_edges:
00639 self.gve.remove_edge(parent_edge.node1.id, parent_edge.node2.id, parent_edge.label)
00640 self.restore_node_consistency(parent_edge.node1.id)
00641
00642 self.gve.remove_node(node_name)
00643 self.states_dict.pop(node_name)
00644 if self.start_state == node_name:
00645 self.start_state = None
00646
00647
00648
00649 def restore_node_consistency(self, node_name):
00650
00651
00652 clist = self.current_children_of(node_name)
00653 cdict = {}
00654
00655 for outcome_name, nn in clist:
00656 cdict[outcome_name] = nn
00657
00658
00659
00660
00661
00662 smach_state = self.states_dict[node_name].get_smach_state()
00663 if hasattr(smach_state, 'set_robot'):
00664 smach_state.set_robot(None)
00665 registered_outcomes = smach_state.get_registered_outcomes()
00666
00667
00668 for outcome in cdict.keys():
00669 if not (outcome in registered_outcomes):
00670 self.gve.remove_edge(node_name, cdict[outcome], outcome)
00671 if (not self.is_modifiable(cdict[outcome])) and len(self.gve.node(cdict[outcome]).edges) < 1:
00672 self.gve.remove_node(cdict[outcome])
00673 self.states_dict.pop(cdict[outcome])
00674
00675
00676
00677
00678 for outcome in registered_outcomes:
00679 if not cdict.has_key(outcome):
00680
00681 new_outcome_name = self._create_outcome_name(outcome)
00682 self._add_temporary_outcome(new_outcome_name)
00683 self._add_edge(node_name, new_outcome_name, outcome)
00684
00685 def _add_temporary_outcome(self, outcome):
00686 self.states_dict[outcome] = tu.EmptyState(outcome, temporary=True)
00687 self.gve.add_node(outcome, self.NODE_RADIUS)
00688
00689 def is_modifiable(self, node_name):
00690 if (self.states_dict[node_name].__class__ == tu.EmptyState) and self.states_dict[node_name].temporary:
00691 return False
00692 else:
00693 return True
00694
00695 def _add_edge(self, n1, n2, n1_outcome):
00696 if not self.states_dict.has_key(n1) or not self.states_dict.has_key(n2):
00697 raise RuntimeError('One of the specified nodes does not exist. Can\'t add edge.')
00698
00699 if self.gve.edge(n1, n2, n1_outcome) != None:
00700 rospy.loginfo("Edge between %s and %s exists, ignoring connnection request" % (n1, n2))
00701 return False
00702
00703
00704 if n1_outcome == None and self.is_modifiable(n2):
00705 raise RuntimeError('Must specify outcome as goal node is not a temporary node.')
00706
00707 self.gve.add_edge(n1, n2, label=n1_outcome, length=GraphModel.EDGE_LENGTH)
00708
00709
00710 return True
00711
00712 def add_edge(self, n1, n2, n1_outcome):
00713 if not self.is_modifiable(n1) or not self.is_modifiable(n2):
00714 return False
00715 else:
00716 return self._add_edge(n1, n2, n1_outcome)
00717
00718 def delete_edge(self, edge):
00719 if not self.is_modifiable(edge.node1.id) or not self.is_modifiable(edge.node2.id):
00720 return False
00721 else:
00722 self.gve.remove_edge(edge.node1.id, edge.node2.id, e.label)
00723 return True
00724
00725 def connection_changed(self, node_name, outcome_name, new_node):
00726
00727 if node_name == None:
00728 return
00729
00730
00731 if not self.states_dict.has_key(new_node):
00732 self.states_dict[new_node] = tu.EmptyState(new_node, temporary=True)
00733 self.gve.add_node(new_node, radius=self.NODE_RADIUS)
00734
00735
00736
00737
00738
00739 old_edge = None
00740 for edge in self.gve.node(node_name).edges:
00741
00742
00743 if edge.label == outcome_name and edge.node1.id == node_name:
00744 if old_edge != None:
00745 raise RuntimeError('Two edges detected for one outcome named %s. %s -> %s and %s -> %s' % (outcome_name, old_edge.node1.id, old_edge.node2.id, edge.node1.id, edge.node2.id))
00746 old_edge = edge
00747
00748
00749 if old_edge.node2.id == new_node:
00750 return
00751
00752
00753
00754 self.gve.remove_edge(node_name, old_edge.node2.id, label=old_edge.label)
00755
00756
00757 if not self.is_modifiable(old_edge.node2.id):
00758
00759
00760 if len(self.gve.node(old_edge.node2.id).edges) <= 0:
00761 self.gve.remove_node(old_edge.node2.id)
00762 self.states_dict.pop(old_edge.node2.id)
00763
00764
00765 if self.gve.node(new_node) == None:
00766
00767 self.states_dict[new_node] = tu.EmptyState(new_node, temporary=True)
00768 self.gve.add_node(new_node, self.NODE_RADIUS)
00769
00770 self._add_edge(node_name, new_node, outcome_name)
00771
00772
00773
00774
00775
00776
00777
00778
00779
00780
00781
00782
00783
00784
00785
00786
00787
00788
00789
00790
00791
00792
00793
00794
00795
00796
00797
00798
00799
00800
00801
00802
00803
00804
00805
00806
00807
00808
00809
00810
00811
00812
00813
00814
00815
00816
00817
00818
00819
00820
00821
00822
00823
00824
00825
00826
00827
00828
00829
00830
00831
00832
00833
00834
00835
00836
00837
00838
00839
00840
00841
00842
00843
00844
00845
00846
00847
00848
00849
00850
00851
00852
00853
00854
00855
00856
00857
00858
00859
00860
00861
00862
00863
00864
00865
00866
00867