00001 package edu.tum.cs.logic.sat;
00002
00003 import java.util.HashMap;
00004 import java.util.HashSet;
00005 import java.util.LinkedList;
00006 import java.util.Map;
00007 import java.util.Random;
00008 import java.util.Vector;
00009
00010 import edu.tum.cs.inference.IParameterHandler;
00011 import edu.tum.cs.inference.ParameterHandler;
00012 import edu.tum.cs.logic.GroundAtom;
00013 import edu.tum.cs.logic.GroundLiteral;
00014 import edu.tum.cs.logic.PossibleWorld;
00015 import edu.tum.cs.logic.WorldVariables;
00016 import edu.tum.cs.logic.WorldVariables.Block;
00017 import edu.tum.cs.srl.AbstractVariable;
00018 import edu.tum.cs.srl.Database;
00019 import edu.tum.cs.srl.bayesnets.ABL;
00020 import edu.tum.cs.srl.bayesnets.bln.BayesianLogicNetwork;
00021 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00022 import edu.tum.cs.util.Stopwatch;
00023 import edu.tum.cs.util.StringTool;
00024
00031 public class SampleSAT implements IParameterHandler {
00032 protected HashMap<Integer,Vector<Constraint>> bottlenecks;
00033 protected HashMap<Integer,Vector<Constraint>> GAOccurrences;
00034 protected PossibleWorld state;
00035 protected Vector<Constraint> unsatisfiedConstraints;
00036 protected Vector<Constraint> constraints;
00037 protected Random rand;
00038 protected WorldVariables vars;
00039 protected boolean debug = false;
00040 protected EvidenceHandler evidenceHandler;
00041 protected HashMap<Integer,Boolean> evidence;
00042 protected boolean useUnitPropagation = false;
00043 Iterable<? extends edu.tum.cs.logic.sat.Clause> kb;
00044 protected ParameterHandler paramHandler;
00048 protected double pSampleSAT = 0.9;
00049
00054 protected double pWalkSAT = 0.5;
00055
00056
00064 public SampleSAT(Iterable<? extends edu.tum.cs.logic.sat.Clause> kb, PossibleWorld state, WorldVariables vars, Iterable<? extends AbstractVariable> db) throws Exception {
00065 this.state = state;
00066 this.vars = vars;
00067 this.kb = kb;
00068 rand = new Random();
00069 constraints = null;
00070
00071
00072 paramHandler = new ParameterHandler(this);
00073 paramHandler.add("pSampleSAT", "setPSampleSAT");
00074 paramHandler.add("pWalkSAT", "setPWalkSAT");
00075
00076
00077 evidenceHandler = new EvidenceHandler(vars, db);
00078 evidence = evidenceHandler.getEvidence();
00079 }
00080
00088 public SampleSAT(PossibleWorld state, WorldVariables vars, Iterable<? extends AbstractVariable> db) throws Exception {
00089 this(null, state, vars, db);
00090 }
00091
00092 public void setDebugMode(boolean active) {
00093 debug = active;
00094 }
00095
00099 public void enableUnitPropagation() {
00100 useUnitPropagation = true;
00101 }
00102
00108 public void initConstraints(Iterable<? extends edu.tum.cs.logic.sat.Clause> kb) throws Exception {
00109
00110 if(constraints != null && useUnitPropagation)
00111 throw new Exception("Resetting the set of constraints is not allowed when using unit propagation, because unit propagation extends the evidence database, which currently cannot be reversed.");
00112 this.kb = kb;
00113
00114
00115 unsatisfiedConstraints = new Vector<Constraint>();
00116 bottlenecks = new HashMap<Integer,Vector<Constraint>>();
00117
00118
00119 constraints = new Vector<Constraint>();
00120 GAOccurrences = new HashMap<Integer,Vector<Constraint>>();
00121 for(edu.tum.cs.logic.sat.Clause c : kb)
00122 constraints.add(new Clause(c.lits));
00123
00124
00125 if(useUnitPropagation)
00126 unitPropagation();
00127
00128
00129 evidenceHandler.setEvidenceInState(state);
00130 }
00131
00135 protected void unitPropagation() {
00136 int oldSize = constraints.size();
00137 LinkedList<Clause> unitClauses = new LinkedList<Clause>();
00138 for(Constraint c : constraints) {
00139 if(c instanceof Clause) {
00140 Clause cl = (Clause)c;
00141 if(cl.size() == 1)
00142 unitClauses.add(cl);
00143 }
00144 }
00145 while(!unitClauses.isEmpty()) {
00146 Clause cl = unitClauses.remove();
00147 GroundLiteral lit = cl.getLiterals()[0];
00148 evidence.put(lit.gndAtom.index, lit.isPositive);
00149 Vector<Constraint> affected = GAOccurrences.get(lit.gndAtom.index);
00150 if(affected != null) {
00151 Vector<Clause> scheduledForRemoval = new Vector<Clause>();
00152 for(Constraint c : affected) {
00153 if(c instanceof Clause) {
00154 Clause acl = (Clause)c;
00155 for(GroundLiteral l : acl.getLiterals()) {
00156 if(l.gndAtom.index == lit.gndAtom.index) {
00157 if(l.isPositive == lit.isPositive)
00158 scheduledForRemoval.add(acl);
00159 else {
00160 acl.removeLiteral(lit.gndAtom.index);
00161 if(acl.size() == 1)
00162 unitClauses.add(acl);
00163 if(acl.size() == 0)
00164 constraints.remove(acl);
00165 }
00166 }
00167 }
00168 }
00169 }
00170 for(Clause acl : scheduledForRemoval)
00171 removeClause(acl);
00172 }
00173
00174 constraints.remove(cl);
00175
00176 GAOccurrences.remove(lit.gndAtom.index);
00177 }
00178 int newSize = constraints.size();
00179 if(debug || true) System.out.println("unit propagation removed " + (oldSize-newSize) + " constraints");
00180 }
00181
00182 protected void removeClause(Clause c) {
00183 constraints.remove(c);
00184
00185 for(GroundLiteral lit : c.getLiterals())
00186 GAOccurrences.get(lit.gndAtom.index).remove(c);
00187 }
00188
00189 protected void addUnsatisfiedConstraint(Constraint c) {
00190 unsatisfiedConstraints.add(c);
00191 }
00192
00193 protected void addBottleneck(GroundAtom a, Constraint c) {
00194 Vector<Constraint> v = bottlenecks.get(a.index);
00195 if(v == null) {
00196 v = new Vector<Constraint>();
00197 bottlenecks.put(a.index, v);
00198 }
00199 v.add(c);
00200 }
00201
00202 protected void addGAOccurrence(GroundAtom a, Constraint c) {
00203 Vector<Constraint> v = GAOccurrences.get(a.index);
00204 if(v == null) {
00205 v = new Vector<Constraint>();
00206 GAOccurrences.put(a.index, v);
00207 }
00208 v.add(c);
00209 }
00210
00215 public void run() throws Exception {
00216
00217 if(constraints == null)
00218 initConstraints(kb);
00219
00220
00221 bottlenecks.clear();
00222 unsatisfiedConstraints.clear();
00223 if(debug) System.out.println("setting random state...");
00224 setRandomState();
00225 if(debug) state.print();
00226 for(Constraint c : constraints)
00227 c.initState();
00228
00229 int step = 1;
00230 while(unsatisfiedConstraints.size() > 0) {
00231
00232 if(debug) {
00233 System.out.println("SAT step " + step + ", " + unsatisfiedConstraints.size() + " constraints unsatisfied");
00234 if(true) {
00235
00236 if(unsatisfiedConstraints.size() < 30)
00237 for(Constraint c : unsatisfiedConstraints) {
00238 System.out.println(" unsatisfied: " + c);
00239 }
00240 }
00241 checkIntegrity();
00242 }
00243
00244 makeMove();
00245 step++;
00246 }
00247 }
00248
00253 protected void checkIntegrity() throws Exception {
00254
00255 for(Constraint c : this.constraints) {
00256 if(c instanceof Clause) {
00257 Clause cl = (Clause)c;
00258 int numTrue = 0;
00259 for(GroundLiteral lit : cl.lits)
00260 if(lit.isTrue(state)) {
00261 numTrue++;
00262 if(!cl.trueOnes.contains(lit.gndAtom))
00263 throw new Exception("Clause.trueOnes corrupted (1)");
00264 }
00265 if(numTrue != cl.trueOnes.size())
00266 throw new Exception("Clause.trueOnes corrupted (2)");
00267 boolean isTrue = numTrue > 0;
00268 boolean contained = unsatisfiedConstraints.contains(c);
00269 if(contained != !isTrue)
00270 throw new Exception("Unsatisfied constraints corrupted");
00271 }
00272 }
00273
00274 for(java.util.Map.Entry<Integer,Vector<Constraint>> entry : bottlenecks.entrySet()) {
00275 GroundAtom ga = this.vars.get(entry.getKey());
00276 for(Constraint c : entry.getValue()) {
00277 if(c instanceof Clause) {
00278 Clause cl = (Clause)c;
00279 boolean haveTrueOne = false;
00280 for(GroundLiteral lit : cl.lits) {
00281 if(lit.isTrue(state)) {
00282 if(haveTrueOne)
00283 throw new Exception("Bottlenecks corrupted: Clause " + cl + " contains a second true literal.");
00284 if(lit.gndAtom != ga)
00285 throw new Exception("Bottlenecks corrupted: Clause " + cl + " contains a true literal that isn't the bottleneck.");
00286 haveTrueOne = true;
00287 }
00288 if(lit.gndAtom == ga && !lit.isTrue(state))
00289 throw new Exception("Bottlenecks corrupted: Clause " + cl + " has " + ga + " as a bottleneck but contains a literal with " + ga + " that is false; it is likely that the clause is a tautology which should never have bottlenecks.");
00290 }
00291 }
00292 }
00293 }
00294 }
00295
00296 public PossibleWorld getState() {
00297 return state;
00298 }
00299
00303 protected void setRandomState() {
00304 evidenceHandler.setRandomState(state);
00305 }
00306
00307 protected void makeMove() {
00308 if(rand.nextDouble() < this.pSampleSAT) {
00309 if(debug) System.out.println(" WalkSAT move:");
00310 walkSATMove();
00311 }
00312 else {
00313 if(debug) System.out.println(" SA move:");
00314 SAMove();
00315 }
00316 }
00317
00318 protected void walkSATMove() {
00319
00320 Constraint c = unsatisfiedConstraints.get(rand.nextInt(unsatisfiedConstraints.size()));
00321
00322 if(rand.nextDouble() < this.pWalkSAT)
00323 c.satisfyRandomly();
00324
00325 else
00326 c.satisfyGreedily();
00327 }
00328
00329 protected void SAMove() {
00330 boolean done = false;
00331 while(!done) {
00332
00333 int idxGA = rand.nextInt(vars.size());
00334 GroundAtom gndAtom = vars.get(idxGA);
00335
00336 if(evidence.containsKey(idxGA))
00337 continue;
00338
00339 done = pickSecondAtRandomAndFlip(gndAtom);
00340 }
00341 }
00342
00348 protected boolean pickSecondAtRandomAndFlip(GroundAtom gndAtom) {
00349
00350 GroundAtom gndAtom2 = null;
00351 Block block = vars.getBlock(gndAtom.index);
00352 if(block != null) {
00353 GroundAtom trueOne = block.getTrueOne(state);
00354 if(gndAtom == trueOne) {
00355 Vector<GroundAtom> others = new Vector<GroundAtom>();
00356 for(GroundAtom ga : block) {
00357 if(ga != trueOne && !evidence.containsKey(ga.index))
00358 others.add(ga);
00359 }
00360 if(others.isEmpty())
00361 return false;
00362 gndAtom2 = others.get(rand.nextInt(others.size()));
00363 }
00364 else {
00365 if(evidence.containsKey(trueOne.index))
00366 return false;
00367 gndAtom2 = trueOne;
00368 }
00369 }
00370
00371 flipGndAtom(gndAtom);
00372 if(gndAtom2 != null)
00373 flipGndAtom(gndAtom2);
00374 return true;
00375 }
00376
00377 protected void pickAndFlipVar(Iterable<GroundAtom> candidates) {
00378
00379 GroundAtom bestGA = null, bestGASecond = null;
00380 int bestDelta = Integer.MIN_VALUE;
00381 for(GroundAtom gndAtom : candidates) {
00382
00383 if(evidence.containsKey(gndAtom.index))
00384 continue;
00385
00386 int delta = deltaCost(gndAtom);
00387
00388 Block block = vars.getBlock(gndAtom.index);
00389 GroundAtom secondGA = null;
00390 if(block != null) {
00391 GroundAtom trueOne = block.getTrueOne(state);
00392 int delta2 = Integer.MIN_VALUE;
00393 if(gndAtom != trueOne) {
00394 secondGA = trueOne;
00395 delta2 = deltaCost(secondGA);
00396 }
00397 else {
00398 for(GroundAtom ga2 : block) {
00399 if(evidence.containsKey(ga2.index) || ga2 == gndAtom)
00400 continue;
00401 int d = deltaCost(ga2);
00402 if(d > delta2) {
00403 delta2 = d;
00404 secondGA = ga2;
00405 }
00406 }
00407 }
00408 if(secondGA == null)
00409 continue;
00410 delta += delta2;
00411 }
00412
00413 boolean newBest = false;
00414 if(delta > bestDelta)
00415 newBest = true;
00416 else if(delta == bestDelta && rand.nextInt(2) == 1)
00417 newBest = true;
00418 if(newBest) {
00419 bestGA = gndAtom;
00420 bestGASecond = secondGA;
00421 bestDelta = delta;
00422 }
00423 }
00424
00425 flipGndAtom(bestGA);
00426 if(bestGASecond != null)
00427 flipGndAtom(bestGASecond);
00428 }
00429
00430 protected void flipGndAtom(GroundAtom gndAtom) {
00431 if(debug) System.out.println(" flipping " + gndAtom);
00432
00433 boolean value = state.isTrue(gndAtom);
00434 state.set(gndAtom, !value);
00435
00436 Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
00437 if(bn != null) {
00438 this.unsatisfiedConstraints.addAll(bn);
00439 bn.clear();
00440 }
00441
00442 Vector<Constraint> occ = this.GAOccurrences.get(gndAtom.index);
00443 if(occ != null)
00444 for(Constraint c : occ)
00445 c.handleFlip(gndAtom);
00446 }
00447
00448 protected int deltaCost(GroundAtom gndAtom) {
00449 int delta = 0;
00450
00451 Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
00452 if(bn != null)
00453 delta -= bn.size();
00454
00455 Vector<Constraint> occs = this.GAOccurrences.get(gndAtom.index);
00456 if(occs != null)
00457 for(Constraint c : occs)
00458 if(c.flipSatisfies(gndAtom))
00459 delta++;
00460 return delta;
00461 }
00462
00467 public void setPSampleSAT(double p) {
00468 this.pSampleSAT = p;
00469 }
00470
00475 public void setPWalkSAT(double p) {
00476 this.pWalkSAT = p;
00477 }
00478
00479 protected abstract class Constraint {
00480 public abstract void satisfyGreedily();
00481 public abstract void satisfyRandomly();
00482 public abstract boolean flipSatisfies(GroundAtom gndAtom);
00483 public abstract void handleFlip(GroundAtom gndAtom);
00484 public abstract void initState();
00485 public abstract boolean isTrue(PossibleWorld w);
00486 }
00487
00488 protected class Clause extends Constraint {
00489 protected GroundLiteral[] lits;
00490 protected Vector<GroundAtom> gndAtoms;
00491 protected HashSet<GroundAtom> trueOnes;
00492
00493 public Clause(GroundLiteral[] lits) {
00494 this.lits = lits;
00495
00496 gndAtoms = new Vector<GroundAtom>(lits.length);
00497 trueOnes = new HashSet<GroundAtom>((lits.length+1)/2);
00498 for(GroundLiteral lit : lits) {
00499 GroundAtom gndAtom = lit.gndAtom;
00500 gndAtoms.add(gndAtom);
00501 addGAOccurrence(gndAtom, this);
00502 }
00503 }
00504
00505 public boolean isTrue(PossibleWorld w) {
00506 for(GroundLiteral lit : lits)
00507 if(lit.isTrue(w))
00508 return true;
00509 return false;
00510 }
00511
00512 @Override
00513 public void satisfyGreedily() {
00514 pickAndFlipVar(gndAtoms);
00515 }
00516
00517 public void satisfyRandomly() {
00518 boolean done = false;
00519 while(!done) {
00520
00521 GroundAtom gndAtom = this.gndAtoms.get(rand.nextInt(this.gndAtoms.size()));
00522
00523 if(evidence.containsKey(gndAtom.index))
00524 continue;
00525
00526 done = pickSecondAtRandomAndFlip(gndAtom);
00527 }
00528 }
00529
00530 @Override
00531 public boolean flipSatisfies(GroundAtom gndAtom) {
00532 return trueOnes.size() == 0;
00533 }
00534
00535 @Override
00536 public void handleFlip(GroundAtom gndAtom) {
00537 int numTrueLits = trueOnes.size();
00538 if(trueOnes.contains(gndAtom)) {
00539 trueOnes.remove(gndAtom);
00540 numTrueLits--;
00541
00542 }
00543 else {
00544 if(numTrueLits == 0)
00545 unsatisfiedConstraints.remove(this);
00546 else if(numTrueLits == 1)
00547 bottlenecks.get(trueOnes.iterator().next().index).remove(this);
00548 trueOnes.add(gndAtom);
00549 numTrueLits++;
00550 }
00551 if(numTrueLits == 1)
00552 addBottleneck(trueOnes.iterator().next(), this);
00553 }
00554
00555 @Override
00556 public String toString() {
00557 return StringTool.join(" v ", lits);
00558 }
00559
00560 @Override
00561 public void initState() {
00562 trueOnes.clear();
00563
00564 for(GroundLiteral lit : lits)
00565 if(lit.isTrue(state))
00566 trueOnes.add(lit.gndAtom);
00567
00568 if(trueOnes.size() == 0)
00569 addUnsatisfiedConstraint(this);
00570
00571
00572
00573 else if(trueOnes.size() == 1)
00574 addBottleneck(trueOnes.iterator().next(), this);
00575 }
00576
00577 public int size() {
00578 return this.lits.length;
00579 }
00580
00581 public GroundLiteral[] getLiterals() {
00582 return lits;
00583 }
00584
00585 public void removeLiteral(int idxGndAtom) {
00586 GroundLiteral[] newlits = new GroundLiteral[this.lits.length-1];
00587 gndAtoms.clear();
00588 for(int i = 0, j = 0; i < lits.length; i++)
00589 if(lits[i].gndAtom.index != idxGndAtom) {
00590 newlits[j++] = lits[i];
00591 gndAtoms.add(lits[i].gndAtom);
00592 }
00593 lits = newlits;
00594 }
00595 }
00596
00597 public static void main(String[] args) throws Exception {
00598
00599
00600
00601
00602
00603
00604 String blog = "meals_any_for.blog";
00605 String net = "meals_any_for_functional.xml";
00606 String blnfile = "meals_any_for_functional.bln";
00607 String dbfile = "lorenzExample.blogdb";
00608 BayesianLogicNetwork bln = new BayesianLogicNetwork(new ABL(blog, net), blnfile);
00609
00610 Database db = new Database(bln.rbn);
00611 db.readBLOGDB(dbfile);
00612
00613 GroundBLN gbln = new GroundBLN(bln, db);
00614 gbln.instantiateGroundNetwork();
00615
00616 PossibleWorld state = new PossibleWorld(gbln.getWorldVars());
00617 ClausalKB ckb = new ClausalKB(gbln.getKB());
00618 Stopwatch sw = new Stopwatch();
00619 sw.start();
00620 SampleSAT ss = new SampleSAT(ckb, state, gbln.getWorldVars(), gbln.getDatabase().getEntries());
00621 ss.run();
00622 sw.stop();
00623
00624
00625 System.out.println("done");
00626 state.print();
00627 System.out.println("time taken: " + sw.getElapsedTimeSecs());
00628 }
00629
00630 public ParameterHandler getParameterHandler() {
00631 return paramHandler;
00632 }
00633
00634 public String getAlgorithmName() {
00635 return String.format("%s[%f;%f]", this.getClass().getSimpleName(), pSampleSAT, pWalkSAT);
00636 }
00637 }