00001
00002
00003
00004
00005
00006
00007 package edu.tum.cs.srl.bayesnets.inference;
00008
00009 import java.util.Vector;
00010
00011 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00012 import edu.ksu.cis.bnj.ver3.core.CPF;
00013 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00014 import edu.tum.cs.bayesnets.inference.ITimeLimitedInference;
00015 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00016 import edu.tum.cs.logic.Disjunction;
00017 import edu.tum.cs.logic.Formula;
00018 import edu.tum.cs.logic.GroundLiteral;
00019 import edu.tum.cs.logic.sat.weighted.WeightedClausalKB;
00020 import edu.tum.cs.logic.sat.weighted.WeightedClause;
00021 import edu.tum.cs.logic.sat.weighted.WeightedFormula;
00022 import edu.tum.cs.logic.sat.weighted.MCSAT.GroundAtomDistribution;
00023 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00024
00029 public class MCSAT extends Sampler implements ITimeLimitedInference {
00030
00031 protected GroundBLN gbln;
00032 protected WeightedClausalKB kb;
00033 protected double maxWeight = 0;
00037 protected Vector<Disjunction> hardConstraintsInCPTs = new Vector<Disjunction>();
00038 protected edu.tum.cs.logic.sat.weighted.MCSAT sampler;
00039
00040 public MCSAT(GroundBLN gbln) throws Exception {
00041 super(gbln);
00042 this.gbln = gbln;
00043 kb = new WeightedClausalKB();
00044
00045 for(BeliefNode n : gbln.getRegularVariables()) {
00046 CPF cpf = n.getCPF();
00047 BeliefNode[] domProd = cpf.getDomainProduct();
00048 walkCPT4ClauseCollection(cpf, domProd, new int[domProd.length], 0);
00049 }
00050
00051 double hardWeight = maxWeight + 100;
00052 for(Formula f : gbln.getKB()) {
00053 kb.addFormula(new WeightedFormula(f, hardWeight, true), false);
00054 }
00055 for(Disjunction f : hardConstraintsInCPTs)
00056 kb.addClause(new WeightedClause(f, hardWeight, true));
00057
00058 hardConstraintsInCPTs = null;
00059 sampler = new edu.tum.cs.logic.sat.weighted.MCSAT(kb, gbln.getWorldVars(), gbln.getDatabase());
00060
00061 paramHandler.addSubhandler(sampler.getParameterHandler());
00062 }
00063
00064 protected void walkCPT4ClauseCollection(CPF cpf, BeliefNode[] domProd, int[] domainIndices, int i) throws Exception {
00065 if(i == domainIndices.length) {
00066
00067 GroundLiteral[] lits = new GroundLiteral[domainIndices.length];
00068 for(int j = 0; j < domainIndices.length; j++) {
00069 lits[j] = gbln.getGroundLiteral(domProd[j], domainIndices[j]);
00070 lits[j].negate();
00071 }
00072 Disjunction f = new Disjunction(lits);
00073
00074 double p = cpf.getDouble(domainIndices);
00075 if(p == 0.0) {
00076 hardConstraintsInCPTs.add(f);
00077 }
00078 else {
00079 double weight = -Math.log(p);
00080 kb.addClause(new WeightedClause(f, weight, false));
00081 if(weight > maxWeight)
00082 maxWeight = weight;
00083 }
00084 return;
00085 }
00086
00087 for(int j = 0; j < domProd[i].getDomain().getOrder(); j++) {
00088 domainIndices[i] = j;
00089 walkCPT4ClauseCollection(cpf, domProd, domainIndices, i+1);
00090 }
00091 }
00092
00093 @Override
00094 public SampledDistribution _infer() throws Exception {
00095 sampler.setDebugMode(this.debug);
00096 sampler.setVerbose(true);
00097 sampler.setInfoInterval(infoInterval);
00098 GroundAtomDistribution gad = sampler.run(numSamples);
00099 return getSampledDistribution(gad);
00100 }
00101
00102 protected SampledDistribution getSampledDistribution(GroundAtomDistribution gad) throws Exception {
00103 gad.normalize();
00104 BeliefNetworkEx bn = gbln.getGroundNetwork();
00105 SampledDistribution dist = new SampledDistribution(bn);
00106 for(BeliefNode n : gbln.getRegularVariables()) {
00107 int idx = bn.getNodeIndex(n);
00108 for(int k = 0; k < n.getDomain().getOrder(); k++) {
00109 GroundLiteral lit = gbln.getGroundLiteral(n, k);
00110 dist.values[idx][k] = gad.getResult(lit.gndAtom.index);
00111 if(!lit.isPositive)
00112 dist.values[idx][k] = 1-dist.values[idx][k];
00113 }
00114 }
00115 for(BeliefNode n : gbln.getAuxiliaryVariables()) {
00116 int idx = bn.getNodeIndex(n);
00117 dist.values[idx][0] = 1.0;
00118 dist.values[idx][1] = 0.0;
00119 }
00120 dist.Z = 1.0;
00121 dist.trials = dist.steps = gad.numSamples;
00122 return dist;
00123 }
00124
00125 public SampledDistribution pollResults() throws Exception {
00126 return getSampledDistribution(sampler.pollResults());
00127 }
00128 }