00001 
00002 
00003 
00004 
00005 
00006 
00007 package edu.tum.cs.probcog;
00008 
00009 import java.util.Map;
00010 import java.util.Vector;
00011 
00012 import edu.tum.cs.srl.Database;
00013 import edu.tum.cs.srl.Signature;
00014 import edu.tum.cs.srl.mln.MarkovLogicNetwork;
00015 import edu.tum.cs.srl.mln.MarkovRandomField;
00016 import edu.tum.cs.srl.mln.inference.InferenceAlgorithm;
00017 import edu.tum.cs.srl.mln.inference.MCSAT;
00018 
00019 public class MLNModel extends Model {
00020 
00021         protected MarkovLogicNetwork mln;
00022         protected Database db;
00023         protected MarkovRandomField mrf;
00024         
00025         public MLNModel(String name, String mln) throws Exception {
00026                 super(name);
00027                 this.mln = new MarkovLogicNetwork(mln);
00028         }
00029 
00030         @Override
00031         protected String _getConstantType(String constant) {
00032                 return db.getConstantType(constant);
00033         }
00034         
00035         @Override
00036         public void beginSession(Map<String, Object> params) throws Exception {
00037                 super.beginSession(params);
00038                 db = new Database(mln);
00039         }
00040 
00041         @Override
00042         protected Vector<InferenceResult> _infer(Iterable<String> queries) throws Exception {
00043                 InferenceAlgorithm ia = new MCSAT(mrf);
00044                 paramHandler.addSubhandler(ia);
00045                 Vector<InferenceResult> res = new Vector<InferenceResult>();
00046                 int maxSteps = 5000;
00047                 for(edu.tum.cs.srl.mln.inference.InferenceResult r : ia.infer(queries, maxSteps)) {
00048                         InferenceResult r2 = new InferenceResult(r.ga.predicate, r.ga.args, r.value);
00049                         res.add(r2);
00050                 }
00051                 return res;
00052         }
00053 
00054         @Override
00055         protected void _setEvidence(Iterable<String[]> evidence) throws Exception {
00056                 for(String[] tuple : evidence) {
00057                         String functionName = tuple[0];
00058                         Signature sig = mln.getSignature(functionName);
00059                         if(sig == null)
00060                                 throw new Exception("Function '" + functionName + "' appearing in evidence not found in model " + name);
00061                         String value;
00062                         String[] params;
00063                         if(sig.argTypes.length == tuple.length-1) {
00064                                 params = new String[tuple.length-1];
00065                                 for(int i = 0; i < params.length; i++)
00066                                         params[i] = tuple[i+1];
00067                                 value = "True";
00068                         }
00069                         else {
00070                                 params = new String[tuple.length-2];
00071                                 for(int i = 0; i < params.length; i++)
00072                                         params[i] = tuple[i+1];
00073                                 value = tuple[tuple.length-1];
00074                         }
00075                         db.addVariable(new Database.Variable(functionName, params, value, mln));
00076                 }
00077         }
00078 
00079         @Override
00080         public Vector<String[]> getDomains() {
00081                 throw new RuntimeException("not implemented"); 
00082         }
00083 
00084         @Override
00085         public Vector<String[]> getPredicates() {               
00086                 return getPredicatesFromSignatures(mln.getSignatures());
00087         }
00088 
00089         @Override
00090         public void instantiate() throws Exception {
00091                 mrf = mln.ground(db);           
00092         }
00093 }