00001
00002
00003
00004
00005
00006 package edu.tum.cs.logic.sat.weighted;
00007
00008 import java.io.PrintStream;
00009 import java.util.Collection;
00010 import java.util.HashMap;
00011 import java.util.HashSet;
00012 import java.util.Random;
00013 import java.util.Vector;
00014
00015 import edu.tum.cs.logic.Formula;
00016 import edu.tum.cs.logic.GroundAtom;
00017 import edu.tum.cs.logic.GroundLiteral;
00018 import edu.tum.cs.logic.PossibleWorld;
00019 import edu.tum.cs.logic.WorldVariables;
00020 import edu.tum.cs.logic.WorldVariables.Block;
00021 import edu.tum.cs.logic.sat.EvidenceHandler;
00022 import edu.tum.cs.srl.Database;
00023 import edu.tum.cs.util.StringTool;
00024
00029 public class MaxWalkSAT {
00030
00031 protected HashMap<Integer, Vector<Constraint>> bottlenecks;
00032 protected HashMap<Integer, Vector<Constraint>> GAOccurrences;
00033 protected PossibleWorld state;
00034 protected Vector<Constraint> unsatisfiedConstraints;
00035 protected Vector<Integer> nonEvidenceGndAtomIndices;
00036 protected Random rand;
00037 protected WorldVariables vars;
00038 protected HashMap<Integer, Boolean> evidence;
00039 protected EvidenceHandler evidenceHandler;
00040 protected HashMap<edu.tum.cs.logic.sat.weighted.WeightedClause, Formula> clFormula;
00041 protected HashMap<WeightedClause, Formula> cl2Formula;
00042 protected HashMap<Formula, Double> formula2weight;
00043 protected HashMap<Formula, HashSet<WeightedClause>> formula2clauses;
00044 protected HashMap<Formula, HashSet<WeightedClause>> formula2satClause;
00045 protected int countUnsCon;
00046 protected int lastMinStep;
00047 protected double unsSum;
00048 protected double unsSumBeta;
00049 public int step;
00050 protected final boolean verbose = false;
00051 protected PossibleWorld bestState;
00052 protected int greedyMoves;
00053 protected int SAMoves;
00054 protected int flips;
00055 double minSum;
00056 protected int deltaCostCalcMethod = 1;
00057 protected int maxSteps = 1000;
00061 protected double p = 0.999;
00062
00071 public MaxWalkSAT(WeightedClausalKB kb, PossibleWorld state, WorldVariables vars, Database evidence) throws Exception {
00072
00073 this.state = state;
00074 this.vars = vars;
00075 flips = 0;
00076 this.unsatisfiedConstraints = new Vector<Constraint>();
00077 cl2Formula = new HashMap<WeightedClause, Formula>();
00078 formula2weight = new HashMap<Formula, Double>();
00079 formula2clauses = new HashMap<Formula, HashSet<WeightedClause>>();
00080 clFormula = kb.getClause2Formula();
00081 nonEvidenceGndAtomIndices = new Vector<Integer>();
00082 bottlenecks = new HashMap<Integer, Vector<Constraint>>();
00083 GAOccurrences = new HashMap<Integer, Vector<Constraint>>();
00084 formula2satClause = new HashMap<Formula, HashSet<WeightedClause>>();
00085 rand = new Random();
00086
00087 evidenceHandler = new EvidenceHandler(vars, evidence.getEntries());
00088 this.evidence = evidenceHandler.getEvidence();
00089 evidenceHandler.setEvidenceInState(state);
00090
00091
00092 for (int c = 0; c < vars.size(); c++) {
00093 if (!this.evidence.containsKey(c)) {
00094 nonEvidenceGndAtomIndices.add(c);
00095 }
00096 }
00097
00098
00099 for (edu.tum.cs.logic.sat.weighted.WeightedClause c : kb) {
00100 WeightedClause wcl = new WeightedClause(c.lits, c.isHard);
00101 Formula f = clFormula.get(c);
00102
00103 cl2Formula.put(wcl, f);
00104
00105 formula2weight.put(f, c.weight);
00106
00107 if (formula2clauses.get(f) != null) {
00108 formula2clauses.get(f).add(wcl);
00109 } else {
00110 HashSet<WeightedClause> addWC = new HashSet<WeightedClause>();
00111 addWC.add(wcl);
00112 formula2clauses.put(f, addWC);
00113 }
00114 }
00115 }
00116
00117 public void setMaxSteps(int steps) {
00118 this.maxSteps = steps;
00119 }
00120
00125 protected void addUnsatisfiedConstraint(Constraint c) {
00126 unsatisfiedConstraints.add(c);
00127 }
00128
00134 protected void addBottleneck(GroundAtom ga, Constraint c) {
00135 Vector<Constraint> v = bottlenecks.get(ga.index);
00136 if (v == null) {
00137 v = new Vector<Constraint>();
00138 bottlenecks.put(ga.index, v);
00139 }
00140 v.add(c);
00141 }
00142
00148 protected void addGAOccurrence(GroundAtom ga, Constraint c) {
00149 Vector<Constraint> v = GAOccurrences.get(ga.index);
00150 if (v == null) {
00151 v = new Vector<Constraint>();
00152 GAOccurrences.put(ga.index, v);
00153 }
00154 v.add(c);
00155 }
00156
00160 protected void unsatisfiedSum() {
00161 HashSet<Formula> unsForm = new HashSet<Formula>();
00162 double sum = 0;
00163 countUnsCon = 0;
00164 for (Constraint wcl : unsatisfiedConstraints) {
00165 if (wcl.isHard()) {
00166 countUnsCon += 1;
00167 }
00168 Formula f = cl2Formula.get(wcl);
00169 unsForm.add(f);
00170 }
00171 for (Formula f : unsForm) {
00172 sum += formula2weight.get(f);
00173 }
00174 unsSum = sum;
00175 }
00176
00180 protected void initFormulaState() {
00181 for (WeightedClause wcl : cl2Formula.keySet()) {
00182 Formula parent = cl2Formula.get(wcl);
00183 HashSet<WeightedClause> wclSet = formula2satClause.get(parent);
00184 if (wclSet == null) {
00185 wclSet = new HashSet<WeightedClause>();
00186 formula2satClause.put(parent, wclSet);
00187 }
00188 if (!unsatisfiedConstraints.contains(wcl)) {
00189 wclSet.add(wcl);
00190 }
00191 }
00192 }
00193
00197 public void run() {
00198
00199 bottlenecks.clear();
00200 unsatisfiedConstraints.clear();
00201 setState();
00202 minSum = Integer.MAX_VALUE;
00203 step = 1;
00204 lastMinStep = 0;
00205 double diffSum = 0;
00206 int minSteps = 0;
00207
00208
00209 for (Constraint c : cl2Formula.keySet()) {
00210 c.initState();
00211 }
00212
00213
00214 initFormulaState();
00215
00216
00217 unsatisfiedSum();
00218
00219
00220 bestState = state.clone();
00221 while (step < maxSteps && unsatisfiedConstraints.size() > 0) {
00222
00223 diffSum = unsSum - minSum;
00224 boolean newBest = false;
00225
00226 if (unsSum <= minSum) {
00227 if (unsSum < minSum) {
00228 newBest = true;
00229
00230 minSum = unsSum;
00231 minSteps = 0;
00232
00233 bestState = state.clone();
00234
00235 lastMinStep = step;
00236 }
00237
00238 if (unsSum == minSum)
00239 minSteps++;
00240 }
00241
00242
00243 String move;
00244 if (rand.nextDouble() < p) {
00245 walkSATMove();
00246 move = "greedy";
00247 }
00248 else {
00249 SAMove();
00250 move = "SA";
00251 }
00252 step++;
00253
00254
00255 if(step % 100 == 0 || newBest) {
00256 System.out.printf(" step %d: %s move, %d hard constraints unsatisfied, sum of unsatisfied weights: %f, best: %f %s\n", step, move, countUnsCon, unsSum, minSum, newBest ? "[NEW BEST]" : "");
00257 }
00258 }
00259 }
00260
00265 public PossibleWorld getState() {
00266 return state;
00267 }
00268
00272 public void printBestState(PrintStream fr) {
00273 boolean[] s = bestState.getState();
00274 for (int c = 0; c < s.length; c++) {
00275
00276 String temp = "";
00277 if (!s[c]) {
00278 temp = "!";
00279 }
00280
00281 fr.println(temp += vars.get(c).toString());
00282 }
00283 fr.println("Unsatisfied Sum: " + minSum);
00284 }
00285
00289 protected void setState() {
00290 evidenceHandler.setRandomState(state);
00291 }
00292
00296 protected void walkSATMove() {
00297
00298 Constraint c = randomlyChosen();
00299 Vector<Object> bestGAinFormula = new Vector<Object>();
00300 double formulaDelta = 0;
00301
00302 Formula parent = cl2Formula.get(c);
00303 do {
00304 for (WeightedClause con : formula2clauses.get(parent)) {
00305 if (unsatisfiedConstraints.contains(con)) {
00306
00307 bestGAinFormula.addAll(con.greedySatisfy());
00308 }
00309 }
00310 for (Object o : bestGAinFormula) {
00311 if (o instanceof Double) {
00312
00313 formulaDelta += ((Double) o).doubleValue();
00314 }
00315 }
00316 if (formulaDelta < 0) {
00317
00318 for (Object o : bestGAinFormula) {
00319 if (o instanceof GroundAtom) {
00320 flipGndAtom((GroundAtom) o);
00321 }
00322 }
00323 } else {
00324 break;
00325 }
00326 } while (!parent.isTrue(state));
00327 }
00328
00332 protected void SAMove() {
00333 boolean done = false;
00334
00335 while (!done) {
00336
00337 int idxGA = nonEvidenceGndAtomIndices.get(rand.nextInt(nonEvidenceGndAtomIndices.size()));
00338 GroundAtom gndAtom = vars.get(idxGA), gndAtom2 = null;
00339
00340 Block block = vars.getBlock(gndAtom.index);
00341 if (block != null) {
00342 GroundAtom trueOne = block.getTrueOne(state);
00343
00344 if (gndAtom == trueOne) {
00345 Vector<GroundAtom> others = new Vector<GroundAtom>();
00346 for (GroundAtom ga : block) {
00347 if (ga != trueOne && !evidence.containsKey(ga.index)) {
00348 others.add(ga);
00349 }
00350 }
00351 if (others.isEmpty()) {
00352 continue;
00353 }
00354
00355 gndAtom2 = others.get(rand.nextInt(others.size()));
00356 } else {
00357 gndAtom2 = trueOne;
00358 }
00359 }
00360
00361 flipGndAtom(gndAtom);
00362 if (gndAtom2 != null) {
00363 flipGndAtom(gndAtom2);
00364 }
00365 done = true;
00366 }
00367 }
00368
00374 @SuppressWarnings("empty-statement")
00375 protected Vector<Object> pickAndFlipVar(Collection<GroundAtom> candidates) {
00376 GroundAtom bestGA = null, bestGASecond = null;
00377 double bestDelta = Integer.MAX_VALUE;
00378
00379 for (GroundAtom gndAtom : candidates) {
00380
00381 double delta = 0;
00382
00383 switch (deltaCostCalcMethod) {
00384 case 1:
00385 delta = deltaCost(gndAtom);
00386 break;
00387 case 2:
00388 delta = deltaCostFormula(gndAtom);
00389 break;
00390 case 3:
00391 delta = deltaCostConAndForm(gndAtom);
00392 }
00393
00394 Block block = vars.getBlock(gndAtom.index);
00395 GroundAtom secondGA = null;
00396 if (block != null) {
00397 GroundAtom trueOne = block.getTrueOne(state);
00398 double delta2 = Integer.MAX_VALUE;
00399 if (gndAtom != trueOne) {
00400
00401 secondGA = trueOne;
00402 delta2 = deltaCost(secondGA);
00403 } else {
00404 for (GroundAtom ga2 : block) {
00405 if (ga2 == gndAtom) {
00406 continue;
00407 }
00408 double d = 0;
00409
00410 switch (deltaCostCalcMethod) {
00411 case 1:
00412 d = deltaCost(ga2);
00413 break;
00414 case 2:
00415 d = deltaCostFormula(ga2);
00416 break;
00417 case 3:
00418 d = deltaCostConAndForm(ga2);
00419 }
00420
00421 if (d < delta2) {
00422 delta2 = d;
00423 secondGA = ga2;
00424 }
00425 }
00426 }
00427
00428 delta += delta2;
00429 }
00430
00431 boolean newBest = false;
00432 if (delta < bestDelta) {
00433
00434 newBest = true;
00435 } else if (delta == bestDelta && rand.nextInt(2) == 1) {
00436 newBest = true;
00437 }
00438 if (newBest) {
00439 bestGA = gndAtom;
00440 bestGASecond = secondGA;
00441 bestDelta = delta;
00442 }
00443 }
00444
00445 Vector<Object> sol = new Vector();
00446 sol.add(bestGA);
00447 sol.add(bestGASecond);
00448 sol.add(bestDelta);
00449 return sol;
00450 }
00451
00456 protected void flipGndAtom(GroundAtom gndAtom) {
00457
00458 boolean value = state.isTrue(gndAtom);
00459 state.set(gndAtom, !value);
00460
00461
00462 Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
00463 if (bn != null) {
00464 for (Constraint wcl : bn) {
00465 Formula parent = cl2Formula.get((WeightedClause) wcl);
00466 int satConFormula = formula2satClause.get(parent).size();
00467
00468 if (satConFormula == formula2clauses.get(parent).size()) {
00469 unsSum += formula2weight.get(parent);
00470 }
00471 if(wcl.isHard())
00472 this.countUnsCon++;
00473
00474 formula2satClause.get(parent).remove(wcl);
00475 }
00476
00477 this.unsatisfiedConstraints.addAll(bn);
00478 bn.clear();
00479 }
00480
00481 Vector<Constraint> occ = this.GAOccurrences.get(gndAtom.index);
00482 if (occ != null) {
00483 for (Constraint c : occ) {
00484 c.handleFlip(gndAtom);
00485 }
00486 }
00487 }
00488
00495 protected double deltaCost(GroundAtom gndAtom) {
00496 double delta = 0;
00497
00498 Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
00499 if (bn != null) {
00500 for (Constraint con : bn) {
00501 delta += con.getDelta();
00502 }
00503 }
00504
00505 for (Constraint c : this.GAOccurrences.get(gndAtom.index)) {
00506 if (c.flipSatisfies(gndAtom)) {
00507 delta -= c.getDelta();
00508 }
00509 }
00510
00511 return delta;
00512 }
00513
00520 protected double deltaCostFormula(GroundAtom gndAtom) {
00521 double delta = 0;
00522
00523 Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
00524 if (bn != null) {
00525 HashSet<Formula> checkedFormulas = new HashSet<Formula>();
00526 for (Constraint con : bn) {
00527
00528 if (!checkedFormulas.contains(cl2Formula.get(con))) {
00529 delta += con.getDeltaFormula(false);
00530 checkedFormulas.add(cl2Formula.get(con));
00531 }
00532 }
00533 }
00534
00535 for (Constraint c : this.GAOccurrences.get(gndAtom.index)) {
00536 if (c.flipSatisfies(gndAtom)) {
00537 delta -= c.getDeltaFormula(true);
00538 }
00539 }
00540 return delta;
00541 }
00542
00550 protected double deltaCostConAndForm(GroundAtom gndAtom) {
00551
00552 double delta = deltaCostFormula(gndAtom);
00553
00554 if (delta == 0) {
00555 delta += deltaCost(gndAtom);
00556 }
00557 return delta;
00558 }
00559
00564 protected Constraint randomlyChosen() {
00565 return unsatisfiedConstraints.get(rand.nextInt(unsatisfiedConstraints.size()));
00566 }
00567
00571 public void printunsCons() {
00572 for (Constraint c : unsatisfiedConstraints) {
00573 System.out.println("F: " + cl2Formula.get(c));
00574 System.out.println("unsCon: " + c.toString());
00575 }
00576 System.out.println("summe: " + unsSum);
00577 }
00578
00583 public PossibleWorld getBestState() {
00584 return bestState;
00585 }
00586
00591 public int getStep() {
00592 return step;
00593 }
00594
00599 public void setP(double p) {
00600 this.p = p;
00601 }
00602
00607 public double getP() {
00608 return p;
00609 }
00610
00618 public void setDeltaCostCalcMethod(int deltaCostCalcMethod) {
00619 this.deltaCostCalcMethod = deltaCostCalcMethod;
00620 }
00621
00622 protected abstract class Constraint {
00623
00624 public abstract Vector<Object> greedySatisfy();
00625
00626 public abstract boolean flipSatisfies(GroundAtom gndAtom);
00627
00628 public abstract void handleFlip(GroundAtom gndAtom);
00629
00630 public abstract void initState();
00631
00632 public abstract double getDelta();
00633
00634 public abstract boolean isHard();
00635
00636 public abstract double getDeltaFormula(boolean trueFlip);
00637
00638 public abstract Vector<GroundAtom> getGAsOfConstraint();
00639 }
00640
00644 protected class WeightedClause extends Constraint {
00645
00646 protected GroundLiteral[] lits;
00647 protected Vector<GroundAtom> gndAtoms;
00648 protected HashSet<GroundAtom> trueOnes;
00649 protected double weight;
00650 public boolean hard;
00651
00657 public WeightedClause(GroundLiteral[] lits, boolean hard) {
00658 this.lits = lits;
00659 this.hard = hard;
00660
00661
00662 gndAtoms = new Vector<GroundAtom>();
00663 trueOnes = new HashSet<GroundAtom>();
00664 for (GroundLiteral lit : lits) {
00665 GroundAtom gndAtom = lit.gndAtom;
00666 gndAtoms.add(gndAtom);
00667 addGAOccurrence(gndAtom, this);
00668 }
00669 }
00670
00675 @Override
00676 public Vector<Object> greedySatisfy() {
00677 return (pickAndFlipVar(gndAtoms));
00678 }
00679
00685 @Override
00686 public boolean flipSatisfies(GroundAtom gndAtom) {
00687 return trueOnes.size() == 0;
00688 }
00689
00694 @Override
00695 public void handleFlip(GroundAtom gndAtom) {
00696 int numTrueLits = trueOnes.size();
00697 Formula parent = cl2Formula.get(this);
00698
00699 if (trueOnes.contains(gndAtom)) {
00700 trueOnes.remove(gndAtom);
00701 numTrueLits--;
00702
00703 }
00704 else {
00705 if (numTrueLits == 0) {
00706
00707 unsatisfiedConstraints.remove(this);
00708 formula2satClause.get(parent).add(this);
00709
00710 if(formula2satClause.get(parent).size() == formula2clauses.get(parent).size()) {
00711 unsSum -= formula2weight.get(parent);
00712 }
00713 if(hard)
00714 countUnsCon--;
00715 }
00716 else if (numTrueLits == 1) {
00717 bottlenecks.get(trueOnes.iterator().next().index).remove(this);
00718 }
00719 trueOnes.add(gndAtom);
00720 numTrueLits++;
00721 }
00722
00723 if (numTrueLits == 1) {
00724 addBottleneck(trueOnes.iterator().next(), this);
00725 }
00726 }
00727
00732 @Override
00733 public String toString() {
00734 return StringTool.join(" v ", lits);
00735 }
00736
00740 @Override
00741 public void initState() {
00742 trueOnes.clear();
00743
00744 for (GroundLiteral lit : lits) {
00745 if (lit.isTrue(state)) {
00746 trueOnes.add(lit.gndAtom);
00747 }
00748 }
00749
00750 if (trueOnes.size() == 0) {
00751 addUnsatisfiedConstraint(this);
00752 }
00753 else if (trueOnes.size() == 1) {
00754 addBottleneck(trueOnes.iterator().next(), this);
00755 }
00756 }
00757
00762 @Override
00763 public boolean isHard() {
00764 return hard;
00765 }
00766
00771 @Override
00772 public double getDelta() {
00773 double delta = 0;
00774 Formula relatedFormula = cl2Formula.get(this);
00775
00776 delta += formula2weight.get(relatedFormula) / formula2clauses.get(relatedFormula).size();
00777 return delta;
00778 }
00779
00785 @Override
00786 public double getDeltaFormula(boolean trueFlip) {
00787 double delta = 0;
00788 Formula relatedFormula = cl2Formula.get(this);
00789 if (trueFlip) {
00790 if (formula2clauses.get(relatedFormula).size() - formula2satClause.get(relatedFormula).size() == 1) {
00791 delta += formula2weight.get(relatedFormula);
00792 }
00793 } else {
00794 if (formula2clauses.get(relatedFormula).size() == formula2satClause.get(relatedFormula).size()) {
00795 delta += formula2weight.get(relatedFormula);
00796 }
00797 }
00798 return delta;
00799 }
00800
00805 @Override
00806 public Vector<GroundAtom> getGAsOfConstraint() {
00807 return gndAtoms;
00808 }
00809 }
00810 }
00811