00001 
00002 
00003 
00004 
00005 
00006 
00007 package edu.tum.cs.srl.bayesnets.inference;
00008 
00009 import java.util.Vector;
00010 
00011 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00012 import edu.ksu.cis.bnj.ver3.core.CPF;
00013 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00014 import edu.tum.cs.bayesnets.inference.ITimeLimitedInference;
00015 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00016 import edu.tum.cs.logic.Disjunction;
00017 import edu.tum.cs.logic.Formula;
00018 import edu.tum.cs.logic.GroundLiteral;
00019 import edu.tum.cs.logic.sat.weighted.WeightedClausalKB;
00020 import edu.tum.cs.logic.sat.weighted.WeightedClause;
00021 import edu.tum.cs.logic.sat.weighted.WeightedFormula;
00022 import edu.tum.cs.logic.sat.weighted.MCSAT.GroundAtomDistribution;
00023 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00024 
00029 public class MCSAT extends Sampler implements ITimeLimitedInference {
00030 
00031         protected GroundBLN gbln;
00032         protected WeightedClausalKB kb;
00033         protected double maxWeight = 0;
00037         protected Vector<Disjunction> hardConstraintsInCPTs = new Vector<Disjunction>();
00038         protected edu.tum.cs.logic.sat.weighted.MCSAT sampler;
00039         
00040         public MCSAT(GroundBLN gbln) throws Exception {
00041                 super(gbln);
00042                 this.gbln = gbln;
00043                 kb = new WeightedClausalKB();           
00044                 
00045                 for(BeliefNode n : gbln.getRegularVariables()) {
00046                         CPF cpf = n.getCPF();
00047                         BeliefNode[] domProd = cpf.getDomainProduct();
00048                         walkCPT4ClauseCollection(cpf, domProd, new int[domProd.length], 0);                                             
00049                 }
00050                 
00051                 double hardWeight = maxWeight + 100;
00052                 for(Formula f : gbln.getKB()) {
00053                         kb.addFormula(new WeightedFormula(f, hardWeight, true), false);
00054                 }
00055                 for(Disjunction f : hardConstraintsInCPTs) 
00056                         kb.addClause(new WeightedClause(f, hardWeight, true));
00057                 
00058                 hardConstraintsInCPTs = null;
00059                 sampler = new edu.tum.cs.logic.sat.weighted.MCSAT(kb, gbln.getWorldVars(), gbln.getDatabase());
00060                 
00061                 paramHandler.addSubhandler(sampler.getParameterHandler());
00062         }
00063         
00064         protected void walkCPT4ClauseCollection(CPF cpf, BeliefNode[] domProd, int[] domainIndices, int i) throws Exception {
00065                 if(i == domainIndices.length) {
00066                         
00067                         GroundLiteral[] lits = new GroundLiteral[domainIndices.length];
00068                         for(int j = 0; j < domainIndices.length; j++) {
00069                                 lits[j] = gbln.getGroundLiteral(domProd[j], domainIndices[j]);
00070                                 lits[j].negate();
00071                         }
00072                         Disjunction f = new Disjunction(lits);
00073                         
00074                         double p = cpf.getDouble(domainIndices);
00075                         if(p == 0.0) { 
00076                                 hardConstraintsInCPTs.add(f);
00077                         }
00078                         else { 
00079                                 double weight = -Math.log(p);
00080                                 kb.addClause(new WeightedClause(f, weight, false));
00081                                 if(weight > maxWeight)
00082                                         maxWeight = weight;
00083                         }
00084                         return;
00085                 }       
00086                 
00087                 for(int j = 0; j < domProd[i].getDomain().getOrder(); j++) {
00088                         domainIndices[i] = j;
00089                         walkCPT4ClauseCollection(cpf, domProd, domainIndices, i+1);
00090                 }
00091         }
00092         
00093         @Override
00094         public SampledDistribution _infer() throws Exception {
00095                 sampler.setDebugMode(this.debug);
00096                 sampler.setVerbose(true);
00097                 sampler.setInfoInterval(infoInterval);
00098                 GroundAtomDistribution gad = sampler.run(numSamples);           
00099                 return getSampledDistribution(gad);     
00100         }
00101         
00102         protected SampledDistribution getSampledDistribution(GroundAtomDistribution gad) throws Exception {
00103                 gad.normalize();
00104                 BeliefNetworkEx bn = gbln.getGroundNetwork();
00105                 SampledDistribution dist = new SampledDistribution(bn);
00106                 for(BeliefNode n : gbln.getRegularVariables()) {
00107                         int idx = bn.getNodeIndex(n);
00108                         for(int k = 0; k < n.getDomain().getOrder(); k++) {
00109                                 GroundLiteral lit = gbln.getGroundLiteral(n, k);
00110                                 dist.values[idx][k] = gad.getResult(lit.gndAtom.index);
00111                                 if(!lit.isPositive)
00112                                         dist.values[idx][k] = 1-dist.values[idx][k];
00113                         }
00114                 }
00115                 for(BeliefNode n : gbln.getAuxiliaryVariables()) {
00116                         int idx = bn.getNodeIndex(n);
00117                         dist.values[idx][0] = 1.0;
00118                         dist.values[idx][1] = 0.0;
00119                 }
00120                 dist.Z = 1.0;
00121                 dist.trials = dist.steps = gad.numSamples;
00122                 return dist;
00123         }
00124 
00125         public SampledDistribution pollResults() throws Exception {             
00126                 return getSampledDistribution(sampler.pollResults());
00127         }
00128 }