00001 import java.io.File;
00002 import java.io.PrintStream;
00003 import java.util.Vector;
00004 import java.util.regex.Pattern;
00005
00006 import edu.tum.cs.srl.Database;
00007 import edu.tum.cs.srl.Signature;
00008 import edu.tum.cs.srl.bayesnets.ABL;
00009 import edu.tum.cs.srl.bayesnets.BLOGModel;
00010 import edu.tum.cs.srl.bayesnets.learning.CPTLearner;
00011 import edu.tum.cs.srl.bayesnets.learning.DomainLearner;
00012
00013 public class learnBLOG {
00014
00015 public static enum Mode {
00016 BLOG, ABL
00017 }
00018
00019 public static void learn(Mode mode, String[] args) {
00020 try {
00021 String acronym = mode == Mode.ABL ? "ABL" : "BLOG";
00022
00023 boolean showBN = false, learnDomains = false, ignoreUndefPreds = false, toMLN = false, debug = false, uniformDefault = false;
00024 String blogFile = null, bifFile = null, dbFile = null, outFileBLOG = null, outFileNetwork = null;
00025 boolean noNormalization = false;
00026 for(int i = 0; i < args.length; i++) {
00027 if(args[i].equals("-s"))
00028 showBN = true;
00029 else if(args[i].equals("-d"))
00030 learnDomains = true;
00031 else if(args[i].equals("-i"))
00032 ignoreUndefPreds = true;
00033 else if(args[i].equals("-b"))
00034 blogFile = args[++i];
00035 else if(args[i].equals("-x"))
00036 bifFile = args[++i];
00037 else if(args[i].equals("-t"))
00038 dbFile = args[++i];
00039 else if(args[i].equals("-ob"))
00040 outFileBLOG = args[++i];
00041 else if(args[i].equals("-ox"))
00042 outFileNetwork = args[++i];
00043 else if(args[i].equals("-mln"))
00044 toMLN = true;
00045 else if(args[i].equals("-nn"))
00046 noNormalization = true;
00047 else if(args[i].equals("-ud"))
00048 uniformDefault = true;
00049 else if(args[i].equals("-debug"))
00050 debug = true;
00051 }
00052 if(bifFile == null || dbFile == null || outFileBLOG == null || outFileNetwork == null) {
00053 System.out.println("\n usage: learn" + acronym + " [-b <" + acronym + " file>] <-x <network file>> <-t <training db pattern>> <-ob <" + acronym + " output>> <-ox <network output>> [-s] [-d]\n\n"+
00054 " -b " + acronym + " file from which to read function signatures\n" +
00055 " -s show learned fragment network\n" +
00056 " -d learn domains\n" +
00057 " -i ignore data on predicates not defined in the model\n" +
00058 " -ud apply uniform distribution by default (for CPT columns with no examples)\n" +
00059 " -nn no normalization (i.e. keep counts in CPTs)\n" +
00060 " -mln convert learnt model to a Markov logic network\n" +
00061 " -debug output debug information\n");
00062 return;
00063 }
00064
00065 BLOGModel bn;
00066 if(mode == Mode.BLOG) {
00067 if(blogFile != null)
00068 bn = new BLOGModel(blogFile, bifFile);
00069 else
00070 bn = new BLOGModel(bifFile);
00071 }
00072 else {
00073 if(blogFile != null)
00074 bn = new ABL(blogFile, bifFile);
00075 else
00076 bn = new BLOGModel(bifFile);
00077 }
00078
00079
00080 bn.prepareForLearning();
00081
00082 System.out.println("Signatures:");
00083 for(Signature sig : bn.getSignatures()) {
00084 System.out.println(" " + sig);
00085 }
00086
00087
00088 System.out.println("Reading data...");
00089 Vector<Database> dbs = new Vector<Database>();
00090 String regex = new File(dbFile).getName();
00091 Pattern p = Pattern.compile( regex );
00092 File directory = new File(dbFile).getParentFile();
00093 if(directory == null || !directory.exists())
00094 directory = new File(".");
00095 System.out.printf("Searching for '%s' in '%s'...\n", regex, directory);
00096 for (File file : directory.listFiles()) {
00097 if(p.matcher(file.getName()).matches()) {
00098 Database db = new Database(bn);
00099 System.out.printf("reading %s...\n", file.getAbsolutePath());
00100 db.readBLOGDB(file.getPath(), ignoreUndefPreds);
00101 dbs.add(db);
00102 }
00103 }
00104
00105
00106 System.out.println("Checking domains...");
00107 for(Database db : dbs)
00108 db.checkDomains(true);
00109
00110
00111 if(learnDomains) {
00112 System.out.println("Learning domains...");
00113 DomainLearner domLearner = new DomainLearner(bn);
00114 for(Database db : dbs) {
00115 domLearner.learn(db);
00116 }
00117 domLearner.finish();
00118 }
00119 System.out.println("Domains:");
00120 for(Signature sig : bn.getSignatures()) {
00121 System.out.println(" " + sig.functionName + ": " + sig.returnType + " ");
00122 }
00123
00124 boolean learnParams = true;
00125 if(learnParams) {
00126 System.out.println("Learning parameters...");
00127 CPTLearner cptLearner = new CPTLearner(bn, uniformDefault, debug);
00128
00129 for(Database db : dbs)
00130 cptLearner.learnTyped(db, true, true);
00131 if(!noNormalization)
00132 cptLearner.finish();
00133
00134 System.out.println("Writing "+ acronym + " output to " + outFileBLOG + "...");
00135 PrintStream out = new PrintStream(new File(outFileBLOG));
00136 bn.write(out);
00137 out.close();
00138
00139 System.out.println("Writing network output to " + outFileNetwork + "...");
00140 bn.save(outFileNetwork);
00141 }
00142
00143 if(toMLN) {
00144 String filename = outFileBLOG + ".mln";
00145 System.out.println("Writing MLN " + filename);
00146 PrintStream out = new PrintStream(new File(outFileBLOG + ".mln"));
00147 bn.toMLN(out, false, false, false);
00148 }
00149
00150 if(showBN) {
00151 System.out.println("Showing Bayesian network...");
00152 bn.show();
00153 }
00154 }
00155 catch(Exception e) {
00156 e.printStackTrace();
00157 }
00158 }
00159
00160 public static void main(String[] args) {
00161 learn(Mode.BLOG, args);
00162 }
00163 }