00001 import java.io.BufferedReader;
00002 import java.io.File;
00003 import java.io.StringReader;
00004 import java.util.Arrays;
00005 import java.util.Vector;
00006 import java.util.regex.Matcher;
00007 import java.util.regex.Pattern;
00008
00009 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00010 import edu.ksu.cis.bnj.ver3.core.CPF;
00011 import edu.ksu.cis.bnj.ver3.core.Discrete;
00012 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00013 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00014 import edu.tum.cs.bayesnets.inference.Algorithm;
00015 import edu.tum.cs.bayesnets.inference.ITimeLimitedInference;
00016 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00017 import edu.tum.cs.bayesnets.inference.Sampler;
00018 import edu.tum.cs.bayesnets.inference.TimeLimitedInference;
00019 import edu.tum.cs.inference.GeneralSampledDistribution;
00020 import edu.tum.cs.srl.bayesnets.inference.BLNinfer;
00021 import edu.tum.cs.util.FileUtil;
00022 import edu.tum.cs.util.Stopwatch;
00023
00024
00025 public class BNinfer {
00026
00030 public static void main(String[] args) {
00031 try {
00032 String networkFile = null;
00033 String dbFile = null;
00034 String query = null;
00035 int maxSteps = 1000;
00036 int maxTrials = 5000;
00037 int infoInterval = 100;
00038 Algorithm algo = Algorithm.LikelihoodWeighting;
00039 boolean debug = false;
00040 boolean skipFailedSteps = false;
00041 boolean removeDeterministicCPTEntries = false;
00042 double timeLimit = 10.0, infoIntervalTime = 1.0;
00043 boolean timeLimitedInference = false;
00044 boolean useMaxSteps = false;
00045 String outputDistFile = null, referenceDistFile = null;
00046
00047
00048 for(int i = 0; i < args.length; i++) {
00049 if(args[i].equals("-n"))
00050 networkFile = args[++i];
00051 else if(args[i].equals("-q"))
00052 query = args[++i];
00053 else if(args[i].equals("-e"))
00054 dbFile = args[++i];
00055 else if(args[i].equals("-nodetcpt"))
00056 removeDeterministicCPTEntries = true;
00057 else if(args[i].equals("-skipFailedSteps"))
00058 skipFailedSteps = true;
00059 else if(args[i].equals("-maxSteps")) {
00060 maxSteps = Integer.parseInt(args[++i]);
00061 useMaxSteps = true;
00062 }
00063 else if(args[i].equals("-maxTrials"))
00064 maxTrials = Integer.parseInt(args[++i]);
00065 else if(args[i].equals("-ia")) {
00066 try {
00067 algo = Algorithm.valueOf(args[++i]);
00068 }
00069 catch(IllegalArgumentException e) {
00070 System.err.println("Error: Unknown inference algorithm '" + args[i] + "'");
00071 System.exit(1);
00072 }
00073 }
00074 else if(args[i].equals("-infoInterval"))
00075 infoInterval = Integer.parseInt(args[++i]);
00076 else if(args[i].equals("-debug"))
00077 debug = true;
00078 else if(args[i].equals("-t")) {
00079 timeLimitedInference = true;
00080 if(i+1 < args.length && !args[i+1].startsWith("-"))
00081 timeLimit = Double.parseDouble(args[++i]);
00082 }
00083 else if(args[i].equals("-od"))
00084 outputDistFile = args[++i];
00085 else if(args[i].equals("-cd"))
00086 referenceDistFile = args[++i];
00087 else
00088 System.err.println("Warning: unknown option " + args[i] + " ignored!");
00089 }
00090 if(networkFile == null || dbFile == null || query == null) {
00091 System.out.println("\n usage: BNinfer <arguments>\n\n" +
00092 " required arguments:\n\n" +
00093 " -n <network file> fragment network (XML-BIF or PMML)\n" +
00094 " -e <evidence db pattern> an evidence database file or file mask\n" +
00095 " -q <comma-sep. queries> queries (predicate names or partially grounded terms with lower-case vars)\n\n" +
00096 " options:\n\n" +
00097 " -maxSteps # the maximum number of steps to take, where applicable (default: 1000)\n" +
00098 " -maxTrials # the maximum number of trials per step for BN sampling algorithms (default: 5000)\n" +
00099 " -infoInterval # the number of steps after which to output a status message\n" +
00100 " -skipFailedSteps failed steps (> max trials) should just be skipped\n\n" +
00101 " -t [secs] use time-limited inference (default: 10 seconds)\n" +
00102 " -infoTime # interval in secs after which to display intermediate results (time-limited inference, default: 1.0)\n" +
00103 " -ia <name> inference algorithm selection; valid names:");
00104 for(Algorithm a : Algorithm.values())
00105 System.out.printf(" %-28s %s\n", a.toString(), a.getDescription());
00106 System.out.println(
00107 " -od <file> save output distribution to file\n" +
00108 " -cd <file> compare results of inference to reference distribution in file\n" +
00109 " -debug debug mode with additional outputs\n" +
00110 " -nodetcpt remove deterministic CPT columns by replacing 0s with low prob. values\n");
00111
00112 System.exit(1);
00113 }
00114
00115
00116 Pattern comma = Pattern.compile("\\s*,\\s*");
00117 String[] candQueries = comma.split(query);
00118 Vector<String> queries = new Vector<String>();
00119 String q = "";
00120 for(int i = 0; i < candQueries.length; i++) {
00121 if(!q.equals(""))
00122 q += ",";
00123 q += candQueries[i];
00124 if(balancedParentheses(q)) {
00125 queries.add(q);
00126 q = "";
00127 }
00128 }
00129 if(!q.equals(""))
00130 throw new IllegalArgumentException("Unbalanced parentheses in queries");
00131
00132
00133 BeliefNetworkEx bn = new BeliefNetworkEx(networkFile);
00134 BeliefNode[] nodes = bn.bn.getNodes();
00135
00136
00137 if(removeDeterministicCPTEntries) {
00138 final double lowProb = 0.001;
00139 for(BeliefNode node : nodes) {
00140 CPF cpf = node.getCPF();
00141 for(int i = 0; i < cpf.size(); i++)
00142 if(cpf.getDouble(i) == 0.0)
00143 cpf.put(i, new ValueDouble(lowProb));
00144 cpf.normalizeByDomain();
00145 }
00146 }
00147
00148
00149 int[] evidenceDomainIndices = new int[nodes.length];
00150 Arrays.fill(evidenceDomainIndices, -1);
00151 for(int i = 0; i < nodes.length; i++) {
00152
00153 String dbContent = FileUtil.readTextFile(dbFile);
00154
00155 Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL);
00156 Matcher matcher = comments.matcher(dbContent);
00157 dbContent = matcher.replaceAll("");
00158
00159 BufferedReader br = new BufferedReader(new StringReader(dbContent));
00160 String line;
00161 while((line = br.readLine()) != null) {
00162 line = line.trim();
00163 if(line.length() > 0) {
00164 String[] entry = line.split("\\s*=\\s*");
00165 if(entry.length != 2)
00166 throw new Exception("Incorrectly formatted evidence entry: " + line);
00167 BeliefNode node = bn.getNode(entry[0]);
00168 if(node == null)
00169 throw new Exception("Evidence node '" + entry[0] + "' not found in model.");
00170 Discrete dom = (Discrete)node.getDomain();
00171 int domidx = dom.findName(entry[1]);
00172 if(domidx == -1)
00173 throw new Exception("Value '" + entry[1] + "' not found in domain of node '" + entry[0] + "'");
00174 evidenceDomainIndices[bn.getNodeIndex(node)] = domidx;
00175 }
00176 }
00177 }
00178
00179
00180 GeneralSampledDistribution referenceDist = null;
00181 if(referenceDistFile != null) {
00182 referenceDist = GeneralSampledDistribution.fromFile(new File(referenceDistFile));
00183 }
00184
00185
00186 Vector<Integer> queryVars = new Vector<Integer>();
00187 for(String qq : queries) {
00188 int varIdx = bn.getNodeIndex(qq);
00189 if(varIdx == -1)
00190 throw new Exception("Unknown variable '" + qq + "'");
00191 queryVars.add(varIdx);
00192 }
00193
00194
00195 Stopwatch sw = new Stopwatch();
00196 sw.start();
00197
00198 Sampler sampler = algo.createSampler(bn);
00199
00200 sampler.setEvidence(evidenceDomainIndices);
00201 sampler.setQueryVars(queryVars);
00202 sampler.setDebugMode(debug);
00203 sampler.setMaxTrials(maxTrials);
00204 sampler.setSkipFailedSteps(skipFailedSteps);
00205 sampler.setNumSamples(maxSteps);
00206 sampler.setInfoInterval(infoInterval);
00207
00208 SampledDistribution dist = null;
00209 if(timeLimitedInference) {
00210 if(!(sampler instanceof ITimeLimitedInference))
00211 throw new Exception(sampler.getAlgorithmName() + " does not support time-limited inference");
00212 ITimeLimitedInference tliSampler = (ITimeLimitedInference) sampler;
00213 if(!useMaxSteps)
00214 sampler.setNumSamples(Integer.MAX_VALUE);
00215 sampler.setInfoInterval(Integer.MAX_VALUE);
00216 TimeLimitedInference tli = new TimeLimitedInference(tliSampler, timeLimit, infoIntervalTime);
00217 tli.setReferenceDistribution(referenceDist);
00218 dist = tli.run();
00219 if(referenceDist != null)
00220 System.out.println("MSEs: " + tli.getMSEs());
00221 }
00222 else
00223 dist = sampler.infer();
00224 sw.stop();
00225
00226
00227 for(String qq : queries) {
00228 int varIdx = bn.getNodeIndex(qq);
00229 dist.printVariableDistribution(System.out, varIdx);
00230 }
00231
00232
00233 if(outputDistFile != null) {
00234 GeneralSampledDistribution gdist = dist.toGeneralDistribution();
00235 File f= new File(outputDistFile);
00236 gdist.write(f);
00237 GeneralSampledDistribution gdist2 = GeneralSampledDistribution.fromFile(f);
00238 gdist2.print(System.out);
00239 }
00240
00241
00242 if(referenceDist != null) {
00243 System.out.println("comparing to reference distribution...");
00244 BLNinfer.compareDistributions(referenceDist, dist);
00245 }
00246 }
00247 catch(Exception e) {
00248 e.printStackTrace();
00249 }
00250 }
00251
00252 public static boolean balancedParentheses(String s) {
00253 int n = 0;
00254 for(int i = 0; i < s.length(); i++) {
00255 if(s.charAt(i) == '(')
00256 n++;
00257 else if(s.charAt(i) == ')')
00258 n--;
00259 }
00260 return n == 0;
00261 }
00262 }