00001 package edu.tum.cs.srl.mln;
00002
00003 import java.io.BufferedReader;
00004 import java.io.File;
00005 import java.io.FileNotFoundException;
00006 import java.io.PrintStream;
00007 import java.io.StringReader;
00008 import java.util.ArrayList;
00009 import java.util.Collection;
00010 import java.util.HashMap;
00011 import java.util.Set;
00012 import java.util.TreeSet;
00013 import java.util.regex.Matcher;
00014 import java.util.regex.Pattern;
00015
00016 import edu.tum.cs.logic.Formula;
00017 import edu.tum.cs.logic.parser.ParseException;
00018 import edu.tum.cs.srl.Database;
00019 import edu.tum.cs.srl.RelationKey;
00020 import edu.tum.cs.srl.RelationalModel;
00021 import edu.tum.cs.srl.Signature;
00022 import edu.tum.cs.srl.taxonomy.Taxonomy;
00023 import edu.tum.cs.tools.JythonInterpreter;
00024 import edu.tum.cs.util.FileUtil;
00025
00030 public class MarkovLogicNetwork implements RelationalModel {
00031
00032 protected File mlnFile;
00033 protected HashMap<Formula, Double> formula2weight;
00037 protected HashMap<String, Signature> signatures;
00041 protected HashMap<String, String[]> guaranteedDomainElements;
00045 protected HashMap<String, Integer> functionalPreds;
00046 double sumAbsWeights = 0;
00047
00053 public MarkovLogicNetwork(String mlnFileLoc) throws Exception {
00054 this();
00055 read(mlnFile = new File(mlnFileLoc));
00056 }
00057
00061 public MarkovLogicNetwork() {
00062 mlnFile = null;
00063 signatures = new HashMap<String, Signature>();
00064 functionalPreds = new HashMap<String, Integer>();
00065 guaranteedDomainElements = new HashMap<String, String[]>();
00066 formula2weight = new HashMap<Formula, Double>();
00067 }
00068
00073 public void addSignature(Signature sig) {
00074 signatures.put(sig.functionName, sig);
00075 }
00076
00077 public void addFormula(Formula f, double weight) {
00078 this.formula2weight.put(f, weight);
00079 }
00080
00081 public void addHardFormula(Formula f) {
00082 addFormula(f, getHardWeight());
00083 }
00084
00085 public void addFunctionalDependency(String predicateName, Integer index) {
00086 this.functionalPreds.put(predicateName, index);
00087 }
00088
00089 public void addGuaranteedDomainElements(String domain, String[] elements) {
00090 this.guaranteedDomainElements.put(domain, elements);
00091 }
00092
00093 public Set<Formula> getFormulas() {
00094 return formula2weight.keySet();
00095 }
00096
00102 public Signature getSignature(String predName) {
00103 return signatures.get(predName);
00104 }
00105
00110 public Integer getFunctionallyDeterminedArgument(String predicateName) {
00111 return this.functionalPreds.get(predicateName);
00112 }
00113
00120 public MarkovRandomField ground(Database db) throws Exception {
00121 return ground(db, true, null);
00122 }
00123
00124 public MarkovRandomField ground(Database db, boolean storeFormulasInMRF, GroundingCallback gc) throws Exception {
00125 return new MarkovRandomField(this, db, storeFormulasInMRF, gc);
00126 }
00127
00132 public void read(File mlnFile) throws Exception {
00133 String actLine;
00134 ArrayList<Formula> hardFormulas = new ArrayList<Formula>();
00135
00136
00137 String content = FileUtil.readTextFile(mlnFile);
00138
00139
00140 Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL);
00141 Matcher matcher = comments.matcher(content);
00142 content = matcher.replaceAll("");
00143 BufferedReader breader = new BufferedReader(new StringReader(content));
00144
00145 String identifier = "\\w+";
00146 String constant = "(?:[A-Z]\\w*|[0-9]+)";
00147
00148 Pattern predDecl = Pattern.compile(String.format("(%s)\\(\\s*(%s!?(?:\\s*,\\s*%s!?)*)\\s*\\)", identifier, identifier, identifier));
00149
00150 Pattern domDecl = Pattern.compile(String.format("(%s)\\s*=\\s*\\{\\s*(%s(?:\\s*,\\s*%s)*)\\s*\\}", identifier, constant, constant));
00151
00152 JythonInterpreter jython = null;
00153
00154
00155 for(actLine = breader.readLine(); breader != null && actLine != null; actLine = breader.readLine()) {
00156 String line = actLine.trim();
00157 if(line.length() == 0)
00158 continue;
00159
00160
00161 if(line.endsWith(".")) {
00162 Formula f;
00163 String strF = line.substring(0, line.length() - 1);
00164 try {
00165 f = Formula.fromString(strF);
00166 }
00167 catch(ParseException e) {
00168 throw new Exception("The hard formula '" + strF + "' could not be parsed: " + e.toString());
00169 }
00170 hardFormulas.add(f);
00171 continue;
00172 }
00173
00174
00175 Matcher m = predDecl.matcher(line);
00176 if(m.matches()) {
00177 String predName = m.group(1);
00178 Signature sig = getSignature(predName);
00179 if(sig != null) {
00180 throw new Exception(String.format("Signature declared in line '%s' was previously declared as '%s'", line, sig.toString()));
00181 }
00182 String[] argTypes = m.group(2).trim().split("\\s*,\\s*");
00183 for (int c = 0; c < argTypes.length; c++) {
00184
00185 if(argTypes[c].endsWith("!")) {
00186 argTypes[c] = argTypes[c].replace("!", "");
00187 Integer oldValue = functionalPreds.put(predName, c);
00188 if(oldValue != null)
00189 throw new Exception(String.format("Predicate '%s' was declared to have more than one functionally determined parameter", predName));
00190 break;
00191 }
00192 }
00193 sig = new Signature(predName, "boolean", argTypes);
00194 addSignature(sig);
00195 continue;
00196 }
00197
00198
00199 m = domDecl.matcher(line);
00200 if(m.matches()) {
00201 Pattern domName = Pattern.compile("[a-z]+\\w+");
00202 Pattern domCont = Pattern.compile("\\{(\\s*[A-Z]+\\w*\\s*,?)+\\}");
00203 Matcher mat = domName.matcher(line);
00204 Matcher mat2 = domCont.matcher(line);
00205
00206 if (mat.find() && mat2.find()) {
00207 String domarg = mat2.group(0).substring(1, mat2.group(0).length() - 1);
00208 String[] cont = domarg.trim().split("\\s*,\\s*");
00209 addGuaranteedDomainElements(mat.group(0), cont);
00210 }
00211 continue;
00212 }
00213
00214
00215 int iSpace = line.indexOf(' ');
00216 if(iSpace == -1)
00217 throw new Exception("This line is not a correct declaration of a weighted formula: " + line);
00218 String strWeight = line.substring(0, iSpace);
00219 Double weight = null;
00220 try {
00221 weight = Double.parseDouble(strWeight);
00222 }
00223 catch(NumberFormatException e) {
00224 if(jython == null) {
00225 jython = new JythonInterpreter();
00226 jython.exec("from math import *");
00227 jython.exec("def logx(x):\n if x == 0: return -100\n return log(x)");
00228 }
00229 try {
00230 weight = jython.evalDouble(strWeight);
00231 }
00232 catch(Exception e2) {
00233 throw new Exception("Could not interpret weight '" + strWeight + "': " + e2.toString());
00234 }
00235 }
00236 String strF = line.substring(iSpace+1).trim();
00237 Formula f;
00238 try {
00239 f = Formula.fromString(strF);
00240 }
00241 catch(ParseException e) {
00242 throw new Exception("The formula '" + strF + "' could not be parsed: " + e.toString());
00243 }
00244 addFormula(f, weight);
00245 sumAbsWeights += Math.abs(weight);
00246 }
00247
00248
00249 double hardWeight = getHardWeight();
00250 for (Formula f : hardFormulas)
00251 addFormula(f, hardWeight);
00252 }
00253
00257 public double getHardWeight() {
00258 return sumAbsWeights + 100;
00259 }
00260
00265 public double getdeltaMin() {
00266 double deltaMin = Double.MAX_VALUE;
00267 TreeSet<Double> weight = new TreeSet<Double>();
00268
00269 for(double d : formula2weight.values())
00270 weight.add(d);
00271 if(weight.size() == 1)
00272 return 1.0e-5;
00273
00274 while(weight.iterator().hasNext()) {
00275 Double d = weight.first();
00276 weight.remove(d);
00277 if (weight.iterator().hasNext()) {
00278 if (Math.abs(d - weight.first()) < deltaMin)
00279 deltaMin = Math.abs(d - weight.first());
00280 }
00281 }
00282 return deltaMin;
00283 }
00284
00290 public void replaceType(String oldType, String newType) {
00291 for(Signature sig : signatures.values())
00292 sig.replaceType(oldType, newType);
00293 }
00294
00295
00299 public HashMap<String, String[]> getGuaranteedDomainElements() {
00300 return guaranteedDomainElements;
00301 }
00302
00308 public Collection<RelationKey> getRelationKeys(String relation) {
00309
00310 return null;
00311 }
00312
00316 public Set<String> getFunctionalPreds() {
00317 return functionalPreds.keySet();
00318 }
00319
00320 public Collection<Signature> getSignatures() {
00321 return this.signatures.values();
00322 }
00323
00327 public Taxonomy getTaxonomy() {
00328 return null;
00329 }
00330
00331 public void write(PrintStream out) {
00332 MLNWriter writer = new MLNWriter(out);
00333
00334
00335 if(this.guaranteedDomainElements.size() > 0) {
00336 out.println("// domain declarations");
00337 for(java.util.Map.Entry<String,String[]> e : this.getGuaranteedDomainElements().entrySet()) {
00338 writer.writeDomainDecl(e.getKey(), e.getValue());
00339 }
00340 out.println();
00341 }
00342
00343
00344 out.println("// predicate declarations");
00345 for(Signature sig : this.getSignatures()) {
00346 writer.writePredicateDecl(sig, this.getFunctionallyDeterminedArgument(sig.functionName));
00347 }
00348 out.println();
00349
00350 out.println("// formulas");
00351 for(Formula f : getFormulas()) {
00352 double w = formula2weight.get(f);
00353 out.printf("%f %s\n", w, f.toString());
00354 }
00355 }
00356
00357 public void write(File f) throws FileNotFoundException {
00358 write(new PrintStream(f));
00359 }
00360
00361 @Override
00362 public Collection<String> getPrologRules() {
00363 return null;
00364 }
00365 }