00001 package edu.tum.cs.srl.bayesnets;
00002 
00003 import java.io.PrintStream;
00004 import java.util.Collection;
00005 import java.util.Vector;
00006 import java.util.regex.Matcher;
00007 import java.util.regex.Pattern;
00008 
00009 import edu.tum.cs.srl.Database;
00010 import edu.tum.cs.srl.RelationKey;
00011 import edu.tum.cs.srl.Signature;
00012 import edu.tum.cs.srl.bayesnets.learning.CPTLearner;
00013 import edu.tum.cs.srl.bayesnets.learning.DomainLearner;
00014 import edu.tum.cs.srl.taxonomy.Concept;
00015 import edu.tum.cs.srl.taxonomy.Taxonomy;
00016 import edu.tum.cs.util.StringTool;
00017 
00023 public class ABL extends BLOGModel {
00024 
00025         public ABL(String[] blogFiles, String networkFile) throws Exception {
00026                 super(blogFiles, networkFile);
00027         }
00028 
00029         public ABL(String blogFile, String networkFile) throws Exception {
00030                 this(new String[] { blogFile }, networkFile);
00031         }
00032 
00033         @Override
00034         protected boolean readDeclaration(String line) throws Exception {
00035                 if (super.readDeclaration(line))
00036                         return true;
00037                 
00038                 if (line.startsWith("relationKey") || line.startsWith("RelationKey")) {
00039                         Pattern pat = Pattern
00040                                         .compile("[Rr]elationKey\\s+(\\w+)\\s*\\((.*)\\)\\s*;?");
00041                         Matcher matcher = pat.matcher(line);
00042                         if (matcher.matches()) {
00043                                 String relation = matcher.group(1);
00044                                 String[] arguments = matcher.group(2).trim().split("\\s*,\\s*");
00045                                 addRelationKey(new RelationKey(relation, arguments));
00046                                 return true;
00047                         }
00048                         return false;
00049                 }
00050                 
00051                 if (line.startsWith("type") || line.startsWith("Type")) {
00052                         if (taxonomy == null)
00053                                 taxonomy = new Taxonomy();
00054                         Pattern pat = Pattern.compile("[Tt]ype\\s+(.*?);?");
00055                         Matcher matcher = pat.matcher(line);
00056                         Pattern typeDecl = Pattern.compile("(\\w+)(?:\\s+isa\\s+(\\w+))?");
00057                         if (matcher.matches()) {
00058                                 String[] decls = matcher.group(1).split("\\s*,\\s*");
00059                                 for (String d : decls) {
00060                                         Matcher m = typeDecl.matcher(d);
00061                                         if (m.matches()) {
00062                                                 Concept c = new Concept(m.group(1));
00063                                                 taxonomy.addConcept(c);
00064                                                 if (m.group(2) != null) {
00065                                                         Concept parent = taxonomy.getConcept(m.group(2));
00066                                                         if (parent == null)
00067                                                                 throw new Exception(
00068                                                                                 "Error in declaration of type '"
00069                                                                                                 + m.group(1)
00070                                                                                                 + "': The parent type '"
00071                                                                                                 + m.group(2)
00072                                                                                                 + "' is undeclared.");
00073                                                         c.setParent(parent);
00074                                                 }
00075                                                 return true;
00076                                         } else
00077                                                 throw new Exception("The type declaration '" + d
00078                                                                 + "' is invalid");
00079                                 }
00080                         }
00081                         return false;
00082                 }
00083                 
00084                 if (line.startsWith("prolog")) {
00085                         String rule = line.substring(6).trim();
00086                         if (!rule.endsWith("."))
00087                                 rule += ".";
00088                         prologRules.add(rule);
00089                         return true;
00090                 }
00091                 
00092                 if(line.startsWith("combining-rule")) {
00093                         Pattern pat = Pattern.compile("combining-rule\\s+(\\w+)\\s+([-\\w]+)\\s*;?");
00094                         Matcher matcher = pat.matcher(line);
00095                         if(matcher.matches()) {
00096                                 String function = matcher.group(1);
00097                                 String strRule = matcher.group(2);
00098                                 Signature sig = getSignature(function);
00099                                 CombiningRule rule;
00100                                 if(sig == null) 
00101                                         throw new Exception("Defined combining rule for unknown function '" + function + "'");
00102                                 try {
00103                                         rule = CombiningRule.fromString(strRule);
00104                                 }
00105                                 catch(IllegalArgumentException e) {
00106                                         Vector<String> v = new Vector<String>();
00107                                         for(CombiningRule cr : CombiningRule.values()) 
00108                                                 v.add(cr.stringRepresention);
00109                                         throw new Exception("Invalid combining rule '" + strRule + "'; valid options: " + StringTool.join(", ", v));
00110                                 }
00111                                 this.combiningRules.put(function, rule);
00112                                 return true;
00113                         }
00114                 }
00115                 return false;
00116         }
00117 
00118         @Override
00119         protected void writeDeclarations(PrintStream out) {
00120                 super.writeDeclarations(out);
00121 
00122                 
00123                 for (Collection<RelationKey> ckey : this.relationKeys.values()) {
00124                         for (RelationKey key : ckey) {
00125                                 out.println("relationKey " + key.toString());
00126                         }
00127                 }
00128                 out.println();
00129         }
00130 
00131         public static void main(String[] args) {
00132                 try {
00133                         String bifFile = "abl/kitchen-places/actseq.xml";
00134                         ABL bn = new ABL(new String[] { "abl/kitchen-places/actseq.abl" },
00135                                         bifFile);
00136                         String dbFile = "abl/kitchen-places/train.blogdb";
00137                         
00138                         System.out.println("Reading data...");
00139                         Database db = new Database(bn);
00140                         db.readBLOGDB(dbFile);
00141                         System.out.println("  " + db.getEntries().size()
00142                                         + " variables read.");
00143                         
00144                         if (true) {
00145                                 System.out.println("Learning domains...");
00146                                 DomainLearner domLearner = new DomainLearner(bn);
00147                                 domLearner.learn(db);
00148                                 domLearner.finish();
00149                         }
00150                         
00151                         System.out.println("Learning parameters...");
00152                         CPTLearner cptLearner = new CPTLearner(bn);
00153                         cptLearner.learnTyped(db, true, true);
00154                         cptLearner.finish();
00155                         System.out.println("Writing XML-BIF output...");
00156                         bn.saveXMLBIF(bifFile);
00157                         if (true) {
00158                                 System.out.println("Showing Bayesian network...");
00159                                 bn.show();
00160                         }
00161                 } catch (Exception e) {
00162                         e.printStackTrace();
00163                 }
00164         }
00165 
00166         public void write(PrintStream out) throws Exception {
00167                 super.writeDeclarations(out);
00168         }
00169 }