00001 package edu.tum.cs.probcog;
00002
00003 import java.io.IOException;
00004 import java.util.ArrayList;
00005 import java.util.Collection;
00006 import java.util.Map;
00007 import java.util.Vector;
00008 import java.util.Map.Entry;
00009
00010 import edu.tum.cs.logic.parser.ParseException;
00011 import edu.tum.cs.srl.Database;
00012 import edu.tum.cs.srl.Signature;
00013 import edu.tum.cs.srl.bayesnets.ABL;
00014 import edu.tum.cs.srl.bayesnets.RelationalNode;
00015 import edu.tum.cs.srl.bayesnets.bln.BayesianLogicNetwork;
00016 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00017 import edu.tum.cs.srl.bayesnets.inference.BLNinfer;
00018 import edu.tum.cs.util.StringTool;
00019 import edu.tum.cs.util.datastruct.Pair;
00020
00021 public class BLNModel extends Model {
00022
00023 protected BayesianLogicNetwork bln;
00024 protected GroundBLN gbln;
00025 protected Database db;
00026 protected String filenames;
00027
00028 public BLNModel(String modelName, String blogFile, String networkFile, String logicFile) throws IOException, ParseException, Exception {
00029 super(modelName);
00030 this.filenames = String.format("%s;%s;%s", blogFile, networkFile, logicFile);
00031 this.bln = new BayesianLogicNetwork(new ABL(blogFile, networkFile), logicFile);
00032 }
00033
00034 @Override
00035 public void instantiate() throws Exception {
00036 gbln = bln.ground(db);
00037 paramHandler.addSubhandler(gbln);
00038 gbln.instantiateGroundNetwork();
00039 }
00040
00041 @Override
00042 public void beginSession(Map<String, Object> params) throws Exception {
00043 super.beginSession(params);
00044 db = new Database(bln.rbn);
00045 paramHandler.addSubhandler(db);
00046 }
00047
00048 @Override
00049 protected Vector<InferenceResult> _infer(Iterable<String> queries) throws Exception {
00050 BLNinfer inference = new BLNinfer(actualParams);
00051 paramHandler.addSubhandler(inference);
00052 inference.setGroundBLN(gbln);
00053 inference.setQueries(queries);
00054 Collection<edu.tum.cs.srl.bayesnets.inference.InferenceResult> results = inference.run();
00055
00056
00057 Vector<InferenceResult> ret = new Vector<InferenceResult>();
00058 for(edu.tum.cs.srl.bayesnets.inference.InferenceResult res : results) {
00059 Pair<String, String[]> var = RelationalNode.parse(res.varName);
00060 Signature sig = bln.rbn.getSignature(var.first);
00061 String[] params = var.second;
00062 boolean isBool = sig.isBoolean();
00063 if(!isBool) {
00064 String[] fullParams = new String[params.length+1];
00065 for(int i = 0; i < params.length; i++)
00066 fullParams[i] = params[i];
00067 params = fullParams;
00068 }
00069 for(int i = 0; i < res.domainElements.length; i++) {
00070 if(!isBool)
00071 params[params.length-1] = res.domainElements[i];
00072 else
00073 if(!res.domainElements[i].equalsIgnoreCase("True"))
00074 continue;
00075 ret.add(new InferenceResult(var.first, params.clone(), res.probabilities[i]));
00076 }
00077 }
00078 return ret;
00079 }
00080
00081 @Override
00082 protected void _setEvidence(Iterable<String[]> evidence) throws Exception {
00083 for(String[] tuple : evidence) {
00084 String functionName = tuple[0];
00085 Signature sig = bln.rbn.getSignature(functionName);
00086 if(sig == null)
00087 throw new Exception("Function '" + functionName + "' appearing in evidence not found in model " + name);
00088 String value;
00089 String[] params;
00090 if(sig.argTypes.length == tuple.length-1) {
00091 params = new String[tuple.length-1];
00092 for(int i = 0; i < params.length; i++)
00093 params[i] = tuple[i+1];
00094 value = "True";
00095 }
00096 else {
00097 if(tuple.length < sig.argTypes.length+2)
00098 throw new Exception("Evidence entry has too few parameters: " + StringTool.join(", ", tuple));
00099 params = new String[sig.argTypes.length];
00100 for(int i = 0; i < params.length; i++)
00101 params[i] = tuple[i+1];
00102 value = tuple[params.length+1];
00103 }
00104 db.addVariable(new Database.Variable(functionName, params, value, this.bln.rbn));
00105 }
00106 }
00107
00108 @Override
00109 public Vector<String[]> getPredicates() {
00110 return getPredicatesFromSignatures(this.bln.rbn.getSignatures());
00111 }
00112
00113 public Vector<String[]> getDomains() {
00114 Vector<String[]> ret = new Vector<String[]>();
00115 for(Entry<String,String[]> e : this.bln.rbn.getGuaranteedDomainElements().entrySet()) {
00116 String[] elems = e.getValue();
00117 ArrayList<String> tuple = new ArrayList<String>(elems.length+1);
00118 tuple.add(e.getKey());
00119 for(int i = 0; i < elems.length; i++) {
00120 String c = mapConstantFromProbCog(elems[i]);
00121 if(c == null)
00122 continue;
00123 tuple.add(c);
00124 }
00125 ret.add(tuple.toArray(new String[tuple.size()]));
00126 }
00127 return ret;
00128 }
00129
00130 @Override
00131 protected String _getConstantType(String constant) {
00132 return db.getConstantType(constant);
00133 }
00134
00135 @Override
00136 public String toString() {
00137 return String.format("%s=BLN[%s]", this.name, this.filenames);
00138 }
00139 }