00001 import java.util.Collections;
00002 import java.util.HashMap;
00003 import java.util.List;
00004 import java.util.Vector;
00005 import java.util.regex.Pattern;
00006
00007 import edu.tum.cs.logic.sat.weighted.MaxWalkSATRoom;
00008 import edu.tum.cs.logic.sat.weighted.WeightedFormula;
00009 import edu.tum.cs.srl.Database;
00010 import edu.tum.cs.srl.mln.MarkovLogicNetwork;
00011 import edu.tum.cs.srl.mln.MarkovRandomField;
00012 import edu.tum.cs.srl.mln.inference.InferenceAlgorithm;
00013 import edu.tum.cs.srl.mln.inference.InferenceResult;
00014 import edu.tum.cs.srl.mln.inference.MAPInferenceAlgorithm;
00015 import edu.tum.cs.srl.mln.inference.MCSAT;
00016 import edu.tum.cs.srl.mln.inference.MaxWalkSAT;
00017 import edu.tum.cs.srl.mln.inference.Toulbar2MAPInference;
00018 import edu.tum.cs.util.Stopwatch;
00019
00024 public class MLNinfer {
00025
00026 enum Algorithm {MaxWalkSAT, MCSAT, Toulbar2, MaxWalkSATRooms};
00027
00031 public static void main(String[] args) {
00032 try {
00033 String mlnFile = null;
00034 String dbFile = null;
00035 String query = null;
00036 int maxSteps = 1000;
00037 Algorithm algo = Algorithm.MCSAT;
00038 String[] cwPreds = null;
00039 boolean debug = false;
00040 HashMap<String,Object> params = new HashMap<String,Object>();
00041
00042
00043 for(int i = 0; i < args.length; i++) {
00044 if(args[i].equals("-i"))
00045 mlnFile = args[++i];
00046 else if(args[i].equals("-q"))
00047 query = args[++i];
00048 else if(args[i].equals("-e"))
00049 dbFile = args[++i];
00050 else if(args[i].equals("-cw"))
00051 cwPreds = args[++i].split(",");
00052 else if(args[i].equals("-maxSteps"))
00053 maxSteps = Integer.parseInt(args[++i]);
00054 else if(args[i].equals("-mws"))
00055 algo = Algorithm.MaxWalkSAT;
00056 else if(args[i].equals("-mwsr"))
00057 algo = Algorithm.MaxWalkSATRooms;
00058 else if(args[i].equals("-mcsat"))
00059 algo = Algorithm.MCSAT;
00060 else if(args[i].equals("-t2"))
00061 algo = Algorithm.Toulbar2;
00062 else if(args[i].equals("-debug"))
00063 debug = true;
00064 else if(args[i].startsWith("-p") || args[i].startsWith("--")) {
00065 String[] pair = args[i].substring(2).split("=");
00066 if(pair.length != 2)
00067 throw new Exception("Argument '" + args[i] + "' for algorithm-specific parameterization is incorrectly formatted.");
00068 params.put(pair[0], pair[1]);
00069 }
00070 else
00071 System.err.println("Warning: unknown option " + args[i] + " ignored!");
00072 }
00073 if(mlnFile == null || dbFile == null || query == null) {
00074 System.out.println("\n usage: MLNinfer <-i <MLN file>> <-e <evidence db file>> <-q <comma-sep. queries>> [options]\n\n"+
00075 " -maxSteps # the maximum number of steps to take [default: 1000]\n" +
00076 " -mws algorithm: MaxWalkSAT (MAP inference)\n" +
00077 " -mcsat algorithm: MC-SAT (default)\n" +
00078 " -t2 algorithm: Toulbar2 branch & bound\n" +
00079 " -debug debug mode with additional outputs\n"
00080
00081 );
00082 return;
00083 }
00084
00085
00086 Pattern comma = Pattern.compile("\\s*,\\s*");
00087 String[] candQueries = comma.split(query);
00088 Vector<String> queries = new Vector<String>();
00089 String q = "";
00090 for(int i = 0; i < candQueries.length; i++) {
00091 if(!q.equals(""))
00092 q += ",";
00093 q += candQueries[i];
00094 if(balancedParentheses(q)) {
00095 queries.add(q);
00096 q = "";
00097 }
00098 }
00099 if(!q.equals(""))
00100 throw new IllegalArgumentException("Unbalanced parentheses in queries");
00101
00102
00103 Stopwatch constructSW = new Stopwatch();
00104 constructSW.start();
00105 System.out.printf("reading model %s...\n", mlnFile);
00106 MarkovLogicNetwork mln = new MarkovLogicNetwork(mlnFile);
00107
00108
00109 System.out.printf("reading database %s...\n", dbFile);
00110 Database db = new Database(mln);
00111 db.readMLNDB(dbFile);
00112 System.out.printf("creating ground MRF...\n");
00113 MarkovRandomField mrf = mln.ground(db);
00114 if(debug) {
00115 System.out.println("MRF:");
00116 for(WeightedFormula wf : mrf)
00117 System.out.println(" " + wf.toString());
00118 }
00119 constructSW.stop();
00120
00121
00122 Stopwatch sw = new Stopwatch();
00123 sw.start();
00124 InferenceAlgorithm infer = null;
00125 switch(algo) {
00126 case MCSAT:
00127 infer = new MCSAT(mrf);
00128 break;
00129 case MaxWalkSAT:
00130 case MaxWalkSATRooms:
00131 infer = new MaxWalkSAT(mrf, algo == Algorithm.MaxWalkSAT ? edu.tum.cs.logic.sat.weighted.MaxWalkSAT.class : MaxWalkSATRoom.class);
00132 break;
00133 case Toulbar2:
00134 infer = new Toulbar2MAPInference(mrf);
00135 break;
00136 }
00137 infer.setDebugMode(debug);
00138 infer.getParameterHandler().handle(params, true);
00139 System.out.printf("algorithm: %s, steps: %d\n", infer.getAlgorithmName(), maxSteps);
00140 List<InferenceResult> results = infer.infer(queries, maxSteps);
00141 sw.stop();
00142
00143
00144 System.out.printf("\nconstruction time: %.4fs, inference time: %.4fs\n", constructSW.getElapsedTimeSecs(), sw.getElapsedTimeSecs());
00145 System.out.println("results:");
00146 Collections.sort(results);
00147 for(InferenceResult r : results)
00148 r.print();
00149 if(infer instanceof MAPInferenceAlgorithm) {
00150 MAPInferenceAlgorithm mapi = (MAPInferenceAlgorithm)infer;
00151 double value = mrf.getWorldValue(mapi.getSolution());
00152 System.out.printf("\nsolution value: %f\n", value);
00153 }
00154 }
00155 catch(Exception e) {
00156 e.printStackTrace();
00157 }
00158 }
00159
00160 public static boolean balancedParentheses(String s) {
00161 int n = 0;
00162 for(int i = 0; i < s.length(); i++) {
00163 if(s.charAt(i) == '(')
00164 n++;
00165 else if(s.charAt(i) == ')')
00166 n--;
00167 }
00168 return n == 0;
00169 }
00170 }