00001
00002
00003
00004 from __future__ import with_statement
00005 from collections import deque, defaultdict
00006 import itertools
00007 import time
00008
00009 import invariants
00010 import pddl
00011 import timers
00012
00013 class BalanceChecker(object):
00014 def __init__(self, task, reachable_action_params, safe=True):
00015 self.predicates_to_add_actions = defaultdict(set)
00016 self.action_name_to_heavy_action = {}
00017 for action in task.durative_actions:
00018 if safe:
00019 action = self.add_inequality_preconds(action, reachable_action_params)
00020 too_heavy_effects = [[],[]]
00021 create_heavy_act = False
00022 for time in xrange(2):
00023 for eff in action.effects[time]:
00024 if isinstance(eff.peffect, pddl.Atom):
00025 predicate = eff.peffect.predicate
00026 self.predicates_to_add_actions[predicate].add(action)
00027 if safe:
00028 too_heavy_effects[time].append(eff)
00029 if eff.parameters:
00030 create_heavy_act = True
00031 too_heavy_effects[time].append(eff.copy())
00032 if safe:
00033 if create_heavy_act:
00034 heavy_act = pddl.DurativeAction(action.name,
00035 action.parameters,
00036 action.duration,
00037 action.condition,
00038 too_heavy_effects)
00039
00040
00041 self.action_name_to_heavy_action[action.name] = heavy_act
00042 else:
00043 self.action_name_to_heavy_action[action.name] = action
00044
00045 def get_threats(self, predicate):
00046 return self.predicates_to_add_actions.get(predicate, set())
00047
00048 def get_heavy_action(self, action_name):
00049 return self.action_name_to_heavy_action[action_name]
00050
00051 def add_inequality_preconds(self, action, reachable_action_params):
00052 if reachable_action_params is None or len(action.parameters) < 2:
00053 return action
00054 new_cond_parts = []
00055 combs = itertools.combinations(range(len(action.parameters)), 2)
00056 for pos1, pos2 in combs:
00057 for params in reachable_action_params[action.name]:
00058 if params[pos1] == params[pos2]:
00059 break
00060 else:
00061 param1 = pddl.Variable(action.parameters[pos1].name)
00062 param2 = pddl.Variable(action.parameters[pos2].name)
00063 new_cond = pddl.NegatedAtom("=", (param1, param2))
00064 new_cond_parts.append(new_cond)
00065 if new_cond_parts:
00066 new_cond = list(action.condition)
00067 for time in (0,2):
00068 cond_parts = list(action.condition[time].parts)
00069 if isinstance(action.condition[time], pddl.Literal):
00070 cond_parts = [action.condition[time]]
00071 cond_parts.extend(new_cond_parts)
00072 cond = pddl.Conjunction(cond_parts)
00073 new_cond[time] = cond
00074 return pddl.DurativeAction(action.name, action.parameters,
00075 action.duration, new_cond,
00076 action.effects)
00077 else:
00078 return action
00079
00080 def get_fluents(task):
00081 fluent_names = set()
00082 for action in task.durative_actions:
00083 for timed_effects in action.effects:
00084 for eff in timed_effects:
00085 if isinstance(eff.peffect, pddl.Literal):
00086 fluent_names.add(eff.peffect.predicate)
00087 return [pred for pred in task.predicates if pred.name in fluent_names]
00088
00089 def get_initial_invariants(task, safe):
00090 for predicate in get_fluents(task):
00091 all_args = range(len(predicate.arguments))
00092 for omitted_arg in [-1] + all_args:
00093 order = [i for i in all_args if i != omitted_arg]
00094 part = invariants.InvariantPart(predicate.name, order, omitted_arg)
00095 if safe:
00096 yield invariants.SafeInvariant((part,))
00097 else:
00098 yield invariants.UnsafeInvariant((part,))
00099
00100
00101 MAX_CANDIDATES = 100000
00102 MAX_TIME = 300
00103
00104 def find_invariants(task, safe, reachable_action_params):
00105 candidates = deque(get_initial_invariants(task, safe))
00106 print len(candidates), "initial candidates"
00107 seen_candidates = set(candidates)
00108
00109 balance_checker = BalanceChecker(task, reachable_action_params, safe)
00110
00111 def enqueue_func(invariant):
00112 if len(seen_candidates) < MAX_CANDIDATES and invariant not in seen_candidates:
00113 candidates.append(invariant)
00114 seen_candidates.add(invariant)
00115
00116 start_time = time.clock()
00117 while candidates:
00118 candidate = candidates.popleft()
00119 if time.clock() - start_time > MAX_TIME:
00120 print "Time limit reached, aborting invariant generation"
00121 return
00122 if candidate.check_balance(balance_checker, enqueue_func):
00123 yield candidate
00124
00125 def useful_groups(invariants, initial_facts):
00126 predicate_to_invariants = defaultdict(list)
00127 for invariant in invariants:
00128 for predicate in invariant.predicates:
00129 predicate_to_invariants[predicate].append(invariant)
00130
00131 nonempty_groups = set()
00132 overcrowded_groups = set()
00133 for atom in initial_facts:
00134 if not isinstance(atom,pddl.FunctionAssignment):
00135 for invariant in predicate_to_invariants.get(atom.predicate, ()):
00136 group_key = (invariant, tuple(invariant.get_parameters(atom)))
00137 if group_key not in nonempty_groups:
00138 nonempty_groups.add(group_key)
00139 else:
00140 overcrowded_groups.add(group_key)
00141 useful_groups = nonempty_groups - overcrowded_groups
00142 for (invariant, parameters) in useful_groups:
00143 yield [part.instantiate(parameters) for part in invariant.parts]
00144
00145 def get_groups(task, safe=True, reachable_action_params=None):
00146 with timers.timing("Finding invariants"):
00147 invariants = list(find_invariants(task, safe, reachable_action_params))
00148 invariants = sorted(invariants)
00149 with timers.timing("Checking invariant weight"):
00150 result = list(useful_groups(invariants, task.init))
00151 return result
00152
00153 if __name__ == "__main__":
00154 import pddl
00155 print "Parsing..."
00156 task = pddl.open()
00157 print "Finding invariants..."
00158 for invariant in find_invariants(task):
00159 print invariant
00160 print "Finding fact groups..."
00161 groups = get_groups(task)
00162 for group in groups:
00163 print "[%s]" % ", ".join(map(str, group))