00001
00002
00003
00004
00005
00006
00007 package edu.tum.cs.bayesnets.inference;
00008
00009 import java.util.Arrays;
00010 import java.util.Collection;
00011 import java.util.HashSet;
00012 import java.util.PriorityQueue;
00013 import java.util.Vector;
00014
00015 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00016 import edu.ksu.cis.bnj.ver3.core.CPF;
00017 import edu.ksu.cis.bnj.ver3.core.Discrete;
00018 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00019 import edu.tum.cs.bayesnets.util.TopologicalOrdering;
00020 import edu.tum.cs.bayesnets.util.TopologicalSort;
00021 import edu.tum.cs.logic.Disjunction;
00022 import edu.tum.cs.logic.Formula;
00023 import edu.tum.cs.logic.GroundLiteral;
00024 import edu.tum.cs.logic.PossibleWorld;
00025 import edu.tum.cs.logic.TrueFalse;
00026 import edu.tum.cs.logic.WorldVariables;
00027 import edu.tum.cs.logic.sat.ClausalKB;
00028 import edu.tum.cs.logic.sat.Clause;
00029 import edu.tum.cs.logic.sat.SampleSAT;
00030 import edu.tum.cs.srl.AbstractVariable;
00031 import edu.tum.cs.srl.Database;
00032 import edu.tum.cs.srl.bayesnets.bln.coupling.VariableLogicCoupling;
00033
00034 public class SATIS_BSampler extends BackwardSampling {
00035
00036 VariableLogicCoupling coupling;
00040 SampleSAT sat;
00044 Collection<BeliefNode> determinedVars;
00048 ClausalKB ckb;
00049
00059 public SATIS_BSampler(BeliefNetworkEx bn, SampleSAT sat, VariableLogicCoupling coupling, Collection<BeliefNode> determinedVars) throws Exception {
00060 super(bn);
00061 this.coupling = coupling;
00062 this.sat = sat;
00063 this.ckb = null;
00064 this.determinedVars = determinedVars;
00065 }
00066
00072 public SATIS_BSampler(BeliefNetworkEx bn) throws Exception {
00073 super(bn);
00074
00075 coupling = new VariableLogicCoupling();
00076 for(BeliefNode n : nodes) {
00077 coupling.addBlockVariable(n, (Discrete)n.getDomain(), n.getName(), new String[0]);
00078 }
00079
00080 ckb = new ClausalKB();
00081 extendKBWithDeterministicConstraintsInCPTs(bn, coupling, ckb, null);
00082
00083 determinedVars = new HashSet<BeliefNode>();
00084 for(Clause c : ckb) {
00085 for(GroundLiteral lit : c.lits) {
00086 BeliefNode var = coupling.getVariable(lit.gndAtom);
00087 if(var == null)
00088 throw new Exception("Could not find node corresponding to ground atom '" + lit.gndAtom.toString() + "' with index " + lit.gndAtom.index + "; set of mapped ground atoms is " + coupling.getCoupledGroundAtoms());
00089 determinedVars.add(var);
00090 }
00091 }
00092
00093 sat = null;
00094 }
00095
00096 @Override
00097 public void setEvidence(int[] evidenceDomainIndices) throws Exception {
00098 super.setEvidence(evidenceDomainIndices);
00099
00100 if(this.sat == null) {
00101
00102 Vector<PropositionalVariable> evidence = new Vector<PropositionalVariable>();
00103 for(int i = 0; i < evidenceDomainIndices.length; i++)
00104 if(evidenceDomainIndices[i] != -1) {
00105 evidence.add(new PropositionalVariable(nodes[i].getName(), nodes[i].getDomain().getName(evidenceDomainIndices[i])));
00106 }
00107
00108 WorldVariables worldVars = this.coupling.getWorldVars();
00109 sat = new SampleSAT(ckb, new PossibleWorld(worldVars), worldVars, evidence);
00110 }
00111
00112 sat.setDebugMode(this.debug);
00113 }
00114
00123 public static void extendKBWithDeterministicConstraintsInCPTs(BeliefNetworkEx bn, VariableLogicCoupling coupling, ClausalKB ckb, Database db) throws Exception {
00124 int size = ckb.size();
00125 System.out.print("gathering deterministic constraints from CPDs... ");
00126 for(BeliefNode node : bn.bn.getNodes()) {
00127 if(!coupling.hasCoupling(node))
00128 continue;
00129 CPF cpf = node.getCPF();
00130 BeliefNode[] domProd = cpf.getDomainProduct();
00131 int[] addr = new int[domProd.length];
00132 walkCPF4HardConstraints(coupling, cpf, addr, 0, ckb, db);
00133 }
00134 System.out.println((ckb.size()-size) + " constraints added");
00135 }
00136
00137 protected static void walkCPF4HardConstraints(VariableLogicCoupling coupling, CPF cpf, int[] addr, int i, ClausalKB ckb, Database db) throws Exception {
00138 BeliefNode[] domProd = cpf.getDomainProduct();
00139 if(i == addr.length) {
00140 double p = cpf.getDouble(addr);
00141 if(p == 0.0) {
00142 GroundLiteral[] lits = new GroundLiteral[domProd.length];
00143 for(int k = 0; k < domProd.length; k++) {
00144 lits[k] = coupling.getGroundLiteral(domProd[k], addr[k]);
00145 lits[k].negate();
00146 }
00147 Formula f = new Disjunction(lits);
00148 if(db != null) {
00149 f = f.simplify(db);
00150 if(f instanceof TrueFalse)
00151 return;
00152 }
00153 ckb.addFormula(f);
00154 }
00155 return;
00156 }
00157 for(int k = 0; k < domProd[i].getDomain().getOrder(); k++) {
00158 addr[i] = k;
00159 walkCPF4HardConstraints(coupling, cpf, addr, i+1, ckb, db);
00160 }
00161 }
00162
00163 protected static class PropositionalVariable extends AbstractVariable {
00164
00165 public PropositionalVariable(String varName, String value) {
00166 super(varName, new String[0], value);
00167 }
00168
00169 @Override
00170 public String getPredicate() {
00171 return this.functionName + "(" + value + ")";
00172 }
00173
00174 @Override
00175 public boolean isBoolean() {
00176 return false;
00177 }
00178 }
00179
00180 @Override
00181 public void initSample(WeightedSample s) throws Exception {
00182 super.initSample(s);
00183
00184
00185 sat.run();
00186 PossibleWorld state = sat.getState();
00187
00188
00189 for(BeliefNode var : determinedVars) {
00190 int domIdx = coupling.getVariableValue(var, state);
00191 s.nodeDomainIndices[this.getNodeIndex(var)] = domIdx;
00192
00193
00194
00195 }
00196 }
00197
00203 protected void getOrdering(int[] evidenceDomainIndices) throws Exception {
00204 HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes));
00205 backwardSampledNodes = new Vector<BeliefNode>();
00206 forwardSampledNodes = new Vector<BeliefNode>();
00207 outsideSamplingOrder = new HashSet<BeliefNode>();
00208 TopologicalOrdering topOrder = new TopologicalSort(bn.bn).run(true);
00209 PriorityQueue<BeliefNode> backSamplingCandidates = new PriorityQueue<BeliefNode>(1, new TierComparator(topOrder));
00210
00211
00212
00213 for(BeliefNode n : determinedVars) {
00214 uninstantiatedNodes.remove(n);
00215 outsideSamplingOrder.add(n);
00216 }
00217
00218
00219 for(int i = 0; i < evidenceDomainIndices.length; i++) {
00220 if(evidenceDomainIndices[i] >= 0) {
00221 backSamplingCandidates.add(nodes[i]);
00222 uninstantiatedNodes.remove(nodes[i]);
00223 }
00224 }
00225
00226
00227 while(!backSamplingCandidates.isEmpty()) {
00228 BeliefNode node = backSamplingCandidates.remove();
00229
00230 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00231 boolean doBackSampling = false;
00232 for(int j = 1; j < domProd.length; j++) {
00233 BeliefNode parent = domProd[j];
00234
00235 if(uninstantiatedNodes.remove(parent)) {
00236 doBackSampling = true;
00237 backSamplingCandidates.add(parent);
00238 }
00239 }
00240 if(doBackSampling)
00241 backwardSampledNodes.add(node);
00242
00243
00244 else
00245 outsideSamplingOrder.add(node);
00246 }
00247
00248
00249 for(int i : topOrder) {
00250 if(uninstantiatedNodes.contains(nodes[i]))
00251 forwardSampledNodes.add(nodes[i]);
00252 }
00253
00254 out.println("node ordering: " + outsideSamplingOrder.size() + " outside order, " + backwardSampledNodes.size() + " backward, " + forwardSampledNodes.size() + " forward");
00255 }
00256
00257 @Override
00258 public String getAlgorithmName() {
00259 return String.format("%s[%s]", getClass().getSimpleName(), sat.getAlgorithmName());
00260 }
00261 }