00001 package edu.tum.cs.srl.mln;
00002 
00003 import java.io.BufferedReader;
00004 import java.io.File;
00005 import java.io.FileNotFoundException;
00006 import java.io.PrintStream;
00007 import java.io.StringReader;
00008 import java.util.ArrayList;
00009 import java.util.Collection;
00010 import java.util.HashMap;
00011 import java.util.Set;
00012 import java.util.TreeSet;
00013 import java.util.regex.Matcher;
00014 import java.util.regex.Pattern;
00015 
00016 import edu.tum.cs.logic.Formula;
00017 import edu.tum.cs.logic.parser.ParseException;
00018 import edu.tum.cs.srl.Database;
00019 import edu.tum.cs.srl.RelationKey;
00020 import edu.tum.cs.srl.RelationalModel;
00021 import edu.tum.cs.srl.Signature;
00022 import edu.tum.cs.srl.taxonomy.Taxonomy;
00023 import edu.tum.cs.tools.JythonInterpreter;
00024 import edu.tum.cs.util.FileUtil;
00025 
00030 public class MarkovLogicNetwork implements RelationalModel {
00031 
00032     protected File mlnFile;
00033     protected HashMap<Formula, Double> formula2weight;
00037     protected HashMap<String, Signature> signatures;
00041     protected HashMap<String, String[]> guaranteedDomainElements;
00045     protected HashMap<String, Integer> functionalPreds;
00046     double sumAbsWeights = 0;
00047 
00053     public MarkovLogicNetwork(String mlnFileLoc) throws Exception {
00054         this();
00055         read(mlnFile = new File(mlnFileLoc));
00056     }
00057     
00061     public MarkovLogicNetwork() {
00062         mlnFile = null;
00063         signatures = new HashMap<String, Signature>();
00064         functionalPreds = new HashMap<String, Integer>();        
00065         guaranteedDomainElements = new HashMap<String, String[]>();
00066         formula2weight = new HashMap<Formula, Double>();
00067     }
00068     
00073     public void addSignature(Signature sig) {
00074         signatures.put(sig.functionName, sig);
00075     }
00076     
00077     public void addFormula(Formula f, double weight) {
00078         this.formula2weight.put(f, weight);
00079     }
00080     
00081     public void addHardFormula(Formula f) {
00082         addFormula(f, getHardWeight());
00083     }
00084     
00085     public void addFunctionalDependency(String predicateName, Integer index) {
00086         this.functionalPreds.put(predicateName, index);
00087     }
00088     
00089     public void addGuaranteedDomainElements(String domain, String[] elements) {
00090         this.guaranteedDomainElements.put(domain, elements);
00091     }
00092     
00093     public Set<Formula> getFormulas() {
00094         return formula2weight.keySet();
00095     }
00096 
00102     public Signature getSignature(String predName) {
00103         return signatures.get(predName);
00104     }
00105 
00110     public Integer getFunctionallyDeterminedArgument(String predicateName) {
00111         return this.functionalPreds.get(predicateName);
00112     }
00113 
00120     public MarkovRandomField ground(Database db) throws Exception {
00121         return ground(db, true, null);
00122     }
00123     
00124     public MarkovRandomField ground(Database db, boolean storeFormulasInMRF, GroundingCallback gc) throws Exception {
00125         return new MarkovRandomField(this, db, storeFormulasInMRF, gc);
00126     }
00127 
00132     public void read(File mlnFile) throws Exception {
00133         String actLine;
00134         ArrayList<Formula> hardFormulas = new ArrayList<Formula>();
00135         
00136         
00137         String content = FileUtil.readTextFile(mlnFile);
00138         
00139         
00140         Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL);
00141         Matcher matcher = comments.matcher(content);
00142         content = matcher.replaceAll("");
00143         BufferedReader breader = new BufferedReader(new StringReader(content));
00144 
00145         String identifier = "\\w+";
00146         String constant = "(?:[A-Z]\\w*|[0-9]+)";
00147         
00148         Pattern predDecl = Pattern.compile(String.format("(%s)\\(\\s*(%s!?(?:\\s*,\\s*%s!?)*)\\s*\\)", identifier, identifier, identifier));
00149         
00150         Pattern domDecl = Pattern.compile(String.format("(%s)\\s*=\\s*\\{\\s*(%s(?:\\s*,\\s*%s)*)\\s*\\}", identifier, constant, constant));
00151         
00152         JythonInterpreter jython = null;
00153         
00154         
00155         for(actLine = breader.readLine(); breader != null && actLine != null; actLine = breader.readLine()) {
00156                 String line = actLine.trim();
00157                 if(line.length() == 0)
00158                 continue;            
00159 
00160             
00161             if(line.endsWith(".")) {
00162                 Formula f;
00163                 String strF = line.substring(0, line.length() - 1);
00164                 try {
00165                         f = Formula.fromString(strF);
00166                 }
00167                 catch(ParseException e) {
00168                         throw new Exception("The hard formula '" + strF + "' could not be parsed: " + e.toString());
00169                 }
00170                 hardFormulas.add(f);
00171                 continue;
00172             } 
00173             
00174             
00175             Matcher m = predDecl.matcher(line);
00176             if(m.matches()) {                 
00177                 String predName = m.group(1);
00178                 Signature sig = getSignature(predName);
00179                 if(sig != null) {
00180                         throw new Exception(String.format("Signature declared in line '%s' was previously declared as '%s'", line, sig.toString()));
00181                 }
00182                 String[] argTypes = m.group(2).trim().split("\\s*,\\s*");
00183                 for (int c = 0; c < argTypes.length; c++) {
00184                     
00185                     if(argTypes[c].endsWith("!")) {
00186                         argTypes[c] = argTypes[c].replace("!", "");
00187                         Integer oldValue = functionalPreds.put(predName, c);
00188                         if(oldValue != null)
00189                                 throw new Exception(String.format("Predicate '%s' was declared to have more than one functionally determined parameter", predName));
00190                         break;
00191                     }
00192                 }
00193                 sig = new Signature(predName, "boolean", argTypes);
00194                 addSignature(sig);
00195                 continue;
00196             }
00197             
00198             
00199             m = domDecl.matcher(line);
00200             if(m.matches()) { 
00201                 Pattern domName = Pattern.compile("[a-z]+\\w+");
00202                 Pattern domCont = Pattern.compile("\\{(\\s*[A-Z]+\\w*\\s*,?)+\\}");
00203                 Matcher mat = domName.matcher(line);
00204                 Matcher mat2 = domCont.matcher(line);
00205                 
00206                 if (mat.find() && mat2.find()) {
00207                     String domarg = mat2.group(0).substring(1, mat2.group(0).length() - 1);
00208                     String[] cont = domarg.trim().split("\\s*,\\s*");
00209                     addGuaranteedDomainElements(mat.group(0), cont);
00210                 }
00211                 continue;
00212             }
00213             
00214             
00215             int iSpace = line.indexOf(' ');
00216             if(iSpace == -1)
00217                 throw new Exception("This line is not a correct declaration of a weighted formula: " + line);
00218             String strWeight = line.substring(0, iSpace);
00219             Double weight = null;
00220             try {
00221                 weight = Double.parseDouble(strWeight);
00222             }            
00223             catch(NumberFormatException e) {
00224                 if(jython == null) {
00225                         jython = new JythonInterpreter();
00226                         jython.exec("from math import *");
00227                         jython.exec("def logx(x):\n  if x == 0: return -100\n  return log(x)");
00228                 }
00229                 try {
00230                         weight = jython.evalDouble(strWeight);
00231                 }
00232                 catch(Exception e2) {
00233                         throw new Exception("Could not interpret weight '" + strWeight + "': " + e2.toString());
00234                 }            
00235             }
00236             String strF = line.substring(iSpace+1).trim();
00237             Formula f;
00238             try {
00239                 f = Formula.fromString(strF);
00240             }
00241             catch(ParseException e) {
00242                 throw new Exception("The formula '" + strF + "' could not be parsed: " + e.toString());
00243             }
00244             addFormula(f, weight);
00245             sumAbsWeights += Math.abs(weight);
00246         }
00247         
00248         
00249         double hardWeight = getHardWeight();
00250         for (Formula f : hardFormulas)
00251             addFormula(f, hardWeight);
00252     }
00253 
00257     public double getHardWeight() {
00258         return sumAbsWeights + 100;
00259     }
00260 
00265     public double getdeltaMin() {
00266         double deltaMin = Double.MAX_VALUE;
00267         TreeSet<Double> weight = new TreeSet<Double>();
00268         
00269         for(double d : formula2weight.values())
00270             weight.add(d);
00271         if(weight.size() == 1)
00272                 return 1.0e-5;
00273         
00274         while(weight.iterator().hasNext()) {
00275             Double d = weight.first();
00276             weight.remove(d);
00277             if (weight.iterator().hasNext()) {
00278                 if (Math.abs(d - weight.first()) < deltaMin)
00279                     deltaMin = Math.abs(d - weight.first());
00280             }
00281         }
00282         return deltaMin;
00283     }
00284 
00290         public void replaceType(String oldType, String newType) {
00291                 for(Signature sig : signatures.values()) 
00292                         sig.replaceType(oldType, newType);              
00293         }
00294 
00295 
00299     public HashMap<String, String[]> getGuaranteedDomainElements() {
00300         return guaranteedDomainElements;
00301     }
00302 
00308     public Collection<RelationKey> getRelationKeys(String relation) {
00309         
00310         return null;
00311     }
00312     
00316     public Set<String> getFunctionalPreds() {
00317         return functionalPreds.keySet();
00318     }
00319     
00320     public Collection<Signature> getSignatures() {
00321         return this.signatures.values();
00322     }
00323 
00327         public Taxonomy getTaxonomy() {
00328                 return null;
00329         }
00330         
00331         public void write(PrintStream out) {
00332                 MLNWriter writer = new MLNWriter(out);
00333                 
00334                 
00335                 if(this.guaranteedDomainElements.size() > 0) {
00336                         out.println("// domain declarations");
00337                         for(java.util.Map.Entry<String,String[]> e : this.getGuaranteedDomainElements().entrySet()) {
00338                                 writer.writeDomainDecl(e.getKey(), e.getValue());                       
00339                         }
00340                         out.println();
00341                 }
00342                 
00343                 
00344                 out.println("// predicate declarations");
00345                 for(Signature sig : this.getSignatures()) {
00346                         writer.writePredicateDecl(sig, this.getFunctionallyDeterminedArgument(sig.functionName));
00347                 }
00348                 out.println();
00349                 
00350                 out.println("// formulas");
00351                 for(Formula f : getFormulas()) {
00352                         double w = formula2weight.get(f);
00353                         out.printf("%f  %s\n", w, f.toString());
00354                 }
00355         }
00356         
00357         public void write(File f) throws FileNotFoundException {
00358                 write(new PrintStream(f));
00359         }
00360 
00361         @Override
00362         public Collection<String> getPrologRules() {
00363                 return null;
00364         }
00365 }