00001
00002
00003 from collections import defaultdict
00004 import itertools
00005
00006 import constraints
00007 import pddl
00008 import tools
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 def invert_list(alist):
00019 result = defaultdict(list)
00020 for pos, arg in enumerate(alist):
00021 result[arg].append(pos)
00022 return result
00023
00024
00025 def instantiate_factored_mapping(pairs):
00026 part_mappings = [[zip(preimg, perm_img) for perm_img in tools.permutations(img)]
00027 for (preimg, img) in pairs]
00028 return tools.cartesian_product(part_mappings)
00029
00030
00031 def find_unique_variables(action, invariant):
00032
00033 params = set([p.name for p in action.parameters])
00034 for eff in action.effects[0]:
00035 params.update([p.name for p in eff.parameters])
00036 for eff in action.effects[1]:
00037 params.update([p.name for p in eff.parameters])
00038 inv_vars = []
00039 counter = itertools.count()
00040 for _ in xrange(invariant.arity()):
00041 while True:
00042 new_name = "?v%i" % counter.next()
00043 if new_name not in params:
00044 inv_vars.append(pddl.Variable(new_name))
00045 break
00046 return inv_vars
00047
00048
00049 def get_literals(condition):
00050 if isinstance(condition, pddl.Literal):
00051 yield condition
00052 elif isinstance(condition, pddl.Conjunction):
00053 for literal in condition.parts:
00054 if isinstance(literal, pddl.Literal):
00055 yield literal
00056
00057
00058 def ensure_conjunction_sat(system, *parts):
00059 """Modifies the constraint system such that it is only solvable if the
00060 conjunction of all parts is satisfiable.
00061
00062 Each part must be an iterator, generator, or an iterable over
00063 literals."""
00064 pos = defaultdict(set)
00065 neg = defaultdict(set)
00066 for literal in itertools.chain(*parts):
00067 if literal.predicate == "=":
00068 if literal.negated:
00069 n = constraints.NegativeClause([literal.args])
00070 system.add_negative_clause(n)
00071 else:
00072 a = constraints.Assignment([literal.args])
00073 system.add_assignment_disjunction([a])
00074 else:
00075 if literal.negated:
00076 neg[literal.predicate].add(literal)
00077 else:
00078 pos[literal.predicate].add(literal)
00079
00080 for pred, posatoms in pos.iteritems():
00081 if pred in neg:
00082 for posatom in posatoms:
00083 for negatom in neg[pred]:
00084 parts = zip(negatom.args, posatom.args)
00085 if parts:
00086 negative_clause = constraints.NegativeClause(parts)
00087 system.add_negative_clause(negative_clause)
00088
00089
00090 def ensure_cover(system, literal, invariant, inv_vars):
00091 """Modifies the constraint system such that it is only solvable if the
00092 invariant covers the literal"""
00093 a = invariant.get_covering_assignments(inv_vars, literal)
00094 system.add_assignment_disjunction(a)
00095
00096
00097 def ensure_inequality(system, literal1, literal2):
00098 """Modifies the constraint system such that it is only solvable if the
00099 literal instantiations are not equal (ignoring whether one is negated and
00100 the other is not)"""
00101 if (literal1.predicate == literal2.predicate and
00102 literal1.parts):
00103 parts = zip(literal1.parts, literal2.parts)
00104 system.add_negative_clause(constraints.NegativeClause(parts))
00105
00106
00107 class InvariantPart:
00108 def __init__(self, predicate, order, omitted_pos=-1):
00109 self.predicate = predicate
00110 self.order = order
00111 self.omitted_pos = omitted_pos
00112
00113 def __eq__(self, other):
00114
00115 return self.predicate == other.predicate and self.order == other.order
00116
00117 def __ne__(self, other):
00118 return self.predicate != other.predicate or self.order != other.order
00119
00120 def __hash__(self):
00121 return hash((self.predicate, tuple(self.order)))
00122
00123 def __str__(self):
00124 var_string = " ".join(map(str, self.order))
00125 omitted_string = ""
00126 if self.omitted_pos != -1:
00127 omitted_string = " [%d]" % self.omitted_pos
00128 return "%s %s%s" % (self.predicate, var_string, omitted_string)
00129
00130 def arity(self):
00131 return len(self.order)
00132
00133 def get_assignment(self, parameters, literal):
00134 equalities = [(arg, literal.args[argpos])
00135 for arg, argpos in zip(parameters, self.order)]
00136 return constraints.Assignment(equalities)
00137
00138 def get_parameters(self, literal):
00139 return [literal.args[pos] for pos in self.order]
00140
00141 def instantiate(self, parameters):
00142 args = ["?X"] * (len(self.order) + (self.omitted_pos != -1))
00143 for arg, argpos in zip(parameters, self.order):
00144 args[argpos] = arg
00145 return pddl.Atom(self.predicate, args)
00146
00147 def possible_mappings(self, own_literal, other_literal):
00148 allowed_omissions = len(other_literal.args) - len(self.order)
00149 if allowed_omissions not in (0, 1):
00150 return []
00151 own_parameters = self.get_parameters(own_literal)
00152 arg_to_ordered_pos = invert_list(own_parameters)
00153 other_arg_to_pos = invert_list(other_literal.args)
00154 factored_mapping = []
00155
00156 for key, other_positions in other_arg_to_pos.iteritems():
00157 own_positions = arg_to_ordered_pos.get(key, [])
00158 len_diff = len(own_positions) - len(other_positions)
00159 if len_diff >= 1 or len_diff <= -2 or len_diff == -1 and not allowed_omissions:
00160 return []
00161 if len_diff:
00162 own_positions.append(-1)
00163 allowed_omissions = 0
00164 factored_mapping.append((other_positions, own_positions))
00165 return instantiate_factored_mapping(factored_mapping)
00166
00167 def possible_matches(self, own_literal, other_literal):
00168 assert self.predicate == own_literal.predicate
00169 result = []
00170 for mapping in self.possible_mappings(own_literal, other_literal):
00171 new_order = [None] * len(self.order)
00172 omitted = -1
00173 for (key, value) in mapping:
00174 if value == -1:
00175 omitted = key
00176 else:
00177 new_order[value] = key
00178 result.append(InvariantPart(other_literal.predicate, new_order, omitted))
00179 return result
00180
00181 def matches(self, other, own_literal, other_literal):
00182 return self.get_parameters(own_literal) == other.get_parameters(other_literal)
00183
00184
00185 class Invariant(object):
00186
00187
00188
00189
00190
00191
00192 def __init__(self, parts):
00193 self.parts = frozenset(parts)
00194 self.predicates = set([part.predicate for part in parts])
00195 self.predicate_to_part = dict([(part.predicate, part) for part in parts])
00196 assert len(self.parts) == len(self.predicates)
00197
00198 def __eq__(self, other):
00199 return self.__class__ == other.__class__ and self.parts == other.parts
00200
00201 def __ne__(self, other):
00202 return self.__class__ != other.__class__ or self.parts != other.parts
00203
00204 def __hash__(self):
00205 return hash(self.parts)
00206
00207 def __str__(self):
00208 return "{%s}" % ", ".join(map(str, self.parts))
00209
00210 def arity(self):
00211 return iter(self.parts).next().arity()
00212
00213 def get_parameters(self, atom):
00214 return self.predicate_to_part[atom.predicate].get_parameters(atom)
00215
00216 def instantiate(self, parameters):
00217 return [part.instantiate(parameters) for part in self.parts]
00218
00219
00220 class SafeInvariant(Invariant):
00221 def get_covering_assignments(self, parameters, atom):
00222 part = self.predicate_to_part[atom.predicate]
00223 return [part.get_assignment(parameters, atom)]
00224
00225
00226
00227 def check_balance(self, balance_checker, enqueue_func):
00228
00229 actions_to_check = set()
00230 for part in self.parts:
00231 actions_to_check |= balance_checker.get_threats(part.predicate)
00232
00233 temp_unbalanced_actions = set()
00234 for action in actions_to_check:
00235 heavy_action = balance_checker.get_heavy_action(action.name)
00236 if self.operator_too_heavy(heavy_action):
00237 return False
00238 unbalanced, new_candidates = self.operator_unbalanced(action,
00239 temp_unbalanced_actions)
00240 if unbalanced:
00241 for candidate in new_candidates:
00242 enqueue_func(candidate)
00243 return False
00244
00245
00246 for action, effect, _ in temp_unbalanced_actions:
00247 if not self.conditions_require_weight_1(action, effect):
00248 break
00249 else:
00250 return True
00251 if len(temp_unbalanced_actions) > 1:
00252
00253
00254
00255 for _, _, candidates in temp_unbalanced_actions:
00256 for candidate in candidates:
00257 enqueue_func(candidate)
00258 return False
00259
00260 return True
00261
00262 def operator_too_heavy(self, h_action):
00263 inv_vars = find_unique_variables(h_action, self)
00264 for time in xrange(2):
00265 cond_time = 2 * time
00266 add_effects = [eff for eff in h_action.effects[time]
00267 if isinstance(eff.peffect, pddl.Literal) and
00268 not eff.peffect.negated and
00269 self.predicate_to_part.get(eff.peffect.predicate)]
00270
00271 if len(add_effects) <= 1:
00272 continue
00273
00274 for eff1, eff2 in itertools.combinations(add_effects, 2):
00275 system = constraints.ConstraintSystem()
00276 ensure_inequality(system, eff1.peffect, eff2.peffect)
00277 ensure_cover(system, eff1.peffect, self, inv_vars)
00278 ensure_cover(system, eff2.peffect, self, inv_vars)
00279 ensure_conjunction_sat(system,
00280 get_literals(h_action.condition[cond_time]),
00281 get_literals(eff1.condition[cond_time]),
00282 get_literals(eff2.condition[cond_time]),
00283 [eff1.peffect.negate()],
00284 [eff2.peffect.negate()])
00285 if system.is_solvable():
00286 return True
00287 return False
00288
00289 def operator_unbalanced(self, action, temp_unbalanced_actions):
00290 inv_vars = find_unique_variables(action, self)
00291 relevant_effs = [[],[]]
00292 add_effects = [[],[]]
00293 del_effects = [[],[]]
00294 for time in xrange(2):
00295 relevant_effs[time] = [eff for eff in action.effects[time]
00296 if isinstance(eff.peffect, pddl.Literal) and
00297 self.predicate_to_part.get(eff.peffect.predicate)]
00298 add_effects[time] = [eff for eff in relevant_effs[time]
00299 if not eff.peffect.negated]
00300 del_effects[time] = [eff for eff in relevant_effs[time]
00301 if eff.peffect.negated]
00302 for time in xrange(2):
00303 poss_temporary_cand = ((time == 1) and not len(add_effects[0]))
00304 for eff in add_effects[time]:
00305 unbal, new_candidates = self.add_effect_unbalanced(action,
00306 eff, del_effects[time],
00307 inv_vars, time)
00308 if unbal:
00309 if not poss_temporary_cand:
00310 return unbal, new_candidates
00311
00312 if poss_temporary_cand:
00313 unbal, new_cands = self.add_effect_temporarily_unbalanced(action,
00314 eff, del_effects[0], inv_vars)
00315 if unbal:
00316 new_candidates += new_cands
00317 return unbal, new_candidates
00318
00319
00320 new_candidates = tuple(new_candidates)
00321 temp_unbalanced_actions.add((action, eff,
00322 new_candidates))
00323
00324 return False, None
00325
00326 def minimal_covering_renamings(self, action, add_effect, inv_vars):
00327 """computes the minimal renamings of the action parameters such
00328 that the add effect is covered by the action.
00329 Each renaming is an constraint system"""
00330
00331
00332 assigs = self.get_covering_assignments(inv_vars, add_effect.peffect)
00333
00334 minimal_renamings = []
00335 params = [p.name for p in action.parameters]
00336 for assignment in assigs:
00337 system = constraints.ConstraintSystem()
00338 system.add_assignment(assignment)
00339
00340 minimality_clauses = []
00341 mapping = assignment.get_mapping()
00342 if len(params) > 1:
00343 for (n1, n2) in itertools.combinations(params, 2):
00344 if mapping.get(n1, n1) != mapping.get(n2, n2):
00345 negative_clause = constraints.NegativeClause([(n1, n2)])
00346 system.add_negative_clause(negative_clause)
00347 minimal_renamings.append(system)
00348 return minimal_renamings
00349
00350 def add_effect_unbalanced(self, action, add_effect, del_effects,
00351 inv_vars, time):
00352 cond_time = 2 * time
00353 minimal_renamings = self.minimal_covering_renamings(action, add_effect,
00354 inv_vars)
00355
00356 lhs_by_pred = defaultdict(list)
00357 for lit in itertools.chain(get_literals(action.condition[cond_time]),
00358 get_literals(add_effect.condition[cond_time]),
00359 get_literals(add_effect.peffect.negate())):
00360 lhs_by_pred[lit.predicate].append(lit)
00361
00362 for del_effect in del_effects:
00363 if (time == 1 and
00364 (del_effect.condition[0] or del_effect.condition[1])):
00365 continue
00366 minimal_renamings = self.unbalanced_renamings(del_effect, add_effect,
00367 inv_vars, lhs_by_pred, time, minimal_renamings)
00368 if not minimal_renamings:
00369 return False, None
00370
00371
00372 return True, self.refine_candidate(add_effect, action, 0)
00373
00374 def add_effect_temporarily_unbalanced(self, action, add_effect, start_del_effects, inv_vars):
00375 """at-end add effect has corresponding at-start del effect, so it could
00376 be balanced if no other action interferes"""
00377 minimal_renamings = self.minimal_covering_renamings(action, add_effect,
00378 inv_vars)
00379
00380 lhs_by_pred = defaultdict(list)
00381 for lit in itertools.chain(get_literals(action.condition[0]),
00382 get_literals(add_effect.condition[0]),
00383 get_literals(add_effect.peffect.negate())):
00384 lhs_by_pred[lit.predicate].append(lit)
00385
00386 for del_effect in start_del_effects:
00387 minimal_renamings = self.temp_unbalanced_renamings(del_effect, add_effect,
00388 inv_vars, lhs_by_pred, minimal_renamings)
00389 if not minimal_renamings:
00390 return False, None
00391
00392
00393 return True, self.refine_candidate(add_effect, action, 0)
00394
00395 def refine_candidate(self, add_effect, action, time):
00396 """refines the candidate for an add effect that is unbalanced in the
00397 action and adds the refined one to the queue"""
00398 new_candidates = []
00399 part = self.predicate_to_part[add_effect.peffect.predicate]
00400 for del_eff in [eff for eff in action.effects[time]
00401 if isinstance(eff.peffect, pddl.Literal) and
00402 eff.peffect.negated]:
00403 if del_eff.peffect.predicate not in self.predicate_to_part:
00404 for match in part.possible_matches(add_effect.peffect,
00405 del_eff.peffect):
00406 new_candidates.append(SafeInvariant(self.parts.union((match,))))
00407 return new_candidates
00408
00409 def temp_unbalanced_renamings(self, del_effect, add_effect,
00410 inv_vars, lhs_by_pred, unbalanced_renamings):
00411 """returns the renamings from unbalanced renamings for which
00412 the start_del_effect does not balance the end_add_effect."""
00413 system = constraints.ConstraintSystem()
00414 ensure_cover(system, del_effect.peffect, self, inv_vars)
00415
00416 still_unbalanced = []
00417 for renaming in unbalanced_renamings:
00418 new_sys = system.combine(renaming)
00419 if self.lhs_satisfiable(renaming, lhs_by_pred):
00420 implies_system = self.imply_del_effect(del_effect, lhs_by_pred,
00421 0)
00422 if not implies_system:
00423 still_unbalanced.append(renaming)
00424 continue
00425 new_sys = new_sys.combine(implies_system)
00426 if not new_sys.is_solvable():
00427 still_unbalanced.append(renaming)
00428 return still_unbalanced
00429
00430 def unbalanced_renamings(self, del_effect, add_effect,
00431 inv_vars, lhs_by_pred, time, unbalanced_renamings):
00432 """returns the renamings from unbalanced renamings for which
00433 the del_effect does not balance the add_effect."""
00434 system = constraints.ConstraintSystem()
00435 ensure_inequality(system, add_effect.peffect, del_effect.peffect)
00436 ensure_cover(system, del_effect.peffect, self, inv_vars)
00437
00438 still_unbalanced = []
00439 for renaming in unbalanced_renamings:
00440 new_sys = system.combine(renaming)
00441 if self.lhs_satisfiable(renaming, lhs_by_pred):
00442 implies_system = self.imply_del_effect(del_effect, lhs_by_pred,
00443 time)
00444 if not implies_system:
00445 still_unbalanced.append(renaming)
00446 continue
00447 new_sys = new_sys.combine(implies_system)
00448 if not new_sys.is_solvable():
00449 still_unbalanced.append(renaming)
00450 return still_unbalanced
00451
00452 def lhs_satisfiable(self, renaming, lhs_by_pred):
00453 system = renaming.copy()
00454 ensure_conjunction_sat(system, *itertools.chain(lhs_by_pred.values()))
00455 return system.is_solvable()
00456
00457 def imply_del_effect(self, del_effect, lhs_by_pred, time):
00458 """returns a constraint system that is solvable if lhs implies
00459 the del effect (only if lhs is satisfiable). If a solvable
00460 lhs never implies the del effect, return None."""
00461
00462 implies_system = constraints.ConstraintSystem()
00463 del_eff_condition = del_effect.condition[time * 2]
00464 for literal in itertools.chain(get_literals(del_eff_condition),
00465 [del_effect.peffect.negate()]):
00466 poss_assignments = []
00467 for match in lhs_by_pred[literal.predicate]:
00468 if match.negated != literal.negated:
00469 continue
00470 else:
00471 a = constraints.Assignment(zip(literal.args, match.args))
00472 poss_assignments.append(a)
00473 if not poss_assignments:
00474 return None
00475 implies_system.add_assignment_disjunction(poss_assignments)
00476 return implies_system
00477
00478
00479 def conditions_require_weight_1(self, action, add_effect):
00480 inv_vars = find_unique_variables(action, self)
00481 minimal_renamings = self.minimal_covering_renamings(action, add_effect,
00482 inv_vars)
00483 relevant_conditions = set(get_literals(action.condition[0]))
00484 relevant_conditions |= set(get_literals(add_effect.condition[0]))
00485 relevant_conditions = [atom for atom in relevant_conditions
00486 if not atom.negated and
00487 self.predicate_to_part.get(atom.predicate)]
00488
00489 negative_clauses = []
00490 for atom in relevant_conditions:
00491 a = self.get_covering_assignments(inv_vars, atom)[0]
00492 if not len(a.equalities):
00493 return True
00494 negative_clauses.append(constraints.NegativeClause(a.equalities))
00495 for renaming in minimal_renamings:
00496 for clause in negative_clauses:
00497 system = renaming.copy()
00498 system.add_negative_clause(clause)
00499 if not system.is_solvable():
00500 break
00501 else:
00502 return False
00503 return True
00504
00505
00506 class UnsafeInvariant(Invariant):
00507 def check_balance(self, balance_checker, enqueue_func):
00508
00509 actions_to_check = set()
00510 for part in self.parts:
00511 actions_to_check |= balance_checker.get_threats(part.predicate)
00512 for action in actions_to_check:
00513 if not self.check_action_balance(balance_checker, action, enqueue_func):
00514 return False
00515 return True
00516
00517 def check_action_balance(self, balance_checker, action, enqueue_func):
00518
00519 if isinstance(action,pddl.Action):
00520 del_effects = [eff for eff in action.effects if isinstance(eff.peffect,pddl.NegatedAtom)]
00521 add_effects = [eff for eff in action.effects if isinstance(eff.peffect,pddl.Atom)]
00522 matched_add_effects = []
00523 for eff in add_effects:
00524 part = self.predicate_to_part.get(eff.peffect.predicate)
00525 if part:
00526 for previous_part, previous_peffect in matched_add_effects:
00527 if previous_part.matches(part, previous_peffect, eff.peffect) \
00528 and previous_peffect != eff.peffect:
00529 return False
00530 if not self.find_matching_del_effect(part, eff, del_effects, enqueue_func):
00531 return False
00532 matched_add_effects.append((part, eff.peffect))
00533 return True
00534 else:
00535 start_del_effects = [eff for eff in action.effects[0] if isinstance(eff.peffect,pddl.NegatedAtom)]
00536 end_del_effects = [eff for eff in action.effects[1] if isinstance(eff.peffect,pddl.NegatedAtom)]
00537 start_add_effects = [eff for eff in action.effects[0] if isinstance(eff.peffect,pddl.Atom)]
00538 end_add_effects = [eff for eff in action.effects[1] if isinstance(eff.peffect,pddl.Atom)]
00539
00540 matched_start_add_effects = []
00541 matched_end_add_effects = []
00542 for eff in start_add_effects:
00543 part = self.predicate_to_part.get(eff.peffect.predicate)
00544 if part:
00545 for previous_part, previous_peffect in matched_start_add_effects:
00546 if previous_part.matches(part, previous_peffect, eff.peffect) \
00547 and previous_peffect != eff.peffect:
00548 return False
00549 if not self.find_matching_del_effect(part, eff, start_del_effects, enqueue_func):
00550 return False
00551 matched_start_add_effects.append((part, eff.peffect))
00552 for eff in end_add_effects:
00553 part = self.predicate_to_part.get(eff.peffect.predicate)
00554 if part:
00555 check_all_del_effects = True
00556 found_start_del = False
00557 found_end_del = False
00558 for previous_part, previous_peffect in matched_end_add_effects:
00559 if previous_part.matches(part, previous_peffect, eff.peffect) \
00560 and previous_peffect != eff.peffect:
00561 return False
00562 for previous_part, previous_peffect in matched_start_add_effects:
00563 if previous_part.matches(part, previous_peffect, eff.peffect) \
00564 and previous_peffect != eff.peffect:
00565 check_all_del_effects = False
00566 break
00567 if check_all_del_effects:
00568 found_start_del = self.find_matching_del_effect(part, eff, start_del_effects,
00569 enqueue_func, False)
00570 found_end_del = self.find_matching_del_effect(part, eff, end_del_effects,
00571 enqueue_func, False)
00572 if not (found_start_del or found_end_del):
00573 if not found_end_del:
00574 self.generate_new_candidates(part, eff, end_del_effects, enqueue_func)
00575 if check_all_del_effects and not found_start_del:
00576 self.generate_new_candidates(part, eff, start_del_effects, enqueue_func)
00577 return False
00578 matched_end_add_effects.append((part, eff.peffect))
00579 return True
00580
00581 def find_matching_del_effect(self, part, add_effect, del_effects, enqueue_func, generate_new=True):
00582
00583 for del_eff in del_effects:
00584 del_part = self.predicate_to_part.get(del_eff.peffect.predicate)
00585 if del_part and part.matches(del_part, add_effect.peffect, del_eff.peffect):
00586 return True
00587
00588 if generate_new:
00589 self.generate_new_candidates(part, add_effect, del_effects, enqueue_func)
00590 return False
00591
00592 def generate_new_candidates(self, part, add_effect, del_effects, enqueue_func):
00593 for del_eff in del_effects:
00594 if del_eff.peffect.predicate not in self.predicate_to_part:
00595 for match in part.possible_matches(add_effect.peffect, del_eff.peffect):
00596 enqueue_func(UnsafeInvariant(self.parts.union((match,))))
00597