00001
00002
00003 import invariant_finder
00004 import pddl
00005
00006 def expand_group(group, task, reachable_facts):
00007 result = []
00008 for fact in group:
00009 try:
00010 pos = list(fact.args).index("?X")
00011 except ValueError:
00012 result.append(fact)
00013 else:
00014
00015
00016
00017
00018 for obj in task.objects + [pddl.ObjectTerm("undefined")]:
00019 newargs = list(fact.args)
00020 newargs[pos] = pddl.ObjectTerm(obj.name)
00021 atom = pddl.Atom(fact.predicate, newargs)
00022 if atom in reachable_facts:
00023 result.append(atom)
00024 return result
00025
00026 def instantiate_groups(groups, task, reachable_facts):
00027 return [expand_group(group, task, reachable_facts) for group in groups]
00028
00029 class GroupCoverQueue:
00030 def __init__(self, groups, partial_encoding, unused_groups=[]):
00031 self.partial_encoding = partial_encoding
00032 if groups:
00033 self.max_size = max([len(group) for group in groups])
00034 self.groups_by_size = [[] for i in range(self.max_size + 1)]
00035 self.groups_by_fact = {}
00036 for group in groups:
00037 group = set(group)
00038 self.groups_by_size[len(group)].append(group)
00039 for fact in group:
00040 self.groups_by_fact.setdefault(fact, []).append(group)
00041 for group in unused_groups:
00042 for fact in group:
00043 self.groups_by_fact.setdefault(fact, []).append(group)
00044 self._update_top()
00045 else:
00046 self.max_size = 0
00047 def __nonzero__(self):
00048 return self.max_size > 1
00049 def pop(self):
00050 result = list(self.top)
00051 if self.partial_encoding:
00052 for fact in result:
00053 for group in self.groups_by_fact[fact]:
00054 group.remove(fact)
00055 self._update_top()
00056 return result
00057 def _update_top(self):
00058 while self.max_size > 1:
00059 max_list = self.groups_by_size[self.max_size]
00060 while max_list:
00061 candidate = max_list.pop()
00062 if len(candidate) == self.max_size:
00063 self.top = candidate
00064 return
00065 self.groups_by_size[len(candidate)].append(candidate)
00066 self.max_size -= 1
00067
00068 def choose_groups(groups, reachable_facts, partial_encoding=True):
00069 queue = GroupCoverQueue(groups, partial_encoding=partial_encoding)
00070 uncovered_facts = reachable_facts.copy()
00071 result = []
00072 while queue:
00073 group = queue.pop()
00074 uncovered_facts.difference_update(group)
00075 result.append(group)
00076 print len(uncovered_facts), "uncovered facts"
00077 result += [[fact] for fact in uncovered_facts]
00078 return result
00079
00080 def choose_groups_with_object_fluents_first(synthesis_groups, object_fluent_groups,
00081 reachable_facts, partial_encoding=True):
00082 uncovered_facts = reachable_facts.copy()
00083 result = []
00084
00085
00086 shrink_groups = [set(group) for group in synthesis_groups]
00087 queue = GroupCoverQueue(object_fluent_groups, partial_encoding=partial_encoding,
00088 unused_groups=shrink_groups)
00089 while queue:
00090 group = queue.pop()
00091 uncovered_facts.difference_update(group)
00092 result.append(group)
00093 print len(uncovered_facts), \
00094 "uncovered facts (before using the results from the invariant synthesis)"
00095
00096
00097
00098 queue = GroupCoverQueue(shrink_groups, partial_encoding=partial_encoding)
00099 while queue:
00100 group = queue.pop()
00101 uncovered_facts.difference_update(group)
00102 result.append(group)
00103
00104 print len(uncovered_facts), "uncovered facts"
00105 result += [[fact] for fact in uncovered_facts]
00106 return result
00107
00108 def build_translation_key(groups):
00109 group_keys = []
00110 for group in groups:
00111 group_key = [str(fact) for fact in group]
00112 group_key.append("<none of those>")
00113 group_keys.append(group_key)
00114 return group_keys
00115
00116 def collect_all_mutex_groups(groups, atoms):
00117
00118
00119
00120 all_groups = []
00121 uncovered_facts = atoms.copy()
00122 for group in groups:
00123 uncovered_facts.difference_update(group)
00124 all_groups.append(group)
00125 all_groups += [[fact] for fact in uncovered_facts]
00126 return all_groups
00127
00128 def create_groups_from_object_fluents(atoms):
00129
00130
00131
00132
00133 groupdict = {}
00134 for a in atoms:
00135 if a.predicate.find("!val") != -1:
00136 groupdict.setdefault(a.predicate+"_".join(map(str, a.args[:-1])), []).append(a)
00137
00138
00139 groups = []
00140 for g in groupdict:
00141 tgroup = []
00142 for a in groupdict[g]:
00143 tgroup.append(a);
00144 groups.append(tgroup)
00145 return groups
00146
00147 def compute_groups(task, atoms, reachable_action_params,
00148 return_mutex_groups=True, partial_encoding=True, safe=True):
00149
00150
00151
00152
00153 no_objectfluents = False
00154
00155
00156 only_objectfluents = False
00157
00158
00159 choose_groups_by_size = False
00160
00161
00162
00163 use_objectfluents_first = True
00164
00165 objectfluent_groups = []
00166 groups = []
00167 mutex_groups = []
00168
00169 if not no_objectfluents:
00170 objectfluent_groups = create_groups_from_object_fluents(atoms)
00171
00172 if not only_objectfluents:
00173 print "Finding invariants..."
00174 groups = invariant_finder.get_groups(task, safe, reachable_action_params)
00175 groups = sorted(groups, cmp=lambda x,y: cmp(str([str(a) for a in x]),
00176 str([str(a) for a in y])))
00177 print "Instantiating groups..."
00178 groups = instantiate_groups(groups, task, atoms)
00179
00180 if return_mutex_groups:
00181 mutex_groups = collect_all_mutex_groups(groups, atoms) + \
00182 objectfluent_groups
00183
00184 print "Choosing groups..."
00185 if use_objectfluents_first:
00186 groups = choose_groups_with_object_fluents_first(groups, objectfluent_groups,
00187 atoms, partial_encoding=partial_encoding)
00188 else:
00189 if only_objectfluents:
00190 groups = objectfluent_groups
00191 elif choose_groups_by_size:
00192 groups = objectfluent_groups + groups
00193 groups = choose_groups(groups, atoms, partial_encoding=partial_encoding)
00194
00195 print "Building translation key..."
00196 translation_key = build_translation_key(groups)
00197 return groups, mutex_groups, translation_key