00001 import java.io.BufferedReader;
00002 import java.io.File;
00003 import java.io.StringReader;
00004 import java.util.Arrays;
00005 import java.util.Vector;
00006 import java.util.regex.Matcher;
00007 import java.util.regex.Pattern;
00008 
00009 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00010 import edu.ksu.cis.bnj.ver3.core.CPF;
00011 import edu.ksu.cis.bnj.ver3.core.Discrete;
00012 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00013 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00014 import edu.tum.cs.bayesnets.inference.Algorithm;
00015 import edu.tum.cs.bayesnets.inference.ITimeLimitedInference;
00016 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00017 import edu.tum.cs.bayesnets.inference.Sampler;
00018 import edu.tum.cs.bayesnets.inference.TimeLimitedInference;
00019 import edu.tum.cs.inference.GeneralSampledDistribution;
00020 import edu.tum.cs.srl.bayesnets.inference.BLNinfer;
00021 import edu.tum.cs.util.FileUtil;
00022 import edu.tum.cs.util.Stopwatch;
00023 
00024 
00025 public class BNinfer {
00026         
00030         public static void main(String[] args) {
00031                 try {
00032                         String networkFile = null;
00033                         String dbFile = null;
00034                         String query = null;
00035                         int maxSteps = 1000;
00036                         int maxTrials = 5000;
00037                         int infoInterval = 100;
00038                         Algorithm algo = Algorithm.LikelihoodWeighting;
00039                         boolean debug = false;
00040                         boolean skipFailedSteps = false;
00041                         boolean removeDeterministicCPTEntries = false;
00042                         double timeLimit = 10.0, infoIntervalTime = 1.0;
00043                         boolean timeLimitedInference = false;
00044                         boolean useMaxSteps = false;
00045                         String outputDistFile = null, referenceDistFile = null;
00046                         
00047                         
00048                         for(int i = 0; i < args.length; i++) {
00049                                 if(args[i].equals("-n"))
00050                                         networkFile = args[++i];
00051                                 else if(args[i].equals("-q"))
00052                                         query = args[++i];
00053                                 else if(args[i].equals("-e"))
00054                                         dbFile = args[++i];                             
00055                                 else if(args[i].equals("-nodetcpt"))
00056                                         removeDeterministicCPTEntries = true;                           
00057                                 else if(args[i].equals("-skipFailedSteps"))
00058                                         skipFailedSteps = true;                         
00059                                 else if(args[i].equals("-maxSteps")) {
00060                                         maxSteps = Integer.parseInt(args[++i]);
00061                                         useMaxSteps = true;
00062                                 }
00063                                 else if(args[i].equals("-maxTrials"))
00064                                         maxTrials = Integer.parseInt(args[++i]);
00065                                 else if(args[i].equals("-ia")) {
00066                                         try {
00067                                                 algo = Algorithm.valueOf(args[++i]);
00068                                         }
00069                                         catch(IllegalArgumentException e) {
00070                                                 System.err.println("Error: Unknown inference algorithm '" + args[i] + "'");
00071                                                 System.exit(1);
00072                                         }
00073                                 }
00074                                 else if(args[i].equals("-infoInterval"))
00075                                         infoInterval = Integer.parseInt(args[++i]);
00076                                 else if(args[i].equals("-debug"))
00077                                         debug = true;
00078                                 else if(args[i].equals("-t")) {
00079                                         timeLimitedInference = true;
00080                                         if(i+1 < args.length && !args[i+1].startsWith("-"))
00081                                                 timeLimit = Double.parseDouble(args[++i]);                                      
00082                                 }
00083                                 else if(args[i].equals("-od"))
00084                                         outputDistFile = args[++i];     
00085                                 else if(args[i].equals("-cd"))
00086                                         referenceDistFile = args[++i];
00087                                 else
00088                                         System.err.println("Warning: unknown option " + args[i] + " ignored!");
00089                         }                       
00090                         if(networkFile == null || dbFile == null || query == null) {
00091                                 System.out.println("\n usage: BNinfer <arguments>\n\n" +
00092                                                              "   required arguments:\n\n" +
00093                                                              "     -n <network file>         fragment network (XML-BIF or PMML)\n" + 
00094                                                              "     -e <evidence db pattern>  an evidence database file or file mask\n" +
00095                                                              "     -q <comma-sep. queries>   queries (predicate names or partially grounded terms with lower-case vars)\n\n" +
00096                                                              "   options:\n\n" +
00097                                                                          "     -maxSteps #      the maximum number of steps to take, where applicable (default: 1000)\n" +
00098                                                                          "     -maxTrials #     the maximum number of trials per step for BN sampling algorithms (default: 5000)\n" +
00099                                                                          "     -infoInterval #  the number of steps after which to output a status message\n" +
00100                                                                          "     -skipFailedSteps failed steps (> max trials) should just be skipped\n\n" +       
00101                                                                          "     -t [secs]        use time-limited inference (default: 10 seconds)\n" +
00102                                                                          "     -infoTime #      interval in secs after which to display intermediate results (time-limited inference, default: 1.0)\n" +
00103                                                                          "     -ia <name>       inference algorithm selection; valid names:");
00104                                 for(Algorithm a : Algorithm.values()) 
00105                                         System.out.printf("                        %-28s  %s\n", a.toString(), a.getDescription());                             
00106                                 System.out.println(
00107                                                                  "     -od <file>       save output distribution to file\n" +
00108                                                                  "     -cd <file>       compare results of inference to reference distribution in file\n" + 
00109                                                                          "     -debug           debug mode with additional outputs\n" + 
00110                                                                  "     -nodetcpt        remove deterministic CPT columns by replacing 0s with low prob. values\n");
00111                                 
00112                                 System.exit(1);
00113                         }                       
00114 
00115                         
00116                         Pattern comma = Pattern.compile("\\s*,\\s*");
00117                         String[] candQueries = comma.split(query);
00118                         Vector<String> queries = new Vector<String>();
00119                         String q = "";
00120                         for(int i = 0; i < candQueries.length; i++) {
00121                                 if(!q.equals(""))
00122                                         q += ",";
00123                                 q += candQueries[i];
00124                                 if(balancedParentheses(q)) {
00125                                         queries.add(q);
00126                                         q = "";
00127                                 }
00128                         }
00129                         if(!q.equals(""))
00130                                 throw new IllegalArgumentException("Unbalanced parentheses in queries");
00131 
00132                         
00133                         BeliefNetworkEx bn = new BeliefNetworkEx(networkFile);
00134                         BeliefNode[] nodes = bn.bn.getNodes();
00135                         
00136                         
00137                         if(removeDeterministicCPTEntries) {
00138                                 final double lowProb = 0.001; 
00139                                 for(BeliefNode node : nodes) {
00140                                         CPF cpf = node.getCPF();                                        
00141                                         for(int i = 0; i < cpf.size(); i++)
00142                                                 if(cpf.getDouble(i) == 0.0)
00143                                                         cpf.put(i, new ValueDouble(lowProb));
00144                                         cpf.normalizeByDomain();
00145                                 }
00146                         }
00147                         
00148                         
00149                         int[] evidenceDomainIndices = new int[nodes.length];
00150                         Arrays.fill(evidenceDomainIndices, -1);
00151                         for(int i = 0; i < nodes.length; i++) {
00152                                 
00153                                 String dbContent = FileUtil.readTextFile(dbFile);                               
00154                                 
00155                                 Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL);
00156                                 Matcher matcher = comments.matcher(dbContent);
00157                                 dbContent = matcher.replaceAll("");
00158                                 
00159                                 BufferedReader br = new BufferedReader(new StringReader(dbContent));
00160                                 String line;
00161                                 while((line = br.readLine()) != null) {
00162                                         line = line.trim();                     
00163                                         if(line.length() > 0) {
00164                                                 String[] entry = line.split("\\s*=\\s*");
00165                                                 if(entry.length != 2)
00166                                                         throw new Exception("Incorrectly formatted evidence entry: " + line);
00167                                                 BeliefNode node = bn.getNode(entry[0]);
00168                                                 if(node == null)
00169                                                         throw new Exception("Evidence node '" + entry[0] + "' not found in model.");
00170                                                 Discrete dom = (Discrete)node.getDomain();
00171                                                 int domidx = dom.findName(entry[1]);
00172                                                 if(domidx == -1)
00173                                                         throw new Exception("Value '" + entry[1] + "' not found in domain of node '" + entry[0] + "'");
00174                                                 evidenceDomainIndices[bn.getNodeIndex(node)] = domidx;
00175                                         }
00176                                 }
00177                         }
00178                         
00179                         
00180                         GeneralSampledDistribution referenceDist = null;
00181                         if(referenceDistFile != null) {
00182                                 referenceDist = GeneralSampledDistribution.fromFile(new File(referenceDistFile));
00183                         }
00184                         
00185                         
00186                         Vector<Integer> queryVars = new Vector<Integer>();
00187                         for(String qq : queries) {
00188                                 int varIdx = bn.getNodeIndex(qq);
00189                                 if(varIdx == -1)
00190                                         throw new Exception("Unknown variable '" + qq + "'");
00191                                 queryVars.add(varIdx);
00192                         }
00193                         
00194                         
00195                         Stopwatch sw = new Stopwatch();
00196                         sw.start();
00197                         
00198                         Sampler sampler = algo.createSampler(bn);
00199                         
00200                         sampler.setEvidence(evidenceDomainIndices);
00201                         sampler.setQueryVars(queryVars);
00202                         sampler.setDebugMode(debug);
00203                         sampler.setMaxTrials(maxTrials);
00204                         sampler.setSkipFailedSteps(skipFailedSteps);
00205                         sampler.setNumSamples(maxSteps);
00206                         sampler.setInfoInterval(infoInterval);
00207                         
00208                         SampledDistribution dist = null; 
00209                         if(timeLimitedInference) {
00210                                 if(!(sampler instanceof ITimeLimitedInference)) 
00211                                         throw new Exception(sampler.getAlgorithmName() + " does not support time-limited inference");                                   
00212                                 ITimeLimitedInference tliSampler = (ITimeLimitedInference) sampler;
00213                                 if(!useMaxSteps)                                
00214                                         sampler.setNumSamples(Integer.MAX_VALUE);
00215                                 sampler.setInfoInterval(Integer.MAX_VALUE); 
00216                                 TimeLimitedInference tli = new TimeLimitedInference(tliSampler, timeLimit, infoIntervalTime);
00217                                 tli.setReferenceDistribution(referenceDist);
00218                                 dist = tli.run();
00219                                 if(referenceDist != null)
00220                                         System.out.println("MSEs: " + tli.getMSEs());                           
00221                         }
00222                         else                            
00223                                 dist = sampler.infer();                 
00224                         sw.stop();
00225                         
00226                         
00227                         for(String qq : queries) {
00228                                 int varIdx = bn.getNodeIndex(qq);
00229                                 dist.printVariableDistribution(System.out, varIdx);
00230                         }
00231                         
00232                         
00233                         if(outputDistFile != null) {
00234                                 GeneralSampledDistribution gdist = dist.toGeneralDistribution();
00235                                 File f=  new File(outputDistFile);
00236                                 gdist.write(f);         
00237                                 GeneralSampledDistribution gdist2 = GeneralSampledDistribution.fromFile(f);
00238                                 gdist2.print(System.out);
00239                         }
00240                         
00241                         
00242                         if(referenceDist != null) {                             
00243                                 System.out.println("comparing to reference distribution...");
00244                                 BLNinfer.compareDistributions(referenceDist, dist);
00245                         }
00246                 }
00247                 catch(Exception e) {
00248                         e.printStackTrace();
00249                 }
00250         }
00251 
00252         public static boolean balancedParentheses(String s) {
00253                 int n = 0;
00254                 for(int i = 0; i < s.length(); i++) {
00255                         if(s.charAt(i) == '(')
00256                                 n++;
00257                         else if(s.charAt(i) == ')')
00258                                 n--;
00259                 }
00260                 return n == 0;
00261         }
00262 }