00001
00002
00003
00004
00005
00006
00007 package edu.tum.cs.logic.sat.weighted;
00008
00009 import java.util.Random;
00010 import java.util.Vector;
00011 import java.util.Map.Entry;
00012
00013 import edu.tum.cs.inference.IParameterHandler;
00014 import edu.tum.cs.inference.ParameterHandler;
00015 import edu.tum.cs.logic.GroundAtom;
00016 import edu.tum.cs.logic.PossibleWorld;
00017 import edu.tum.cs.logic.WorldVariables;
00018 import edu.tum.cs.logic.sat.SampleSAT;
00019 import edu.tum.cs.srl.Database;
00020
00025 public class MCSAT implements IParameterHandler {
00026
00027 protected WeightedClausalKB kb;
00028 protected WorldVariables vars;
00029 protected Database db;
00030 protected Random rand;
00031 protected GroundAtomDistribution dist;
00032 protected boolean verbose = false, debug = false;
00033 protected int infoInterval = 100;
00034 protected ParameterHandler paramHandler;
00035 protected SampleSAT sat;
00036
00037 public MCSAT(WeightedClausalKB kb, WorldVariables vars, Database db) throws Exception {
00038 this.kb = kb;
00039 this.vars = vars;
00040 this.db = db;
00041 this.rand = new Random();
00042 this.dist = new GroundAtomDistribution(vars);
00043 this.paramHandler = new ParameterHandler(this);
00044 PossibleWorld state = new PossibleWorld(vars);
00045 sat = new SampleSAT(state, vars, db.getEntries());
00046 paramHandler.addSubhandler(sat.getParameterHandler());
00047 paramHandler.add("infoInterval", "setInfoInterval");
00048 paramHandler.add("verbose", "setVerbose");
00049 }
00050
00051 public WeightedClausalKB getKB() {
00052 return kb;
00053 }
00054
00055 public void setVerbose(boolean verbose) {
00056 this.verbose = verbose;
00057 }
00058
00059 public void setDebugMode(boolean active) {
00060 this.debug = active;
00061 }
00062
00063 public void setInfoInterval(int interval) {
00064 this.infoInterval = interval;
00065 }
00066
00067 public GroundAtomDistribution run(int steps) throws Exception {
00068 if(debug) {
00069 System.out.println("\nMC-SAT constraints:");
00070 for(WeightedClause wc : kb)
00071 System.out.println(" " + wc);
00072 System.out.println();
00073 }
00074 verbose = verbose || debug;
00075 if(verbose)
00076 System.out.printf("%s sampling...\n", this.getAlgorithmName());
00077
00078
00079 if(verbose) System.out.println("finding initial state...");
00080 Vector<WeightedClause> M = new Vector<WeightedClause>();
00081 for(Entry<WeightedFormula, Vector<WeightedClause>> e : kb.getFormulasAndClauses()) {
00082 WeightedFormula wf = e.getKey();
00083 if(wf.isHard) {
00084 M.addAll(e.getValue());
00085 }
00086 }
00087 sat.setDebugMode(debug);
00088 sat.initConstraints(M);
00089 sat.run();
00090
00091
00092 for(int i = 0; i < steps; i++) {
00093 M.clear();
00094 for(Entry<WeightedFormula, Vector<WeightedClause>> e : kb.getFormulasAndClauses()) {
00095 WeightedFormula wf = e.getKey();
00096 if(wf.formula.isTrue(sat.getState())){
00097 boolean satisfy = wf.isHard || rand.nextDouble() * Math.exp(wf.weight) > 1.0;
00098 if(satisfy)
00099 M.addAll(e.getValue());
00100 }
00101 }
00102 if(verbose && (i+1) % infoInterval == 0) {
00103 System.out.printf("MC-SAT step %d: %d constraints to be satisfied\n", i+1, M.size());
00104 if(debug) {
00105 for(WeightedClause wc : M)
00106 System.out.println(" " + wc);
00107 }
00108 }
00109 sat.initConstraints(M);
00110 sat.run();
00111
00112 if(false) {
00113 sat.getState().print();
00114 }
00115
00116 synchronized(dist) {
00117 dist.addSample(sat.getState(), 1.0);
00118 }
00119 }
00120 synchronized(dist) {
00121 dist.normalize();
00122 }
00123
00124 return dist;
00125 }
00126
00127 public void setP(double p) {
00128 sat.setPSampleSAT(p);
00129 }
00130
00131 public static class GroundAtomDistribution implements Cloneable {
00132 public double[] sums;
00133 public double Z;
00134 public int numSamples;
00135
00136 public GroundAtomDistribution(WorldVariables vars){
00137 this.Z = 0.0;
00138 this.numSamples = 0;
00139 this.sums = new double[vars.size()];
00140 }
00141
00142 public void addSample(PossibleWorld w, double weight){
00143 for(GroundAtom ga : w.getVariables()){
00144 if(w.isTrue(ga)){
00145 sums[ga.index] += weight;
00146 }
00147 }
00148 Z += weight;
00149 numSamples++;
00150 }
00151
00152 public void normalize(){
00153 if(Z != 1.0) {
00154 for(int i = 0; i < sums.length; i++){
00155 sums[i] /= Z;
00156 }
00157 Z = 1.0;
00158 }
00159 }
00160
00161 public double getResult(int indx){
00162 return sums[indx];
00163 }
00164
00165 public GroundAtomDistribution clone() throws CloneNotSupportedException {
00166 return (GroundAtomDistribution)super.clone();
00167 }
00168 }
00169
00170 public double getResult(GroundAtom ga) {
00171 return dist.getResult(ga.index);
00172 }
00173
00174 public GroundAtomDistribution pollResults() throws CloneNotSupportedException {
00175 GroundAtomDistribution ret = null;
00176 synchronized(dist) {
00177 ret = this.dist.clone();
00178 }
00179 return ret;
00180 }
00181
00182 public ParameterHandler getParameterHandler() {
00183 return paramHandler;
00184 }
00185
00186 public String getAlgorithmName() {
00187 return String.format("%s[%s]", this.getClass().getSimpleName(), sat.getAlgorithmName());
00188 }
00189 }