00001 package edu.tum.cs.srl.bayesnets;
00002
00003 import java.util.HashMap;
00004 import java.util.HashSet;
00005 import java.util.Map;
00006 import java.util.Vector;
00007 import java.util.regex.Matcher;
00008 import java.util.regex.Pattern;
00009
00010 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00011 import edu.ksu.cis.bnj.ver3.core.Discrete;
00012 import edu.tum.cs.logic.Atom;
00013 import edu.tum.cs.logic.Biimplication;
00014 import edu.tum.cs.logic.Conjunction;
00015 import edu.tum.cs.logic.Equality;
00016 import edu.tum.cs.logic.Exist;
00017 import edu.tum.cs.logic.Formula;
00018 import edu.tum.cs.logic.Literal;
00019 import edu.tum.cs.logic.Negation;
00020 import edu.tum.cs.srl.Database;
00021 import edu.tum.cs.srl.Signature;
00022 import edu.tum.cs.srl.mln.MLNWriter;
00023 import edu.tum.cs.util.StringTool;
00024 import edu.tum.cs.util.datastruct.Pair;
00025
00026 public class RelationalNode extends ExtendedNode {
00030 protected String functionName;
00034 public String[] params;
00038 public String[] addParams;
00039 public boolean isConstant, isAuxiliary, isPrecondition, isUnobserved;
00043 public Aggregator aggregator;
00047 public String parentMode;
00048 protected Vector<Integer> indicesOfConstantArgs = null;
00049
00053 protected ParentGrounder parentGrounder = null;
00054
00055 public static final String BUILTINPRED_EQUALS = "EQ";
00056 public static final String BUILTINPRED_NEQUALS = "NEQ";
00057
00058 public static enum Aggregator {
00059 FunctionalOr(true, "=OR"),
00060 NoisyOr(false, "OR"),
00061 Average(false, "AVG");
00062
00063 public boolean isFunctional;
00064 protected String syntax;
00065
00066 private Aggregator(boolean isFunctional, String syntax) {
00067 this.isFunctional = isFunctional;
00068 this.syntax = syntax;
00069 }
00070
00071 public String toString() {
00072 return super.toString() + "(\"" + syntax + "\")";
00073 }
00074
00075 public String getFunctionSyntax() {
00076 return syntax;
00077 }
00078
00079 public static Aggregator fromSyntax(String syntax) throws Exception {
00080 for(Aggregator a : Aggregator.values())
00081 if(a.syntax.equals(syntax))
00082 return a;
00083 throw new Exception("There is no aggregator for '" + syntax + "'");
00084 }
00085 }
00086
00087
00093 public static String extractFunctionName(String varName) {
00094 if(varName.contains("("))
00095 return varName.substring(0, varName.indexOf('('));
00096 return varName;
00097
00098 }
00099
00100 public static Pair<String, String[]> parse(String variable) {
00101 Pattern p = Pattern.compile("(\\w+)\\(([^\\)]+)\\)");
00102 Matcher m = p.matcher(variable);
00103 if(!m.matches())
00104 return null;
00105 return new Pair<String, String[]>(m.group(1), m.group(2).split(","));
00106 }
00107
00112 public static boolean isConstant(String identifier) {
00113 return Character.isUpperCase(identifier.charAt(0));
00114 }
00115
00116 public RelationalNode(RelationalBeliefNetwork bn, BeliefNode node) throws Exception {
00117 super(bn, node);
00118 Pattern namePat = Pattern.compile("(\\w+)\\((.*)\\)");
00119 String name = node.getName();
00120
00121 if(name.charAt(0) == '#') {
00122 isAuxiliary = true;
00123 name = name.substring(1);
00124 }
00125 else if(name.charAt(0) == '+') {
00126 isPrecondition = true;
00127 isAuxiliary = true;
00128 name = name.substring(1);
00129 }
00130
00131
00132 aggregator = null;
00133 Pattern aggPat = Pattern.compile("(=?[A-Z]+):.*");
00134 Matcher m = aggPat.matcher(name);
00135 if(m.matches()) {
00136 String aggFunction = m.group(1);
00137 aggregator = Aggregator.fromSyntax(aggFunction);
00138 name = name.substring(aggFunction.length()+1);
00139 }
00140
00141 int sepPos = name.indexOf('|');
00142 if(sepPos != -1) {
00143 String decl = name.substring(sepPos+1);
00144 Pattern declPat = Pattern.compile("([A-Z]+):(.*)");
00145 m = declPat.matcher(decl);
00146 if(m.matches()) {
00147 parentMode = m.group(1);
00148 addParams = m.group(2).split("\\s*,\\s*");
00149 }
00150 else {
00151 addParams = decl.split("\\s*,\\s*");
00152 }
00153 name = name.substring(0, sepPos);
00154 }
00155
00156 Matcher matcher = namePat.matcher(name);
00157 if(matcher.matches()) {
00158 this.functionName = matcher.group(1);
00159 this.params = matcher.group(2).split("\\s*,\\s*");
00160 this.isConstant = false;
00161 }
00162 else {
00163 this.functionName = name;
00164 this.params = new String[]{name};
00165 this.isConstant = true;
00166 }
00167
00168 if(isPrecondition)
00169 bn.setEvidenceFunction(functionName);
00170 }
00171
00176 public boolean isFragment() {
00177 return !isConstant && !isAuxiliary && !isBuiltInPred();
00178 }
00179
00183 public String toString() {
00184 return getName();
00185 }
00186
00191 public String getName() {
00192 return this.node.getName();
00193 }
00194
00199 public int getNodeIndex() {
00200 return this.index;
00201 }
00202
00207 public String getCleanName() {
00208 if(isConstant)
00209 return functionName;
00210 return Signature.formatVarName(this.functionName, this.params);
00211 }
00212
00216 public boolean isBoolean() {
00217 Signature sig = bn.getSignature(this);
00218 if(sig != null)
00219 return sig.isBoolean();
00220 else
00221 return bn.isBooleanDomain(node.getDomain());
00222 }
00223
00228 public String getFunctionName() {
00229 return this.functionName;
00230 }
00231
00238 public String toLiteralString(int setting, HashMap<String,String> constantValues) {
00239
00240 if(this.functionName.equals(BUILTINPRED_NEQUALS))
00241 return String.format("!(%s=%s)", this.params[0], this.params[1]);
00242 if(this.functionName.equals(BUILTINPRED_EQUALS))
00243 return String.format("%s=%s", this.params[0], this.params[1]);
00244
00245
00246
00247 StringBuffer sb = new StringBuffer(String.format("%s(", MLNWriter.lowerCaseString(functionName)));
00248
00249 for(int i = 0; i < params.length; i++) {
00250 if(i > 0)
00251 sb.append(",");
00252 String value = constantValues != null ? constantValues.get(params[i]) : null;
00253 if(value == null)
00254 sb.append(params[i]);
00255 else
00256 sb.append(value);
00257 }
00258
00259 String value = ((Discrete)node.getDomain()).getName(setting);
00260 if(isBoolean()) {
00261 if(value.equalsIgnoreCase("false"))
00262 sb.insert(0, '!');
00263 }
00264 else {
00265 sb.append(',');
00266 sb.append(MLNWriter.upperCaseString(value));
00267 }
00268 sb.append(')');
00269 return sb.toString();
00270 }
00271
00278 public Formula toFormula(Map<String,String> constantValues) throws Exception {
00279 if(!hasAggregator())
00280 return null;
00281 if(aggregator == Aggregator.FunctionalOr) {
00282
00283 Vector<Formula> parents = new Vector<Formula>();
00284 for(RelationalNode parent : this.getRelationalParents()) {
00285 parents.add(parent.toLiteral(0, constantValues));
00286 }
00287 return new Biimplication(this.toLiteral(0, constantValues), new Exist(this.addParams, new Conjunction(parents)));
00288 }
00289 return null;
00290 }
00291
00298 public Formula toLiteral(int domIdx, Map<String,String> constantValues) {
00299
00300 if(this.functionName.equals(BUILTINPRED_NEQUALS))
00301 return new Negation(new Equality(this.params[0], this.params[1]));
00302 if(this.functionName.equals(BUILTINPRED_EQUALS))
00303 return new Equality(this.params[0], this.params[1]);
00304
00305
00306 Vector<String> atomParams = new Vector<String>();
00307 for(int i = 0; i < params.length; i++) {
00308 String value = constantValues != null ? constantValues.get(params[i]) : null;
00309 if(value == null)
00310 atomParams.add(params[i]);
00311 else
00312 atomParams.add(value);
00313 }
00314
00315 String value = ((Discrete)node.getDomain()).getName(domIdx);
00316 if(isBoolean()) {
00317 boolean isTrue = !value.equalsIgnoreCase("false");
00318 return new Literal(isTrue, new Atom(this.functionName, atomParams));
00319 }
00320 else {
00321 atomParams.add(value);
00322 return new Atom(this.functionName, atomParams);
00323 }
00324 }
00325
00329 public RelationalBeliefNetwork getNetwork() {
00330 return bn;
00331 }
00332
00336 public boolean hasCPT() {
00337 return aggregator == null || !aggregator.isFunctional;
00338 }
00339
00343 public boolean hasAggregator() {
00344 return this.aggregator != null;
00345 }
00346
00351 public Signature getSignature() {
00352 return bn.getSignature(this);
00353 }
00354
00358 public boolean isRelation() {
00359 return params != null && params.length > 1;
00360 }
00361
00362
00369 public String getVariableName(String[] actualParams) throws Exception {
00370 if(actualParams.length != params.length)
00371 throw new Exception(String.format("Invalid number of actual parameters suppplied for %s: expected %d, got %d", toString(), params.length, actualParams.length));
00372 return Signature.formatVarName(getFunctionName(), actualParams);
00373 }
00374
00375 public Vector<RelationalNode> getParents() {
00376 return bn.getRelationalParents(this);
00377 }
00378
00384 public boolean hasParams(String[] params) {
00385 for(int i = 0; i < params.length; i++) {
00386 int j = 0;
00387 for(; j < this.params.length; j++)
00388 if(params[i].equals(this.params[j]))
00389 break;
00390 if(j == this.params.length)
00391 return false;
00392 }
00393 return true;
00394 }
00395
00396 public boolean hasParam(String param) {
00397 for(int i = 0; i < params.length; i++)
00398 if(params[i].equals(param))
00399 return true;
00400 return false;
00401 }
00402
00408 public RelationalNode getFreeParamGroundingParent() throws Exception {
00409 if(addParams == null || addParams.length == 0)
00410 throw new Exception("This node has no free parameters for which there could be a parent that grounds them.");
00411
00412 for(RelationalNode parent : getParents()) {
00413 if(parent.isRelation() && parent.hasParams(this.addParams)) {
00414 return parent;
00415 }
00416 }
00417 return null;
00418 }
00419
00425 public String toAtom() throws Exception {
00426 if(!isBoolean())
00427 throw new Exception("Cannot convert non-Boolean node to atom without specifying setting");
00428 return getCleanName();
00429 }
00430
00434 public void setLabel() {
00435 StringBuffer buf = new StringBuffer();
00436 if(this.aggregator != null)
00437 buf.append(aggregator.getFunctionSyntax() + ":");
00438 buf.append(getCleanName());
00439 if(this.addParams != null && this.addParams.length > 0) {
00440 buf.append("|");
00441 if(this.parentMode != null && this.parentMode.length() > 0)
00442 buf.append(parentMode + ":");
00443 buf.append(StringTool.join(",", this.addParams));
00444 }
00445 this.node.setName(buf.toString());
00446 }
00447
00448 public Discrete getDomain() {
00449 return (Discrete)node.getDomain();
00450 }
00451
00459 public String getValueInDB(String[] actualParams, Database db, boolean closedWorld) throws Exception {
00460
00461 if(functionName.equals(BUILTINPRED_NEQUALS))
00462 return actualParams[0].equals(actualParams[1]) ? "False" : "True";
00463 if(functionName.equals(BUILTINPRED_EQUALS))
00464 return actualParams[0].equals(actualParams[1]) ? "True" : "False";
00465
00466 if(!isConstant) {
00467 String curVarName = getVariableName(actualParams);
00468
00469 String value = db.getVariableValue(curVarName, closedWorld);
00470 if(value == null) {
00471 throw new Exception("Could not find value of " + curVarName + " in database. closedWorld = " + closedWorld);
00472 }
00473 return value;
00474
00475 }
00476 else {
00477
00478 return actualParams[0];
00479 }
00480 }
00481
00482 public boolean isBuiltInPred() {
00483 return functionName.equals(BUILTINPRED_EQUALS) || functionName.equals(BUILTINPRED_NEQUALS);
00484 }
00485
00490 public Vector<HashMap<String,String>> getConstantAssignments() {
00491 RelationalNode mainNode = this;
00492
00493 Vector<RelationalNode> constantParents = new Vector<RelationalNode>();
00494 for(RelationalNode parent : this.getNetwork().getRelationalParents(mainNode)) {
00495 if(parent.isConstant)
00496 constantParents.add(parent);
00497 }
00498
00499 Vector<HashMap<String, String>> constantAssignments = new Vector<HashMap<String, String>>();
00500 if(constantParents.isEmpty())
00501 constantAssignments.add(new HashMap<String,String>());
00502 else
00503 collectConstantAssignments(constantParents.toArray(new RelationalNode[0]), 0, new String[constantParents.size()], constantAssignments);
00504 return constantAssignments;
00505 }
00506
00507 protected void collectConstantAssignments(RelationalNode[] constNodes, int i, String[] assignment, Vector<HashMap<String,String>> assignments) {
00508 if(i == constNodes.length) {
00509 HashMap<String,String> m = new HashMap<String,String>();
00510 for(int j = 0; j < assignment.length; j++)
00511 m.put(constNodes[j].getName(), assignment[j]);
00512 assignments.add(m);
00513 }
00514 else {
00515 Discrete dom = (Discrete)constNodes[i].node.getDomain();
00516 for(int j = 0; j < dom.getOrder(); j++) {
00517 assignment[i] = dom.getName(j);
00518 collectConstantAssignments(constNodes, i+1, assignment, assignments);
00519 }
00520 }
00521 }
00522
00527 public Vector<RelationalNode> getRelationalParents() {
00528 return bn.getRelationalParents(this);
00529 }
00530
00535 public Vector<Integer> getIndicesOfConstantParams() {
00536 if(indicesOfConstantArgs == null) {
00537 indicesOfConstantArgs = new Vector<Integer>();
00538 HashSet<String> constantArgs = new HashSet<String>();
00539 for(RelationalNode parent : getRelationalParents()) {
00540 if(parent.isConstant)
00541 constantArgs.add(parent.functionName);
00542 }
00543 for(int i = 0; i < params.length; i++)
00544 if(constantArgs.contains(params[i]))
00545 indicesOfConstantArgs.add(i);
00546 }
00547 return indicesOfConstantArgs;
00548 }
00549
00550 public ParentGrounder getParentGrounder() throws Exception {
00551 if(parentGrounder != null)
00552 return parentGrounder;
00553 return (parentGrounder = new ParentGrounder(this.bn, this));
00554 }
00555
00563 public HashMap<String,String> getParameterBinding(String[] actualParams, Database db) throws Exception {
00564 return getParentGrounder().generateParameterBindings(actualParams, db);
00565 }
00566
00567 public Vector<RelationalNode> getPreconditionParents() {
00568 Vector<RelationalNode> ret = new Vector<RelationalNode>();
00569 BeliefNode[] domprod = this.node.getCPF().getDomainProduct();
00570 for(int i = 1; i < domprod.length; i++) {
00571 ExtendedNode n = this.bn.getExtendedNode(domprod[i]);
00572 if(n instanceof RelationalNode) {
00573 RelationalNode rn = (RelationalNode)n;
00574 if(rn.isPrecondition)
00575 ret.add(rn);
00576 }
00577 }
00578 return ret;
00579 }
00580 }
00581