00001 import java.io.File;
00002 import java.io.FileWriter;
00003 import java.util.Vector;
00004 import java.util.regex.Pattern;
00005
00006 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00007 import edu.ksu.cis.bnj.ver3.core.CPF;
00008 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00009 import edu.tum.cs.bayesnets.inference.BackwardSampling;
00010 import edu.tum.cs.bayesnets.inference.BackwardSamplingWithChildren;
00011 import edu.tum.cs.bayesnets.inference.BackwardSamplingWithPriors;
00012 import edu.tum.cs.bayesnets.inference.LikelihoodWeightingWithUncertainEvidence;
00013 import edu.tum.cs.bayesnets.inference.SmileBackwardSampling;
00014 import edu.tum.cs.bayesnets.inference.SmileEPIS;
00015 import edu.tum.cs.srl.bayesnets.ABL;
00016 import edu.tum.cs.srl.bayesnets.bln.AbstractGroundBLN;
00017 import edu.tum.cs.srl.bayesnets.bln.BayesianLogicNetwork;
00018 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00019 import edu.tum.cs.srl.bayesnets.bln.py.BayesianLogicNetworkPy;
00020 import edu.tum.cs.srl.bayesnets.inference.BNSampler;
00021 import edu.tum.cs.srl.bayesnets.inference.GibbsSampling;
00022 import edu.tum.cs.srl.bayesnets.inference.InferenceResult;
00023 import edu.tum.cs.srl.bayesnets.inference.LiftedBackwardSampling;
00024 import edu.tum.cs.srl.bayesnets.inference.SATIS;
00025 import edu.tum.cs.srl.bayesnets.inference.SATISEx;
00026 import edu.tum.cs.srl.bayesnets.inference.Sampler;
00027 import edu.tum.cs.util.Stopwatch;
00028
00029
00030 public class BLNinferMSNBC {
00031
00032 enum Algorithm {LikelihoodWeighting, LWU, CSP, GibbsSampling, EPIS, BackwardSampling, SmileBackwardSampling, BackwardSamplingPriors, Experimental, LiftedBackwardSampling, SATIS, SATISEx};
00033
00037 public static void main(String[] args) {
00038 try {
00039 String blogFile = null;
00040 String bifFile = null;
00041 String blnFile = null;
00042 String dbFile = null;
00043 String query = null;
00044 int maxSteps = 1000;
00045 int maxTrials = 5000;
00046 Algorithm algo = Algorithm.LikelihoodWeighting;
00047 String[] cwPreds = null;
00048 boolean showBN = false;
00049 boolean usePython = false;
00050 boolean debug = false;
00051 boolean saveInstance = false;
00052 boolean skipFailedSteps = false;
00053 boolean removeDeterministicCPTEntries = false;
00054
00055
00056 for(int i = 0; i < args.length; i++) {
00057 if(args[i].equals("-b"))
00058 blogFile = args[++i];
00059 else if(args[i].equals("-x"))
00060 bifFile = args[++i];
00061 else if(args[i].equals("-l"))
00062 blnFile = args[++i];
00063 else if(args[i].equals("-q"))
00064 query = args[++i];
00065 else if(args[i].equals("-e"))
00066 dbFile = args[++i];
00067 else if(args[i].equals("-s"))
00068 showBN = true;
00069 else if(args[i].equals("-nodetcpt"))
00070 removeDeterministicCPTEntries = true;
00071 else if(args[i].equals("-si"))
00072 saveInstance = true;
00073 else if(args[i].equals("-skipFailedSteps"))
00074 skipFailedSteps = true;
00075 else if(args[i].equals("-py"))
00076 usePython = true;
00077 else if(args[i].equals("-cw"))
00078 cwPreds = args[++i].split(",");
00079 else if(args[i].equals("-maxSteps"))
00080 maxSteps = Integer.parseInt(args[++i]);
00081 else if(args[i].equals("-maxTrials"))
00082 maxTrials = Integer.parseInt(args[++i]);
00083 else if(args[i].equals("-lw"))
00084 algo = Algorithm.LikelihoodWeighting;
00085 else if(args[i].equals("-lwu"))
00086 algo = Algorithm.LWU;
00087 else if(args[i].equals("-epis"))
00088 algo = Algorithm.EPIS;
00089 else if(args[i].equals("-csp"))
00090 algo = Algorithm.CSP;
00091 else if(args[i].equals("-gs"))
00092 algo = Algorithm.GibbsSampling;
00093 else if(args[i].equals("-bs"))
00094 algo = Algorithm.BackwardSampling;
00095 else if(args[i].equals("-sbs"))
00096 algo = Algorithm.SmileBackwardSampling;
00097 else if(args[i].equals("-bsp"))
00098 algo = Algorithm.BackwardSamplingPriors;
00099 else if(args[i].equals("-lbs"))
00100 algo = Algorithm.LiftedBackwardSampling;
00101 else if(args[i].equals("-exp"))
00102 algo = Algorithm.Experimental;
00103 else if(args[i].equals("-satis"))
00104 algo = Algorithm.SATIS;
00105 else if(args[i].equals("-satisex"))
00106 algo = Algorithm.SATISEx;
00107 else if(args[i].equals("-debug"))
00108 debug = true;
00109 else
00110 System.err.println("Warning: unknown option " + args[i] + " ignored!");
00111 }
00112 if(bifFile == null || dbFile == null || blogFile == null || blnFile == null || query == null) {
00113 System.out.println("\n usage: inferBLN <-b <BLOG file>> <-x <xml-BIF file>> <-l <BLN file>> <-e <evidence db>> <-q <comma-sep. queries>> [options]\n\n"+
00114 " -maxSteps # the maximum number of steps to take\n" +
00115 " -maxTrials # the maximum number of trials per step for BN sampling algorithms\n" +
00116 " -skipFailedSteps failed steps (> max trials) should just be skipped\n" +
00117 " -lw algorithm: likelihood weighting (default)\n" +
00118 " -lwu algorithm: likelihood weighting with uncertain evidence (default)\n" +
00119 " -gs algorithm: Gibbs sampling\n" +
00120 " -exp algorithm: Experimental\n" +
00121 " -satis algorithm: SAT-IS\n" +
00122 " -satisex algorithm: SAT-IS (extended with hard CPT constraints) \n" +
00123 " -bs algorithm: backward sampling\n" +
00124 " -sbs algorithm: SMILE backward sampling\n" +
00125 " -epis algorithm: SMILE evidence prepropagation importance sampling\n" +
00126 " -py use Python-based logic engine\n" +
00127 " -debug debug mode with additional outputs\n" +
00128 " -s show ground network in editor\n" +
00129 " -si save ground network instance in BIF format (.instance.xml)\n" +
00130 " -nodetcpt remove deterministic CPT columns by replacing 0s with low prob. values\n" +
00131 " -cw <predNames> set predicates as closed-world (comma-separated list of names)\n");
00132 return;
00133 }
00134
00135
00136 Pattern comma = Pattern.compile("\\s*,\\s*");
00137 String[] candQueries = comma.split(query);
00138 Vector<String> queries = new Vector<String>();
00139 String q = "";
00140 for(int i = 0; i < candQueries.length; i++) {
00141 if(!q.equals(""))
00142 q += ",";
00143 q += candQueries[i];
00144 if(balancedParentheses(q)) {
00145 queries.add(q);
00146 q = "";
00147 }
00148 }
00149 if(!q.equals(""))
00150 throw new IllegalArgumentException("Unbalanced parentheses in queries");
00151
00152
00153 ABL blog = new ABL(blogFile, bifFile);
00154
00155
00156 if(removeDeterministicCPTEntries) {
00157 final double lowProb = 0.001;
00158 for(BeliefNode node : blog.bn.getNodes()) {
00159 CPF cpf = node.getCPF();
00160 for(int i = 0; i < cpf.size(); i++)
00161 if(cpf.getDouble(i) == 0.0)
00162 cpf.put(i, new ValueDouble(lowProb));
00163 cpf.normalizeByDomain();
00164 }
00165 }
00166
00167
00168
00169 System.out.println("Reading data...");
00170 String[] pathName = dbFile.split("/");
00171 String dirName=".";
00172 for(int p=0;p<pathName.length-1;p++) {
00173 dirName+="/"+pathName[p];
00174 }
00175
00176 Pattern p = Pattern.compile( pathName[pathName.length-1] );
00177 FileWriter resFile = new FileWriter("results.csv" );
00178
00179 for (File file : new File( dirName ).listFiles()) {
00180 if(p.matcher(file.getName()).matches()) {
00181
00182 String testDBfile = dirName+"/"+file.getName();
00183
00184
00185
00186 AbstractGroundBLN gbln;
00187 if(!usePython) {
00188 BayesianLogicNetwork bln = new BayesianLogicNetwork(blog, blnFile);
00189 gbln = new GroundBLN(bln, testDBfile);
00190 }
00191 else {
00192 BayesianLogicNetworkPy bln = new BayesianLogicNetworkPy(blog, blnFile);
00193 gbln = new edu.tum.cs.srl.bayesnets.bln.py.GroundBLN(bln, testDBfile);
00194 }
00195 if(cwPreds != null) {
00196 for(String predName : cwPreds)
00197 gbln.getDatabase().setClosedWorldPred(predName);
00198 }
00199 gbln.instantiateGroundNetwork();
00200 if(showBN) {
00201 gbln.getGroundNetwork().show();
00202 }
00203 if(saveInstance) {
00204 String baseName = bifFile.substring(0, bifFile.lastIndexOf('.'));
00205 gbln.getGroundNetwork().saveXMLBIF(baseName + ".instance.xml");
00206 }
00207
00208
00209 Stopwatch sw = new Stopwatch();
00210 sw.start();
00211 Sampler sampler = null;
00212 switch(algo) {
00213 case LikelihoodWeighting:
00214 sampler = new BNSampler(gbln, edu.tum.cs.bayesnets.inference.LikelihoodWeighting.class); break;
00215 case LWU:
00216 sampler = new BNSampler(gbln, LikelihoodWeightingWithUncertainEvidence.class); break;
00217 case GibbsSampling:
00218 sampler = new GibbsSampling(gbln); break;
00219 case EPIS:
00220 sampler = new BNSampler(gbln, SmileEPIS.class); break;
00221 case SmileBackwardSampling:
00222 sampler = new BNSampler(gbln, SmileBackwardSampling.class); break;
00223 case BackwardSampling:
00224 sampler = new BNSampler(gbln, BackwardSampling.class); break;
00225 case BackwardSamplingPriors:
00226 sampler = new BNSampler(gbln, BackwardSamplingWithPriors.class); break;
00227 case Experimental:
00228 sampler = new BNSampler(gbln, BackwardSamplingWithChildren.class); break;
00229 case LiftedBackwardSampling:
00230 sampler = new LiftedBackwardSampling((GroundBLN)gbln); break;
00231 case SATIS:
00232 sampler = new SATIS((GroundBLN)gbln); break;
00233 case SATISEx:
00234 sampler = new SATISEx((GroundBLN)gbln); break;
00235 default:
00236 throw new Exception("algorithm not handled");
00237 }
00238 sampler.setDebugMode(debug);
00239 if(sampler instanceof BNSampler) {
00240 ((BNSampler)sampler).setMaxTrials(maxTrials);
00241 ((BNSampler)sampler).setSkipFailedSteps(skipFailedSteps);
00242 }
00243 sampler.setNumSamples(maxSteps);
00244 sampler.setQueries(queries);
00245 Vector<InferenceResult> results = sampler.inferQueries();
00246 sw.stop();
00247
00248 for(InferenceResult res : results) {
00249
00250 double max = res.probabilities[0];
00251 int maxIdx = 0;
00252 for (int i = 1; i < res.probabilities.length; i++) {
00253 if (res.probabilities[i] > max) {max = res.probabilities[i];maxIdx=i;}
00254 }
00255
00256 String truth = res.varName.substring(0, res.varName.length()-1).split("\\(")[1];
00257 truth = truth.split("_")[3];
00258 int eq=0; if(truth.equals(res.domainElements[maxIdx])) eq=1;
00259
00260 if(max<1)
00261 resFile.write(res.varName +","+truth +"," + res.domainElements[maxIdx] + ","+max+ ","+eq+"\n");
00262
00263 }
00264 }
00265 }
00266 resFile.close();
00267 }
00268 catch(Exception e) {
00269 e.printStackTrace();
00270 }
00271 }
00272
00273 public static boolean balancedParentheses(String s) {
00274 int n = 0;
00275 for(int i = 0; i < s.length(); i++) {
00276 if(s.charAt(i) == '(')
00277 n++;
00278 else if(s.charAt(i) == ')')
00279 n--;
00280 }
00281 return n == 0;
00282 }
00283 }