00001
00002
00003
00004
00005
00006
00007 package edu.tum.cs.probcog;
00008
00009 import java.util.Map;
00010 import java.util.Vector;
00011
00012 import edu.tum.cs.srl.Database;
00013 import edu.tum.cs.srl.Signature;
00014 import edu.tum.cs.srl.mln.MarkovLogicNetwork;
00015 import edu.tum.cs.srl.mln.MarkovRandomField;
00016 import edu.tum.cs.srl.mln.inference.InferenceAlgorithm;
00017 import edu.tum.cs.srl.mln.inference.MCSAT;
00018
00019 public class MLNModel extends Model {
00020
00021 protected MarkovLogicNetwork mln;
00022 protected Database db;
00023 protected MarkovRandomField mrf;
00024
00025 public MLNModel(String name, String mln) throws Exception {
00026 super(name);
00027 this.mln = new MarkovLogicNetwork(mln);
00028 }
00029
00030 @Override
00031 protected String _getConstantType(String constant) {
00032 return db.getConstantType(constant);
00033 }
00034
00035 @Override
00036 public void beginSession(Map<String, Object> params) throws Exception {
00037 super.beginSession(params);
00038 db = new Database(mln);
00039 }
00040
00041 @Override
00042 protected Vector<InferenceResult> _infer(Iterable<String> queries) throws Exception {
00043 InferenceAlgorithm ia = new MCSAT(mrf);
00044 paramHandler.addSubhandler(ia);
00045 Vector<InferenceResult> res = new Vector<InferenceResult>();
00046 int maxSteps = 5000;
00047 for(edu.tum.cs.srl.mln.inference.InferenceResult r : ia.infer(queries, maxSteps)) {
00048 InferenceResult r2 = new InferenceResult(r.ga.predicate, r.ga.args, r.value);
00049 res.add(r2);
00050 }
00051 return res;
00052 }
00053
00054 @Override
00055 protected void _setEvidence(Iterable<String[]> evidence) throws Exception {
00056 for(String[] tuple : evidence) {
00057 String functionName = tuple[0];
00058 Signature sig = mln.getSignature(functionName);
00059 if(sig == null)
00060 throw new Exception("Function '" + functionName + "' appearing in evidence not found in model " + name);
00061 String value;
00062 String[] params;
00063 if(sig.argTypes.length == tuple.length-1) {
00064 params = new String[tuple.length-1];
00065 for(int i = 0; i < params.length; i++)
00066 params[i] = tuple[i+1];
00067 value = "True";
00068 }
00069 else {
00070 params = new String[tuple.length-2];
00071 for(int i = 0; i < params.length; i++)
00072 params[i] = tuple[i+1];
00073 value = tuple[tuple.length-1];
00074 }
00075 db.addVariable(new Database.Variable(functionName, params, value, mln));
00076 }
00077 }
00078
00079 @Override
00080 public Vector<String[]> getDomains() {
00081 throw new RuntimeException("not implemented");
00082 }
00083
00084 @Override
00085 public Vector<String[]> getPredicates() {
00086 return getPredicatesFromSignatures(mln.getSignatures());
00087 }
00088
00089 @Override
00090 public void instantiate() throws Exception {
00091 mrf = mln.ground(db);
00092 }
00093 }