00001
00002
00003
00004
00005 package edu.tum.cs.wcsp;
00006
00007 import java.io.FileNotFoundException;
00008 import java.io.PrintStream;
00009 import java.util.ArrayList;
00010 import java.util.HashMap;
00011 import java.util.HashSet;
00012 import java.util.Iterator;
00013 import java.util.Map.Entry;
00014
00015 import edu.tum.cs.logic.Formula;
00016 import edu.tum.cs.logic.GroundAtom;
00017 import edu.tum.cs.logic.Negation;
00018 import edu.tum.cs.logic.PossibleWorld;
00019 import edu.tum.cs.logic.WorldVariables;
00020 import edu.tum.cs.logic.sat.weighted.WeightedFormula;
00021 import edu.tum.cs.srl.Database;
00022 import edu.tum.cs.srl.Signature;
00023 import edu.tum.cs.srl.mln.GroundingCallback;
00024 import edu.tum.cs.srl.mln.MarkovLogicNetwork;
00025 import edu.tum.cs.srl.mln.MarkovRandomField;
00026
00031 public class WCSPConverter implements GroundingCallback {
00032
00033 protected MarkovLogicNetwork mln;
00034 protected PossibleWorld wld;
00035 protected Double deltaMin;
00036 protected HashMap<String, HashSet<String>> doms;
00037 protected HashMap<Integer, Integer> gndID_BlockID;
00038 protected ArrayList<String> vars;
00039 protected ArrayList<String> simplifiedVars;
00040 protected HashMap<GroundAtom, Integer> gnd_varidx;
00041 protected HashMap<Integer, Integer> gnd_sfvaridx;
00042 protected HashMap<String, HashSet<GroundAtom>> vars_gnd;
00046 protected HashMap<Integer, HashSet<GroundAtom>> sfvars_gnd;
00047 protected HashMap<String, String> func_dom;
00048 protected HashMap<Integer, Integer> sfvars_vars;
00049 protected StringBuffer sb_result, sb_settings;
00050 protected PrintStream ps;
00051 protected int numConstraints = 0;
00052 protected boolean initialized = false;
00053 protected long sumSoftCosts = 0;
00054
00061 public WCSPConverter(String mlnFileLoc, String dbFileLoc) throws Exception {
00062 this.mln = new MarkovLogicNetwork(mlnFileLoc);
00063 deltaMin = mln.getdeltaMin();
00064 sb_result = new StringBuffer();
00065 Database db = new Database(mln);
00066 db.readMLNDB(dbFileLoc);
00067 mln.ground(db, false, this);
00068 }
00069
00074 public WCSPConverter(MarkovRandomField mrf) throws Exception {
00075 this.mln = mrf.mln;
00076 deltaMin = mln.getdeltaMin();
00077 sb_result = new StringBuffer();
00078 for(WeightedFormula wf : mrf) {
00079 convertFormula(wf, mrf);
00080 }
00081 }
00082
00089 public void run(String wcspFilename, String scenarioSettingsFilename) throws FileNotFoundException {
00090 if(scenarioSettingsFilename != null) {
00091
00092 sb_settings = saveSzenarioSettings(new StringBuffer());
00093 ps = new PrintStream(scenarioSettingsFilename);
00094 ps.print(sb_settings);
00095 ps.flush();
00096 ps.close();
00097 }
00098
00099
00100
00101 ps = new PrintStream(wcspFilename);
00102 generateHead(ps);
00103 ps.print(sb_result);
00104 ps.flush();
00105 ps.close();
00106 }
00107
00114 private StringBuffer saveSzenarioSettings(StringBuffer sb) {
00115
00116
00117 sb.append("Domains:" + System.getProperty("line.separator"));
00118 for (Entry e : doms.entrySet())
00119 sb.append(e.getKey().toString() + "=" + e.getValue().toString() + ";");
00120 sb.append(System.getProperty("line.separator"));
00121
00122
00123 sb.append("Variables:" + System.getProperty("line.separator"));
00124 for (String s : vars)
00125 sb.append(s + ";");
00126 sb.append(System.getProperty("line.separator"));
00127
00128
00129 sb.append("gnd_varsidx:" + System.getProperty("line.separator"));
00130 for (Entry e : gnd_varidx.entrySet())
00131 sb.append(e.getKey().toString() + "=" + e.getValue().toString() + ";");
00132 sb.append(System.getProperty("line.separator"));
00133
00134
00135 sb.append("var_domain:" + System.getProperty("line.separator"));
00136 for (Entry e : func_dom.entrySet())
00137 sb.append(e.getKey().toString() + "=" + e.getValue().toString() + ";");
00138 sb.append(System.getProperty("line.separator"));
00139
00140
00141 sb.append("vars_gnd:" + System.getProperty("line.separator"));
00142 for (Entry e : vars_gnd.entrySet())
00143 sb.append(e.getKey().toString() + "=" + e.getValue().toString() + ";");
00144 sb.append(System.getProperty("line.separator"));
00145
00146
00147 sb.append("sfvars_vars:" + System.getProperty("line.separator"));
00148 for (Entry e : sfvars_vars.entrySet())
00149 sb.append(e.getKey().toString() + "=" + e.getValue().toString() + ";");
00150 sb.append(System.getProperty("line.separator"));
00151
00152 return sb;
00153 }
00154
00158 protected void atom2var() {
00159 vars = new ArrayList<String>();
00160 gnd_varidx = new HashMap<GroundAtom, Integer>();
00161 func_dom = new HashMap<String, String>();
00162 vars_gnd = new HashMap<String, HashSet<GroundAtom>>();
00163
00164 WorldVariables ww = wld.getVariables();
00165 for (int i = 0; i < ww.size(); i++) {
00166
00167 if (ww.getBlock(ww.get(i).index) != null)
00168 atom2func(ww.get(i));
00169 else {
00170 vars.add(ww.get(i).toString());
00171 gnd_varidx.put(ww.get(i), vars.indexOf(ww.get(i).toString()));
00172
00173 func_dom.put(vars.get(vars.indexOf(ww.get(i).toString())), "boolean");
00174
00175 HashSet<GroundAtom> tmp = new HashSet<GroundAtom>();
00176 tmp.add(ww.get(i));
00177 vars_gnd.put(vars.get(vars.indexOf(ww.get(i).toString())), tmp);
00178 }
00179 }
00180 }
00181
00188 protected void atom2func(GroundAtom gnd) {
00189 String shortend = "";
00190 int x = 0;
00191
00192 for (int i = 0; i < gnd.args.length - 1; i++) {
00193 if (x++ > 0)
00194 shortend = shortend + ",";
00195 shortend = shortend + gnd.args[i];
00196 }
00197 String function = gnd.predicate + "(" + shortend + ")";
00198
00199
00200
00201 if (vars.contains(function)) {
00202 gnd_varidx.put(gnd, vars.indexOf(function));
00203 HashSet<GroundAtom> temp = vars_gnd.get(function);
00204 temp.add(gnd);
00205 } else {
00206 vars.add(function);
00207 gnd_varidx.put(gnd, vars.indexOf(function));
00208 Signature sig = mln.getSignature(gnd.predicate);
00209 func_dom.put(function, sig.argTypes[mln.getFunctionallyDeterminedArgument(gnd.predicate)]);
00210 HashSet<GroundAtom> temp = new HashSet<GroundAtom>();
00211 temp.add(gnd);
00212 vars_gnd.put(vars.get(vars.indexOf(function)), temp);
00213 }
00214 }
00215
00222 private void simplyfyVars(ArrayList<String> variables, Database db) throws Exception {
00223 sfvars_vars = new HashMap<Integer, Integer>();
00224 simplifiedVars = new ArrayList<String>();
00225 gnd_sfvaridx = new HashMap<Integer, Integer>();
00226 sfvars_gnd = new HashMap<Integer, HashSet<GroundAtom>>();
00227
00228
00229 for (int i = 0; i < variables.size(); i++) {
00230 HashSet<GroundAtom> givenAtoms = new HashSet<GroundAtom>();
00231
00232 HashSet<GroundAtom> gndAtoms = vars_gnd.get(variables.get(i));
00233 for (GroundAtom g : gndAtoms) {
00234 if (db.getVariableValue(g.toString(), false) != null)
00235 givenAtoms.add(g);
00236 }
00237
00238
00239
00240
00241 if ((gndAtoms.size() != givenAtoms.size())) {
00242
00243 int idx = simplifiedVars.size();
00244 simplifiedVars.add(variables.get(i));
00245
00246 sfvars_vars.put(idx, i);
00247
00248 for (GroundAtom g : vars_gnd.get(variables.get(i)))
00249 gnd_sfvaridx.put(g.index, idx);
00250
00251 sfvars_gnd.put(idx, (HashSet<GroundAtom>) gndAtoms.clone());
00252 }
00253 }
00254 }
00255
00262 protected void convertFormula(WeightedFormula wf) throws Exception {
00263 Formula f = wf.formula;
00264 double weight = wf.weight;
00265
00266
00267 HashSet<GroundAtom> gndAtoms = new HashSet<GroundAtom>();
00268 f.getGroundAtoms(gndAtoms);
00269
00270
00271 ArrayList<Integer> referencedVarIndices = new ArrayList<Integer>(gndAtoms.size());
00272 for (GroundAtom g : gndAtoms) {
00273
00274 int idx = gnd_sfvaridx.get(g.index);
00275 if (!referencedVarIndices.contains(idx))
00276 referencedVarIndices.add(idx);
00277 }
00278
00279
00280 ArrayList<String> settingsZero = new ArrayList<String>();
00281 ArrayList<String> settingsOther = new ArrayList<String>();
00282 convertFormula(f, referencedVarIndices, 0, wld, new int[referencedVarIndices.size()], weight, settingsZero, settingsOther);
00283 ArrayList<String> smallerSet;
00284 long cost = Math.round(wf.weight / deltaMin);
00285 long defaultCosts;
00286 if(settingsOther.size() < settingsZero.size()) {
00287 smallerSet = settingsOther;
00288
00289 defaultCosts = 0;
00290 }
00291 else {
00292 smallerSet = settingsZero;
00293
00294 defaultCosts = cost;
00295 }
00296
00297
00298 if(smallerSet.size() == 0)
00299 return;
00300
00301 numConstraints++;
00302
00303 if(!wf.isHard)
00304 sumSoftCosts += cost;
00305
00306
00307 String nl = System.getProperty("line.separator");
00308
00309 sb_result.append(referencedVarIndices.size() + " ");
00310 for(Integer in : referencedVarIndices) {
00311 sb_result.append(in + " ");
00312 }
00313 sb_result.append(defaultCosts + " " + smallerSet.size());
00314 sb_result.append(nl);
00315
00316 for (String s : smallerSet)
00317 sb_result.append(s + nl);
00318 }
00319
00324 protected void generateHead(PrintStream out) {
00325 int maxDomSize = 0;
00326
00327
00328 StringBuffer strDomSizes = new StringBuffer();
00329 for(int i = 0; i < simplifiedVars.size(); i++) {
00330 HashSet<String> domSet = doms.get(func_dom.get(simplifiedVars.get(i)));
00331 int domSize = domSet == null ? 2 : domSet.size();
00332 strDomSizes.append(domSize + " ");
00333 if(domSize > maxDomSize)
00334 maxDomSize = domSize;
00335 }
00336
00337
00338
00339 long top = sumSoftCosts+1;
00340 out.printf("WCSPfromMLN %d %d %d %d\n", simplifiedVars.size(), maxDomSize, numConstraints, top);
00341
00342 out.println(strDomSizes.toString());
00343 }
00344
00357 protected void convertFormula(Formula f, ArrayList<Integer> wcspVarIndices, int i, PossibleWorld w, int[] g, double weight, ArrayList<String> settingsZero, ArrayList<String> settingsOther) throws Exception {
00358 if(weight < 0)
00359 throw new Exception("Weights must be positive");
00360
00361 if (i == wcspVarIndices.size()) {
00362 StringBuffer zeile = new StringBuffer();
00363
00364 for (Integer c : g)
00365 zeile.append(c + " ");
00366
00367
00368 if (!f.isTrue(w)) {
00369 zeile.append(Math.round(weight / deltaMin));
00370 settingsOther.add(zeile.toString());
00371 } else {
00372 zeile.append("0");
00373 settingsZero.add(zeile.toString());
00374 }
00375 } else {
00376 int wcspVarIdx = wcspVarIndices.get(i);
00377
00378 HashSet<String> domSet = doms.get(func_dom.get(simplifiedVars.get(wcspVarIdx)));
00379 int domSize;
00380 if(domSet == null)
00381 domSize = 2;
00382 else
00383 domSize = domSet.size();
00384 for(int j = 0; j < domSize; j++) {
00385 g[i] = j;
00386 setGroundAtomState(w, wcspVarIdx, j);
00387 convertFormula(f, wcspVarIndices, i + 1, w, g, weight, settingsZero, settingsOther);
00388 }
00389 }
00390 }
00391
00398 public void setGroundAtomState(PossibleWorld w, int wcspVarIdx, int domIdx) {
00399 HashSet<GroundAtom> atoms = sfvars_gnd.get(wcspVarIdx);
00400 if(atoms.size() == 1) {
00401 w.set(atoms.iterator().next(), domIdx == 0);
00402 }
00403 else {
00404 Object[] dom = doms.get(func_dom.get(simplifiedVars.get(wcspVarIdx))).toArray();
00405 setBlockState(w, atoms, dom[domIdx].toString());
00406 }
00407
00408 }
00409
00416 protected void setBlockState(PossibleWorld w, HashSet<GroundAtom> block, String value) {
00417 int detArgIdx = this.mln.getFunctionallyDeterminedArgument(block.iterator().next().predicate);
00418 Iterator<GroundAtom> it = block.iterator();
00419 GroundAtom g;
00420 while(it.hasNext()) {
00421 g = it.next();
00422 boolean v = g.args[detArgIdx].equals(value);
00423 w.set(g.index, v);
00424 if(v)
00425 break;
00426 }
00427
00428 while(it.hasNext())
00429 w.set(it.next().index, false);
00430 }
00431
00440 public void onGroundedFormula(WeightedFormula wf, MarkovRandomField mrf) throws Exception {
00441 convertFormula(wf, mrf);
00442 }
00443
00444 public void convertFormula(WeightedFormula wf, MarkovRandomField mrf) throws Exception {
00445
00446 if(!initialized) {
00447 this.wld = new PossibleWorld(mrf.getWorldVariables());
00448 doms = mrf.getDb().getDomains();
00449 atom2var();
00450 simplyfyVars(vars, mrf.getDb());
00451 initialized = true;
00452 }
00453
00454
00455 if(wf.weight < 0) {
00456 wf.formula = new Negation(wf.formula);
00457 wf.weight *= -1;
00458 }
00459
00460
00461 convertFormula(wf);
00462 }
00463 }