00001 package edu.tum.cs.srl.bayesnets.inference;
00002 import java.io.File;
00003 import java.util.ArrayList;
00004 import java.util.Collection;
00005 import java.util.Collections;
00006 import java.util.Comparator;
00007 import java.util.HashMap;
00008 import java.util.Map;
00009 import java.util.Vector;
00010 import java.util.regex.Pattern;
00011
00012 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00013 import edu.ksu.cis.bnj.ver3.core.CPF;
00014 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00015 import edu.tum.cs.bayesnets.inference.ITimeLimitedInference;
00016 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00017 import edu.tum.cs.inference.BasicSampledDistribution;
00018 import edu.tum.cs.inference.GeneralSampledDistribution;
00019 import edu.tum.cs.inference.IParameterHandler;
00020 import edu.tum.cs.inference.ParameterHandler;
00021 import edu.tum.cs.inference.BasicSampledDistribution.DistributionComparison;
00022 import edu.tum.cs.srl.Database;
00023 import edu.tum.cs.srl.bayesnets.ABL;
00024 import edu.tum.cs.srl.bayesnets.RelationalBeliefNetwork;
00025 import edu.tum.cs.srl.bayesnets.bln.AbstractBayesianLogicNetwork;
00026 import edu.tum.cs.srl.bayesnets.bln.AbstractGroundBLN;
00027 import edu.tum.cs.srl.bayesnets.bln.BayesianLogicNetwork;
00028 import edu.tum.cs.srl.bayesnets.bln.py.BayesianLogicNetworkPy;
00029 import edu.tum.cs.util.Stopwatch;
00030
00035 public class BLNinfer implements IParameterHandler {
00036 String declsFile = null;
00037 String networkFile = null;
00038 String logicFile = null;
00039 String dbFile = null;
00040 boolean useMaxSteps = false;
00041 Algorithm algo = Algorithm.LikelihoodWeighting;
00042 String[] cwPreds = null;
00043 boolean showBN = false;
00044 boolean usePython = false;
00045 boolean verbose = true;
00046 boolean saveInstance = false;
00047 boolean skipFailedSteps = false;
00048 boolean removeDeterministicCPTEntries = false;
00049 boolean resultsFilterEvidence = false;
00050 double timeLimit = 10.0, infoIntervalTime = 1.0;
00051 boolean timeLimitedInference = false;
00052 String outputDistFile = null, referenceDistFile = null;
00053 Map<String,Object> params;
00054 AbstractBayesianLogicNetwork bln = null;
00055 AbstractGroundBLN gbln = null;
00056 Database db = null;
00057 Iterable<String> queries = null;
00058 ParameterHandler paramHandler;
00059
00060 enum SortOrder implements Comparator<InferenceResult> {
00061 Atom {
00062 public int compare(InferenceResult o1, InferenceResult o2) {
00063 return o1.varName.compareTo(o2.varName);
00064 }
00065 },
00066 Probability {
00067 public int compare(InferenceResult o1, InferenceResult o2) {
00068 return -Double.compare(o1.probabilities[0], o2.probabilities[0]);
00069 }
00070 },
00071 PredicateProbability {
00072 public int compare(InferenceResult o1, InferenceResult o2) {
00073 String pred1 = o1.varName.substring(0, o1.varName.indexOf('('));
00074 String pred2 = o2.varName.substring(0, o2.varName.indexOf('('));
00075 int res = pred1.compareTo(pred2);
00076 if(res != 0)
00077 return res;
00078 else
00079 return -Double.compare(o1.probabilities[0], o2.probabilities[0]);
00080 }
00081 };
00082 };
00083 SortOrder resultsSortOrder = SortOrder.Atom;
00084
00085
00086 Collection<InferenceResult> results;
00087 double samplingTime;
00088 int stepsTaken;
00089
00090 public BLNinfer() throws Exception {
00091 this(new HashMap<String,Object>());
00092 }
00093
00094 public BLNinfer(Map<String,Object> params) throws Exception {
00095 paramHandler = new ParameterHandler(this);
00096 paramHandler.add("verbose", "setVerbose");
00097 paramHandler.add("maxSteps", "setMaxSteps");
00098 paramHandler.add("numSamples", "setMaxSteps");
00099 paramHandler.add("inferenceMethod", "setInferenceMethod");
00100 paramHandler.add("timeLimit", "setTimeLimit");
00101 this.params = params;
00102 }
00103
00104 public void setVerbose(Boolean verbose) {
00105 this.verbose = verbose;
00106 }
00107
00108 public void setMaxSteps(Integer steps) {
00109 useMaxSteps = true;
00110 }
00111
00112 public void setInferenceMethod(String methodName) {
00113 try {
00114 algo = Algorithm.valueOf(methodName);
00115 }
00116 catch(IllegalArgumentException e) {
00117 System.err.println("Error: Unknown inference algorithm '" + methodName + "'");
00118 Algorithm.printList("");
00119 System.exit(1);
00120 }
00121 }
00122
00123 public void setTimeLimit(double seconds) {
00124 timeLimitedInference = true;
00125 this.timeLimit = seconds;
00126 }
00127
00128 public void readArgs(String[] args) throws Exception {
00129
00130 for(int i = 0; i < args.length; i++) {
00131 if(args[i].equals("-b"))
00132 declsFile = args[++i];
00133 else if(args[i].equals("-x"))
00134 networkFile = args[++i];
00135 else if(args[i].equals("-l"))
00136 logicFile = args[++i];
00137 else if(args[i].equals("-q")) {
00138 String query = args[++i];
00139 Pattern comma = Pattern.compile("\\s*,\\s*");
00140 String[] candQueries = comma.split(query);
00141 Vector<String> queries = new Vector<String>();
00142 String q = "";
00143 for(int j = 0; j < candQueries.length; j++) {
00144 if(!q.equals(""))
00145 q += ",";
00146 q += candQueries[j];
00147 if(balancedParentheses(q)) {
00148 queries.add(q);
00149 q = "";
00150 }
00151 }
00152 this.queries = queries;
00153 if(!q.equals(""))
00154 throw new IllegalArgumentException("Unbalanced parentheses in queries");
00155 }
00156 else if(args[i].equals("-e"))
00157 dbFile = args[++i];
00158 else if(args[i].equals("-s"))
00159 showBN = true;
00160 else if(args[i].equals("-rfe"))
00161 resultsFilterEvidence = true;
00162 else if(args[i].equals("-nodetcpt"))
00163 removeDeterministicCPTEntries = true;
00164 else if(args[i].equals("-si"))
00165 saveInstance = true;
00166 else if(args[i].equals("-skipFailedSteps"))
00167 skipFailedSteps = true;
00168 else if(args[i].equals("-py"))
00169 usePython = true;
00170 else if(args[i].equals("-cw"))
00171 cwPreds = args[++i].split(",");
00172 else if(args[i].equals("-maxSteps")) {
00173 int steps = Integer.parseInt(args[++i]);
00174 params.put("numSamples", steps);
00175 setMaxSteps(steps);
00176 }
00177 else if(args[i].equals("-maxTrials"))
00178 params.put("maxTrials", args[++i]);
00179 else if(args[i].equals("-ia"))
00180 setInferenceMethod(args[++i]);
00181 else if(args[i].equals("-infoInterval"))
00182 params.put("infoInterval", args[++i]);
00183 else if(args[i].equals("-debug"))
00184 params.put("debug", Boolean.TRUE);
00185 else if(args[i].equals("-t")) {
00186 if(i+1 < args.length && !args[i+1].startsWith("-"))
00187 setTimeLimit(Double.parseDouble(args[++i]));
00188 else
00189 setTimeLimit(timeLimit);
00190 }
00191 else if(args[i].equals("-infoTime"))
00192 infoIntervalTime = Double.parseDouble(args[++i]);
00193 else if(args[i].equals("-od"))
00194 outputDistFile = args[++i];
00195 else if(args[i].equals("-cd"))
00196 referenceDistFile = args[++i];
00197 else if(args[i].startsWith("-O")) {
00198 String order = args[i].substring(2);
00199 if(order.equals("a"))
00200 resultsSortOrder = SortOrder.Atom;
00201 else if(order.equals("p"))
00202 resultsSortOrder = SortOrder.Probability;
00203 else if(order.equals("pp"))
00204 resultsSortOrder = SortOrder.PredicateProbability;
00205 else
00206 throw new Exception("Unknown sort order '" + order + "'");
00207 }
00208 else if(args[i].startsWith("-p") || args[i].startsWith("--")) {
00209 String[] pair = args[i].substring(2).split("=");
00210 if(pair.length != 2)
00211 throw new Exception("Argument '" + args[i] + "' for algorithm-specific parameterization is incorrectly formatted.");
00212 params.put(pair[0], pair[1]);
00213 }
00214 else
00215 throw new Exception("Unknown option " + args[i]);
00216 }
00217 }
00218
00219 public void setBLN(AbstractBayesianLogicNetwork bln) {
00220 this.bln = bln;
00221 }
00222
00223 public void setDatabase(Database db) {
00224 this.db = db;
00225 }
00226
00227 public void setQueries(Iterable<String> queries) {
00228 this.queries = queries;
00229 }
00230
00231 public void setGroundBLN(AbstractGroundBLN gbln) {
00232 this.gbln = gbln;
00233 setBLN(gbln.getBLN());
00234 setDatabase(gbln.getDatabase());
00235 }
00236
00237 public Collection<InferenceResult> run() throws Exception {
00238 if(networkFile == null && bln == null)
00239 throw new IllegalArgumentException("No fragment network given");
00240 if(dbFile == null && db == null)
00241 throw new IllegalArgumentException("No evidence given");
00242 if(declsFile == null && bln == null)
00243 throw new IllegalArgumentException("No model declarations given");
00244 if(logicFile == null && bln == null)
00245 throw new IllegalArgumentException("No logical constraints definitions given");
00246 if(queries == null)
00247 throw new IllegalArgumentException("No queries given");
00248
00249
00250 paramHandler.handle(params, false);
00251
00252
00253 RelationalBeliefNetwork blog;
00254 if(bln == null)
00255 blog = new ABL(declsFile, networkFile);
00256 else
00257 blog = bln.rbn;
00258
00259
00260 if(removeDeterministicCPTEntries) {
00261 final double lowProb = 0.001;
00262 for(BeliefNode node : blog.bn.getNodes()) {
00263 CPF cpf = node.getCPF();
00264 for(int i = 0; i < cpf.size(); i++)
00265 if(cpf.getDouble(i) == 0.0)
00266 cpf.put(i, new ValueDouble(lowProb));
00267 cpf.normalizeByDomain();
00268 }
00269 }
00270
00271
00272 if(db == null)
00273 db = new Database(blog);
00274 paramHandler.addSubhandler(db.getParameterHandler());
00275 if(dbFile != null)
00276 db.readBLOGDB(dbFile);
00277 if(cwPreds != null) {
00278 for(String predName : cwPreds)
00279 db.setClosedWorldPred(predName);
00280 }
00281
00282
00283 if(gbln == null) {
00284 if(bln == null) {
00285 if(!usePython)
00286 bln = new BayesianLogicNetwork(blog, logicFile);
00287 else
00288 bln = new BayesianLogicNetworkPy(blog, logicFile);
00289 }
00290 gbln = bln.ground(db);
00291 paramHandler.addSubhandler(gbln);
00292 gbln.instantiateGroundNetwork();
00293 }
00294 if(showBN) {
00295 gbln.getGroundNetwork().show();
00296 }
00297 if(saveInstance) {
00298 String baseName = networkFile.substring(0, networkFile.lastIndexOf('.'));
00299 gbln.getGroundNetwork().saveXMLBIF(baseName + ".instance.xml");
00300 }
00301
00302
00303 GeneralSampledDistribution referenceDist = null;
00304 if(referenceDistFile != null) {
00305 referenceDist = GeneralSampledDistribution.fromFile(new File(referenceDistFile));
00306 }
00307
00308
00309 Stopwatch sw = new Stopwatch();
00310 sw.start();
00311
00312 Sampler sampler = algo.createSampler(gbln);
00313 sampler.setQueries(queries);
00314
00315 paramHandler.addSubhandler(sampler);
00316
00317 SampledDistribution dist;
00318 if(timeLimitedInference) {
00319 if(!(sampler instanceof ITimeLimitedInference))
00320 throw new Exception(sampler.getAlgorithmName() + " does not support time-limited inference");
00321 ITimeLimitedInference tliSampler = (ITimeLimitedInference) sampler;
00322 if(!useMaxSteps)
00323 sampler.setNumSamples(Integer.MAX_VALUE);
00324 sampler.setInfoInterval(Integer.MAX_VALUE);
00325 TimeLimitedInference tli = new TimeLimitedInference(tliSampler, timeLimit, infoIntervalTime);
00326 paramHandler.addSubhandler(tli);
00327 tli.setReferenceDistribution(referenceDist);
00328 dist = tli.run();
00329 if(referenceDist != null)
00330 System.out.println("MSEs: " + tli.getMSEs());
00331 results = tli.getResults(dist);
00332 }
00333 else {
00334 dist = sampler.infer();
00335 results = sampler.getResults(dist);
00336 }
00337 this.samplingTime = sampler.getSamplingTime();
00338 this.stepsTaken = dist.steps;
00339 sw.stop();
00340
00341
00342 if(verbose) {
00343 ArrayList<InferenceResult> sortedResults = new ArrayList<InferenceResult>(results);
00344 Collections.sort(sortedResults, this.resultsSortOrder);
00345 for(InferenceResult res : sortedResults) {
00346 boolean show = true;
00347 if(resultsFilterEvidence)
00348 if(db.contains(res.varName))
00349 show = false;
00350 if(show) res.print();
00351 }
00352 }
00353
00354
00355 if(outputDistFile != null) {
00356 GeneralSampledDistribution gdist = dist.toGeneralDistribution();
00357 File f= new File(outputDistFile);
00358 gdist.write(f);
00359 GeneralSampledDistribution gdist2 = GeneralSampledDistribution.fromFile(f);
00360 gdist2.print(System.out);
00361 }
00362
00363
00364 if(referenceDist != null) {
00365 System.out.println("comparing to reference distribution...");
00366 compareDistributions(referenceDist, dist);
00367 }
00368
00369 return results;
00370 }
00371
00375 public Collection<InferenceResult> getResults() {
00376 return this.results;
00377 }
00378
00382 public double getSamplingTime() {
00383 return samplingTime;
00384 }
00385
00389 public int getNumSteps() {
00390 return stepsTaken;
00391 }
00392
00396 public static void main(String[] args) {
00397 try {
00398 BLNinfer infer = new BLNinfer();
00399 infer.readArgs(args);
00400 infer.run();
00401
00402 ParameterHandler handler = infer.getParameterHandler();
00403 Collection<String> unhandledParams = handler.getUnhandledParams();
00404 if(!unhandledParams.isEmpty())
00405 System.err.println("Warning: Some parameters could not be handled: " + unhandledParams.toString() + "; supported parameters: " + handler.getHandledParameters().toString());
00406 }
00407 catch(IllegalArgumentException e) {
00408 System.err.println(e);
00409 System.out.println("\n usage: BLNinfer <arguments>\n\n" +
00410 " required arguments:\n\n" +
00411 " -b <declarations file> declarations file (types, domains, signatures, etc.)\n" +
00412 " -x <network file> fragment network (XML-BIF or PMML)\n" +
00413 " -l <logic file> logical constraints file\n" +
00414 " -e <evidence db pattern> an evidence database file or file mask\n" +
00415 " -q <comma-sep. queries> queries (predicate names or partially grounded terms with lower-case vars)\n\n" +
00416 " options:\n\n" +
00417 " -maxSteps # the maximum number of steps to take (default: 1000 for non-time-limited inf.)\n" +
00418 " -maxTrials # the maximum number of trials per step for BN sampling algorithms (default: 5000)\n" +
00419 " -infoInterval # the number of steps after which to output a status message\n" +
00420 " -skipFailedSteps failed steps (> max trials) should just be skipped\n\n" +
00421 " -t [secs] use time-limited inference (default: 10 seconds)\n" +
00422 " -infoTime # interval in secs after which to display intermediate results (time-limited inference, default: 1.0)\n" +
00423 " -ia <name> inference algorithm selection; valid names:");
00424 Algorithm.printList(" ");
00425 System.out.println(
00426 " --<key>=<value> set algorithm-specific parameter\n" +
00427 " -debug debug mode with additional outputs\n" +
00428 " -s show ground network in editor\n" +
00429 " -si save ground network instance in BIF format (.instance.xml)\n" +
00430 " -rfe filter evidence in results\n" +
00431 " -nodetcpt remove deterministic CPT columns by replacing 0s with low prob. values\n" +
00432 " -cw <predNames> set predicates as closed-world (comma-separated list of names)\n" +
00433 " -O<a|p|pp> order printed results by atom name (a), probability (p), predicate then probability (pp)\n" +
00434 " -od <file> save output distribution to file\n" +
00435 " -cd <file> compare results of inference to reference distribution in file\n" +
00436 " -py use Python-based logic engine [deprecated]\n");
00437 System.exit(1);
00438 }
00439 catch(Exception e) {
00440 e.printStackTrace();
00441 System.exit(1);
00442 }
00443 }
00444
00445 public static boolean balancedParentheses(String s) {
00446 int n = 0;
00447 for(int i = 0; i < s.length(); i++) {
00448 if(s.charAt(i) == '(')
00449 n++;
00450 else if(s.charAt(i) == ')')
00451 n--;
00452 }
00453 return n == 0;
00454 }
00455
00456 public static void compareDistributions(BasicSampledDistribution d1, BasicSampledDistribution d2) throws Exception {
00457 BasicSampledDistribution.DistributionComparison dc = new DistributionComparison(d1, d2);
00458 dc.addEntryComparison(new BasicSampledDistribution.ErrorList(d1));
00459 dc.addEntryComparison(new BasicSampledDistribution.MeanSquaredError(d1));
00460 dc.addEntryComparison(new BasicSampledDistribution.HellingerDistance(d1));
00461 dc.compare();
00462 dc.printResults();
00463 }
00464
00465 @Override
00466 public ParameterHandler getParameterHandler() {
00467 return paramHandler;
00468 }
00469 }