00001 package edu.tum.cs.srl.bayesnets.inference;
00002
00003 import java.util.HashMap;
00004 import java.util.Vector;
00005
00006 import weka.classifiers.trees.j48.Rule;
00007 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00008 import edu.tum.cs.bayesnets.inference.SATIS_BSampler;
00009 import edu.tum.cs.logic.Formula;
00010 import edu.tum.cs.logic.Negation;
00011 import edu.tum.cs.logic.TrueFalse;
00012 import edu.tum.cs.logic.sat.ClausalKB;
00013 import edu.tum.cs.srl.bayesnets.RelationalBeliefNetwork;
00014 import edu.tum.cs.srl.bayesnets.RelationalNode;
00015 import edu.tum.cs.srl.bayesnets.CPT2MLNFormulas.CPT2Rules;
00016 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00017 import edu.tum.cs.srl.bayesnets.bln.coupling.VariableLogicCoupling;
00018 import edu.tum.cs.util.datastruct.Map2D;
00019
00026 public class SATISEx extends SATIS {
00030 boolean exploitCSI = false;
00031
00032 public SATISEx(GroundBLN bln) throws Exception {
00033 super(bln);
00034 this.paramHandler.add("useCSI", "useCSI");
00035 }
00036
00037 public void useCSI(boolean active) {
00038 exploitCSI = active;
00039 }
00040
00041 @Override
00042 protected ClausalKB getClausalKB() throws Exception {
00043 ClausalKB ckb = super.getClausalKB();
00044
00045
00046 if(!exploitCSI)
00047 SATIS_BSampler.extendKBWithDeterministicConstraintsInCPTs(gbln.getGroundNetwork(), gbln.getCoupling(), ckb, gbln.getDatabase());
00048 else {
00049
00050 System.out.println("CSI analysis...");
00051 Map2D<RelationalNode, String, Vector<Formula>> constraints = new Map2D<RelationalNode, String, Vector<Formula>>();
00052 int numFormulas = 0;
00053 int numZeros = 0;
00054 int numDirectTranslations = 0;
00055 RelationalBeliefNetwork rbn = this.gbln.getRBN();
00056 for(RelationalNode relNode : rbn.getRelationalNodes()) {
00057 if(!relNode.isFragment())
00058 continue;
00059
00060 CPT2Rules cpt2rules = null;
00061 for(HashMap<String,String> constantAssignment : relNode.getConstantAssignments()) {
00062 Vector<Formula> v = new Vector<Formula>();
00063 if(relNode.hasAggregator()) {
00064 Formula f = relNode.toFormula(constantAssignment);
00065 if(f == null)
00066 throw new Exception("Relational node " + relNode + " could not be translated to a formula");
00067
00068 v.add(f);
00069 numDirectTranslations++;
00070 }
00071 else {
00072 if(cpt2rules == null) {
00073 cpt2rules = new CPT2Rules(relNode);
00074 numZeros += cpt2rules.getZerosInCPT();
00075 }
00076
00077 Rule[] rules = cpt2rules.learnRules(constantAssignment);
00078 for(Rule rule : rules) {
00079 if(cpt2rules.getProbability(rule) == 0.0) {
00080 Formula f = cpt2rules.getConjunction(rule, constantAssignment);
00081 v.add(new Negation(f));
00082 numFormulas++;
00083 }
00084 }
00085 }
00086
00087 StringBuffer sb = new StringBuffer();
00088 for(Integer i : relNode.getIndicesOfConstantParams())
00089 sb.append(constantAssignment.get(relNode.params[i]));
00090 String constantKey = sb.toString();
00091
00092 constraints.put(relNode, constantKey, v);
00093 }
00094 }
00095 System.out.printf("reduced %d zeros in CPTs to %d formulas; %d direct translations\n", numZeros, numFormulas, numDirectTranslations);
00096
00097
00098 System.out.println("grounding constraints...");
00099 VariableLogicCoupling coupling = gbln.getCoupling();
00100 int sizeBefore = ckb.size();
00101 for(BeliefNode node : gbln.getRegularVariables()) {
00102 RelationalNode template = gbln.getTemplateOf(node);
00103
00104 Iterable<String> params = coupling.getOriginalParams(node);
00105
00106 StringBuffer sb = new StringBuffer();
00107 int i = 0;
00108 Vector<Integer> constIndices = template.getIndicesOfConstantParams();
00109 for(String p : params) {
00110 if(constIndices.contains(i))
00111 sb.append(p);
00112 i++;
00113 }
00114 String constantKey = sb.toString();
00115
00116 Vector<Formula> vf = constraints.get(template, constantKey);
00117 if(vf != null) {
00118
00119 i = 0;
00120 String[] actualParams = new String[template.params.length];
00121 for(String param : params)
00122 actualParams[i++] = param;
00123 HashMap<String,String> binding = template.getParameterBinding(actualParams, gbln.getDatabase());
00124
00125 for(Formula f : vf) {
00126
00127 Formula gf = f.ground(binding, coupling.getWorldVars(), gbln.getDatabase());
00128 Formula gfs = gf.simplify(gbln.getDatabase());
00129 if(gfs instanceof TrueFalse) {
00130 TrueFalse tf = (TrueFalse)gfs;
00131 if(!tf.isTrue())
00132 System.err.println("unsatisfiable formula" + gf);
00133 continue;
00134 }
00135 ckb.addFormula(gfs);
00136 }
00137 }
00138 }
00139 System.out.printf("added %d constraints\n", ckb.size()-sizeBefore);
00140 }
00141
00142 return ckb;
00143 }
00144 }