greedy_join.py
Go to the documentation of this file.
00001 import sys
00002 
00003 import pddl
00004 import pddl_to_prolog
00005 
00006 class OccurrencesTracker(object):
00007   """Keeps track of the number of times each variable appears
00008   in a list of symbolic atoms."""
00009   def __init__(self, rule):
00010     self.occurrences = {}
00011     self.update(rule.effect, +1)
00012     for cond in rule.conditions:
00013       self.update(cond, +1)
00014   def update(self, symatom, delta):
00015     for var in symatom.args:
00016       if isinstance(var,pddl.Variable):
00017         if var not in self.occurrences:
00018           self.occurrences[var] = 0
00019         self.occurrences[var] += delta
00020         assert self.occurrences[var] >= 0
00021         if not self.occurrences[var]:
00022           del self.occurrences[var]
00023   def variables(self):
00024     return set(self.occurrences)
00025 
00026 class CostMatrix(object):
00027   def __init__(self, joinees):
00028     self.joinees = []
00029     self.cost_matrix = []
00030     for joinee in joinees:
00031       self.add_entry(joinee)
00032   def add_entry(self, joinee):
00033     new_row = [self.compute_join_cost(joinee, other) for other in self.joinees]
00034     self.cost_matrix.append(new_row)
00035     self.joinees.append(joinee)
00036   def delete_entry(self, index):
00037     for row in self.cost_matrix[index + 1:]:
00038       del row[index]
00039     del self.cost_matrix[index]
00040     del self.joinees[index]
00041   def find_min_pair(self):
00042     assert len(self.joinees) >= 2
00043     min_cost = (sys.maxint, sys.maxint)
00044     for i, row in enumerate(self.cost_matrix):
00045       for j, entry in enumerate(row):
00046         if entry < min_cost:
00047           min_cost = entry
00048           left_index, right_index = i, j
00049     return left_index, right_index
00050   def remove_min_pair(self):
00051     left_index, right_index = self.find_min_pair()
00052     left, right = self.joinees[left_index], self.joinees[right_index]
00053     assert left_index > right_index
00054     self.delete_entry(left_index)
00055     self.delete_entry(right_index)
00056     return (left, right)
00057   def compute_join_cost(self, left_joinee, right_joinee):
00058     left_vars = pddl_to_prolog.get_variables([left_joinee])
00059     right_vars = pddl_to_prolog.get_variables([right_joinee])
00060     if len(left_vars) > len(right_vars):
00061       left_vars, right_vars = right_vars, left_vars
00062     common_vars = left_vars & right_vars
00063     return (len(left_vars) - len(common_vars),
00064             len(right_vars) - len(common_vars),
00065             -len(common_vars))
00066   def __nonzero__(self):
00067     return len(self.joinees) >= 2
00068 
00069 class ResultList(object):
00070   def __init__(self, rule, name_generator):
00071     self.final_effect = rule.effect
00072     self.result = []
00073     self.name_generator = name_generator
00074   def get_result(self):
00075     self.result[-1].effect = self.final_effect
00076     return self.result
00077   def add_rule(self, type, conditions, effect_vars):
00078     effect = pddl.Atom(self.name_generator.next(), effect_vars)
00079     rule = pddl_to_prolog.Rule(conditions, effect)
00080     rule.type = type
00081     self.result.append(rule)
00082     return rule.effect
00083 
00084 def greedy_join(rule, name_generator):
00085   assert len(rule.conditions) >= 2
00086   cost_matrix = CostMatrix(rule.conditions)
00087   occurrences = OccurrencesTracker(rule)
00088   result = ResultList(rule, name_generator)
00089   
00090   while cost_matrix:
00091     joinees = list(cost_matrix.remove_min_pair())
00092     for joinee in joinees:
00093       occurrences.update(joinee, -1)
00094 
00095     common_vars = set(joinees[0].args) & set(joinees[1].args)
00096     condition_vars = set(joinees[0].args) | set(joinees[1].args)
00097     effect_vars = occurrences.variables() & condition_vars
00098     for i, joinee in enumerate(joinees):
00099       joinee_vars = set(joinee.args)
00100       retained_vars = joinee_vars & (effect_vars | common_vars)
00101       if retained_vars != joinee_vars:
00102         joinees[i] = result.add_rule("project", [joinee], list(retained_vars))
00103     joint_condition = result.add_rule("join", joinees, list(effect_vars))
00104     cost_matrix.add_entry(joint_condition)
00105     occurrences.update(joint_condition, +1)
00106 
00107   #assert occurrences.variables() == set(rule.effect.args)
00108   #for var in set(rule.effect.args):
00109   #  assert occurrences.occurrences[var] == 2 * rule.effect.args.count(var)
00110   return result.get_result()
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines


tfd_modules
Author(s): Maintained by Christian Dornhege (see AUTHORS file).
autogenerated on Tue Jan 22 2013 12:25:03