00001
00002
00003
00004
00005
00006 package edu.tum.cs.srl.mln;
00007
00008 import java.io.PrintStream;
00009 import java.util.Iterator;
00010 import java.util.Vector;
00011
00012 import edu.tum.cs.logic.Formula;
00013 import edu.tum.cs.logic.GroundAtom;
00014 import edu.tum.cs.logic.IPossibleWorld;
00015 import edu.tum.cs.logic.WorldVariables;
00016 import edu.tum.cs.logic.sat.weighted.WeightedFormula;
00017 import edu.tum.cs.srl.AbstractVariable;
00018 import edu.tum.cs.srl.Database;
00019 import edu.tum.cs.srl.Signature;
00020
00025 public class MarkovRandomField implements Iterable<WeightedFormula> {
00026 protected Database db;
00027 public MarkovLogicNetwork mln;
00028 protected Vector<WeightedFormula> weightedFormulas;
00029 protected WorldVariables vars;
00033 protected final boolean simplifyGroundedFormulas = true;
00034
00043 public MarkovRandomField(MarkovLogicNetwork mln, Database db, boolean storeFormulas, GroundingCallback gc) throws Exception {
00044 this.db = db;
00045 this.vars = new WorldVariables();
00046 this.mln = mln;
00047 groundVariables();
00048 groundFormulas(storeFormulas, gc);
00049 }
00050
00051 public MarkovRandomField(MarkovLogicNetwork mln, Database db) throws Exception {
00052 this(mln, db, true, null);
00053 }
00054
00059 public WorldVariables getWorldVariables() {
00060 return vars;
00061 }
00062
00067 protected void groundVariables() throws Exception {
00068 for(Signature sig : mln.getSignatures()) {
00069 groundVariables(sig, new String[sig.argTypes.length], 0, mln.getFunctionallyDeterminedArgument(sig.functionName));
00070 }
00071 }
00072
00073 protected void groundVariables(Signature sig, String[] args, int i, Integer functionallyDeterminedArg) throws Exception {
00074 if(i == args.length) {
00075 if(functionallyDeterminedArg != null) {
00076
00077 Vector<GroundAtom> block = new Vector<GroundAtom>();
00078 Iterable<String> dom = db.getDomain(sig.argTypes[functionallyDeterminedArg]);
00079 GroundAtom trueOne = null;
00080 for(String value : dom) {
00081 args[functionallyDeterminedArg] = value;
00082 GroundAtom ga = new GroundAtom(sig.functionName, args.clone());
00083 block.add(ga);
00084 AbstractVariable var = db.getVariable(ga.toString());
00085 if(var != null && var.isTrue()) {
00086 if(trueOne != null)
00087 throw new Exception(String.format("The block the variable '%s' is in contains more than one true ground atom", ga.toString()));
00088 trueOne = ga;
00089 }
00090 }
00091
00092 if(trueOne != null) {
00093 for(GroundAtom ga : block)
00094 if(ga != trueOne && !db.contains(ga.toString()))
00095 db.addVariable(new Database.Variable(ga.predicate, ga.args, "False", mln));
00096 }
00097
00098 vars.addBlock(block);
00099
00100 }
00101 else {
00102 vars.add(new GroundAtom(sig.functionName, args.clone()));
00103 }
00104 return;
00105 }
00106 if(functionallyDeterminedArg != null && functionallyDeterminedArg.equals(i)) {
00107 groundVariables(sig, args, i+1, functionallyDeterminedArg);
00108 }
00109 else {
00110 Iterable<String> dom = db.getDomain(sig.argTypes[i]);
00111 if(dom == null)
00112 throw new Exception("Domain '" + sig.argTypes[i] + "' not found in the database");
00113 for(String value : dom) {
00114 args[i] = value;
00115 groundVariables(sig, args, i+1, functionallyDeterminedArg);
00116 }
00117 }
00118 }
00119
00126 protected void groundFormulas(boolean makelist, GroundingCallback gc) throws Exception {
00127 weightedFormulas = new Vector<WeightedFormula>();
00128 for(Formula form : mln.getFormulas()) {
00129 Double weight = mln.formula2weight.get(form);
00130 if(weight == null)
00131 throw new Exception(String.format("MLN does not assign a weight to '%s'; mapped formulas are %s.", form.toString(), mln.formula2weight.keySet().toString()));
00132 boolean isHard = weight.equals(mln.getHardWeight());
00133 for(Formula gf : form.getAllGroundings(db, vars, simplifyGroundedFormulas)) {
00134 WeightedFormula wf = new WeightedFormula(gf, weight, isHard);
00135 if(makelist)
00136 weightedFormulas.add(wf);
00137 if(gc != null)
00138 gc.onGroundedFormula(wf, this);
00139 }
00140 }
00141 }
00142
00147 public Database getDb() {
00148 return db;
00149 }
00150
00151 public Iterator<WeightedFormula> iterator() {
00152 return weightedFormulas.iterator();
00153 }
00154
00155 public void print(PrintStream out) {
00156 for(WeightedFormula wf : this)
00157 out.println(wf.toString());
00158 }
00159
00164 public double getWorldValue(IPossibleWorld w) {
00165 double s = 0;
00166 for(WeightedFormula wf : this)
00167 if(wf.formula.isTrue(w))
00168 s += wf.weight;
00169 return s;
00170 }
00171 }