00001 package edu.tum.cs.srl.bayesnets.bln;
00002
00003 import java.util.ArrayList;
00004 import java.util.Collection;
00005 import java.util.HashMap;
00006 import java.util.HashSet;
00007 import java.util.Iterator;
00008 import java.util.Map;
00009 import java.util.Vector;
00010 import java.util.Map.Entry;
00011
00012 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00013 import edu.ksu.cis.bnj.ver3.core.CPF;
00014 import edu.ksu.cis.bnj.ver3.core.Discrete;
00015 import edu.ksu.cis.bnj.ver3.core.Domain;
00016 import edu.ksu.cis.bnj.ver3.core.Value;
00017 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00018 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00019 import edu.tum.cs.inference.IParameterHandler;
00020 import edu.tum.cs.inference.ParameterHandler;
00021 import edu.tum.cs.srl.Database;
00022 import edu.tum.cs.srl.ParameterGrounder;
00023 import edu.tum.cs.srl.Signature;
00024 import edu.tum.cs.srl.bayesnets.CombiningRule;
00025 import edu.tum.cs.srl.bayesnets.DecisionNode;
00026 import edu.tum.cs.srl.bayesnets.ExtendedNode;
00027 import edu.tum.cs.srl.bayesnets.ParentGrounder;
00028 import edu.tum.cs.srl.bayesnets.RelationalBeliefNetwork;
00029 import edu.tum.cs.srl.bayesnets.RelationalNode;
00030 import edu.tum.cs.srl.bayesnets.RelationalNode.Aggregator;
00031 import edu.tum.cs.util.Stopwatch;
00032 import edu.tum.cs.util.StringTool;
00033 import edu.tum.cs.util.datastruct.Pair;
00034
00035 public abstract class AbstractGroundBLN implements IParameterHandler {
00039 protected BeliefNetworkEx groundBN;
00043 protected AbstractBayesianLogicNetwork bln;
00047 protected Vector<BeliefNode> hardFormulaNodes;
00051 protected String databaseFile;
00055 protected Database db;
00059 protected HashMap<String, Vector<RelationalNode>> functionTemplates;
00063 protected HashSet<String> instantiatedVariables;
00064 protected HashMap<String, Value[]> cpfCache;
00065 protected boolean verbose = true;
00066 protected boolean debug = false;
00067 protected ParameterHandler paramHandler;
00071 protected HashMap<BeliefNode, String> cpfIDs;
00075 protected HashMap<BeliefNode, RelationalNode> groundNode2TemplateNode;
00076
00077 public AbstractGroundBLN(AbstractBayesianLogicNetwork bln, Database db) throws Exception {
00078 init(bln, db);
00079 }
00080
00081 public AbstractGroundBLN(AbstractBayesianLogicNetwork bln, String databaseFile) throws Exception {
00082 this.databaseFile = databaseFile;
00083 Database db = new Database(bln.rbn);
00084 db.readBLOGDB(databaseFile, true);
00085 init(bln, db);
00086 }
00087
00088 protected void init(AbstractBayesianLogicNetwork bln, Database db) throws Exception {
00089 paramHandler = new ParameterHandler(this);
00090 paramHandler.add("verbose", "setVerbose");
00091 this.bln = bln;
00092 this.db = db;
00093 cpfIDs = new HashMap<BeliefNode, String>();
00094 groundNode2TemplateNode = new HashMap<BeliefNode, RelationalNode>();
00095 }
00096
00097 public AbstractBayesianLogicNetwork getBLN() {
00098 return this.bln;
00099 }
00100
00105 public void instantiateGroundNetwork() throws Exception {
00106 instantiateGroundNetwork(true);
00107 }
00108
00114 public void instantiateGroundNetwork(boolean addAuxiliaryVars) throws Exception {
00115 Stopwatch sw = new Stopwatch();
00116 sw.start();
00117
00118 if(verbose) System.out.println("generating network...");
00119 groundBN = new BeliefNetworkEx();
00120
00121
00122 if(verbose) System.out.println(" regular nodes");
00123 RelationalBeliefNetwork rbn = bln.rbn;
00124
00125
00126 functionTemplates = new HashMap<String, Vector<RelationalNode>>();
00127 BeliefNode[] nodes = rbn.bn.getNodes();
00128 for(int i = 0; i < nodes.length; i++) {
00129 ExtendedNode extNode = rbn.getExtendedNode(i);
00130
00131 if(!(extNode instanceof RelationalNode))
00132 continue;
00133 RelationalNode relNode = (RelationalNode)extNode;
00134 if(!relNode.isFragment())
00135 continue;
00136
00137 String f = relNode.getFunctionName();
00138 Vector<RelationalNode> v = functionTemplates.get(f);
00139 if(v == null) {
00140 v = new Vector<RelationalNode>();
00141 functionTemplates.put(f, v);
00142 }
00143 v.add(relNode);
00144 }
00145
00146
00147 instantiatedVariables = new HashSet<String>();
00148 cpfCache = new HashMap<String, Value[]>();
00149 for(String functionName : functionTemplates.keySet()) {
00150 if(verbose) System.out.println(" " + functionName);
00151 Collection<String[]> parameterSets = ParameterGrounder.generateGroundings(bln.rbn, functionName, db);
00152 for(String[] params : parameterSets)
00153 instantiateVariable(functionName, params);
00154 }
00155
00156
00157 instantiatedVariables = null;
00158 functionTemplates = null;
00159 cpfCache = null;
00160
00161
00162 if(addAuxiliaryVars) {
00163 if(verbose) System.out.println(" formulaic nodes");
00164 hardFormulaNodes = new Vector<BeliefNode>();
00165 groundFormulaicNodes();
00166 }
00167
00168 if(verbose) {
00169 System.out.println("network size: " + getGroundNetwork().bn.getNodes().length + " nodes");
00170 System.out.println(String.format("construction time: %.4fs", sw.getElapsedTimeSecs()));
00171 }
00172 }
00173
00181 protected BeliefNode instantiateVariable(String functionName, String[] params) throws Exception {
00182
00183 String varName = Signature.formatVarName(functionName, params);
00184 if(instantiatedVariables.contains(varName))
00185 return groundBN.getNode(varName);
00186
00187 BeliefNode ret = null;
00188
00189
00190 Vector<RelationalNode> templates = functionTemplates.get(functionName);
00191 if(templates == null)
00192 throw new Exception("There are no templates from which " + Signature.formatVarName(functionName, params) + " could be constructed.");
00193
00194 boolean combiningRuleNeeded = false;
00195
00196 Vector<Pair<RelationalNode, Vector<Map<Integer, String[]>>>> suitableTemplates = new Vector<Pair<RelationalNode, Vector<Map<Integer, String[]>>>>();
00197
00198
00199 for(RelationalNode relNode : templates) {
00200
00201
00202 boolean preconditionsMet = true;
00203 for(DecisionNode decision : relNode.getDecisionParents()) {
00204 if(!decision.isTrue(relNode.params, params, db, false)) {
00205 preconditionsMet = false;
00206 break;
00207 }
00208 }
00209 if(!preconditionsMet)
00210 continue;
00211
00212
00213 ParentGrounder pg = bln.rbn.getParentGrounder(relNode);
00214 Vector<Map<Integer, String[]>> groundings = pg.getGroundings(params, db);
00215
00216
00217
00218 Vector<RelationalNode> preconds = relNode.getPreconditionParents();
00219 for(RelationalNode precond : preconds) {
00220 Iterator<Map<Integer, String[]>> iter = groundings.iterator();
00221 while(iter.hasNext()) {
00222 Map<Integer, String[]> grounding = iter.next();
00223 String value = db.getVariableValue(precond.getVariableName(grounding.get(precond.index)), true);
00224 if(!value.equals("True"))
00225 iter.remove();
00226 }
00227 }
00228
00229 if(groundings.isEmpty())
00230 continue;
00231
00232
00233
00234
00235 if(groundings.size() > 1 && !relNode.hasAggregator())
00236 combiningRuleNeeded = true;
00237
00238
00239 if(!suitableTemplates.isEmpty())
00240 combiningRuleNeeded = true;
00241
00242 suitableTemplates.add(new Pair<RelationalNode, Vector<Map<Integer, String[]>>>(relNode, groundings));
00243 }
00244
00245
00246 if(suitableTemplates.isEmpty()) {
00247 if(!this.bln.rbn.isEvidenceFunction(functionName))
00248 throw new Exception("No relational node was found that could serve as the template for the variable " + varName);
00249 else {
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261 if(debug)
00262 System.out.println(" " + varName + " (skipped, is evidence)");
00263 }
00264 }
00265
00266
00267 Pair<RelationalNode, Vector<Map<Integer, String[]>>> template = suitableTemplates.iterator().next();
00268 RelationalNode relNode = template.first;
00269
00270
00271 String mainNodeName = relNode.getVariableName(params);
00272 instantiatedVariables.add(mainNodeName);
00273 if(debug)
00274 System.out.println(" " + mainNodeName);
00275
00276
00277 BeliefNode mainNode = groundBN.addNode(mainNodeName, relNode.node.getDomain());
00278 groundNode2TemplateNode.put(mainNode, relNode);
00279 onAddGroundAtomNode(relNode, params, mainNode);
00280
00281
00282 if(!combiningRuleNeeded) {
00283 instantiateVariableFromSingleTemplate(mainNode, template.first, template.second);
00284 }
00285 else {
00286 CombiningRule r = bln.rbn.getCombiningRule(functionName);
00287 if(r == null)
00288 throw new Exception("More than one group of parents for variable " + varName + " but no combining rule was specified");
00289 instantiateVariableWithCombiningRule(mainNode, suitableTemplates, r);
00290 }
00291
00292 return mainNode;
00293 }
00294
00302 protected void instantiateVariableFromSingleTemplate(BeliefNode mainNode, RelationalNode relNode, Vector<Map<Integer, String[]>> groundings) throws Exception {
00303
00304
00305 if(!relNode.hasAggregator()) {
00306 if(groundings.size() != 1)
00307 throw new Exception("Cannot instantiate " + mainNode.getName() + " for " + groundings.size() + " groups of parents.");
00308 if(debug) {
00309 System.out.println(" relevant nodes/parents");
00310 Map<Integer, String[]> grounding = groundings.firstElement();
00311 for(Entry<Integer, String[]> e : grounding.entrySet()) {
00312 System.out.println(" " + bln.rbn.getRelationalNode(e.getKey()).getVariableName(e.getValue()));
00313 }
00314 }
00315 instantiateCPF(groundings.firstElement(), relNode, mainNode);
00316 }
00317
00318 else {
00319 ArrayList<BeliefNode> domprod = new ArrayList<BeliefNode>();
00320 domprod.add(mainNode);
00321
00322 if(!relNode.aggregator.isFunctional) {
00323
00324 Vector<BeliefNode> auxNodes = new Vector<BeliefNode>();
00325 int k = 0;
00326 for(Map<Integer, String[]> grounding : groundings) {
00327
00328 String auxNodeName = String.format("AUX%d_%s", k++, mainNode.getName());
00329 BeliefNode auxNode = groundBN.addNode(auxNodeName, mainNode.getDomain());
00330 auxNodes.add(auxNode);
00331
00332 instantiateCPF(grounding, relNode, auxNode);
00333 }
00334
00335 for(BeliefNode parent : auxNodes) {
00336
00337 groundBN.connect(parent, mainNode, false);
00338 domprod.add(parent);
00339 }
00340 }
00341
00342
00343 else {
00344
00345 for(Map<Integer, String[]> grounding : groundings) {
00346 HashMap<BeliefNode,BeliefNode> src2targetParent = new HashMap<BeliefNode,BeliefNode>();
00347 connectParents(grounding, relNode, mainNode, src2targetParent, null);
00348 domprod.addAll(src2targetParent.values());
00349 }
00350 }
00351
00352 Aggregator combFunc = relNode.aggregator;
00353 if(combFunc == Aggregator.FunctionalOr || combFunc == Aggregator.NoisyOr) {
00354
00355 if(!RelationalBeliefNetwork.isBooleanDomain(mainNode.getDomain()))
00356 throw new Exception("Cannot use OR aggregator on non-Boolean node " + relNode.toString());
00357
00358 String cpfid = combFunc.getFunctionSyntax();
00359 switch(combFunc) {
00360 case FunctionalOr:
00361 cpfid += String.format("-%d-%d", groundings.size(), groundings.firstElement().size());
00362 break;
00363 case NoisyOr:
00364 cpfid += String.format("-%d", groundings.size());
00365 break;
00366 }
00367
00368 CPF cpf = mainNode.getCPF();
00369 BeliefNode[] domprod_arr = domprod.toArray(new BeliefNode[domprod.size()]);
00370
00371 Value[] values = cpfCache.get(cpfid);
00372 if(values != null)
00373 cpf.build(domprod_arr, values);
00374
00375 else {
00376 cpf.buildZero(domprod_arr, false);
00377 CPFFiller filler;
00378 if(combFunc == Aggregator.FunctionalOr)
00379 filler = new CPFFiller_ORGrouped(mainNode, groundings.firstElement().size()-1);
00380 else
00381 filler = new CPFFiller_OR(mainNode);
00382 filler.fill();
00383
00384 cpfCache.put(cpfid, cpf.getValues());
00385 }
00386
00387 cpfIDs.put(mainNode, cpfid);
00388 }
00389 else
00390 throw new Exception("Cannot ground structure because of multiple parent sets for node " + mainNode.getName() + " with unhandled aggregator " + relNode.aggregator);
00391 }
00392 }
00393
00394 protected BeliefNode instantiateVariableWithCombiningRule(BeliefNode mainNode, Vector<Pair<RelationalNode, Vector<Map<Integer, String[]>>>> suitableTemplates, CombiningRule r) throws Exception {
00395
00396 HashMap<BeliefNode, Integer> parentIndices = new HashMap<BeliefNode, Integer>();
00397 Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>> templateDomprodMap = new Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>>();
00398 int domProdIndex = 1;
00399 for(Pair<RelationalNode, Vector<Map<Integer, String[]>>> template : suitableTemplates) {
00400 RelationalNode relNode = template.first;
00401 Vector<Map<Integer, String[]>> nodeGroundings = template.second;
00402 for(Map<Integer, String[]> nodeGrounding : nodeGroundings) {
00403 Map<BeliefNode,Integer> relParentIndex2domprodIndex = new HashMap<BeliefNode,Integer>();
00404 for(Entry<Integer,String[]> entry : nodeGrounding.entrySet()) {
00405 RelationalNode relParent = bln.rbn.getRelationalNode(entry.getKey());
00406 if(relParent == relNode)
00407 continue;
00408 BeliefNode parent = instantiateVariable(relParent.getFunctionName(), entry.getValue());
00409 if(parent == null)
00410 throw new Exception();
00411 Integer index = parentIndices.get(parent);
00412 if(index == null) {
00413 index = domProdIndex++;
00414 parentIndices.put(parent, index);
00415 }
00416 relParentIndex2domprodIndex.put(relParent.node, index);
00417 }
00418 templateDomprodMap.add(new Pair<RelationalNode, Map<BeliefNode,Integer>>(relNode, relParentIndex2domprodIndex));
00419 }
00420 }
00421
00422
00423 CPF cpf = mainNode.getCPF();
00424 BeliefNode[] domprod = new BeliefNode[1 + parentIndices.size()];
00425 domprod[0] = mainNode;
00426 for(Entry<BeliefNode, Integer> e : parentIndices.entrySet()) {
00427 domprod[e.getValue()] = e.getKey();
00428 this.groundBN.connect(e.getKey(), mainNode, false);
00429 }
00430 cpf.buildZero(domprod, false);
00431
00432
00433 fillCPFCombiningRule(cpf, 1, new int[domprod.length], templateDomprodMap, r);
00434
00435 return mainNode;
00436 }
00437
00438 protected void fillCPFCombiningRule(CPF cpf, int i, int[] addr, Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>> templateDomprodMap, CombiningRule r) {
00439 BeliefNode[] domprod = cpf.getDomainProduct();
00440 if(i == domprod.length) {
00441 if(r.booleanSemantics) {
00442 double trueCase = fillCPFCombiningRule_computeColumnEntry(0, addr, templateDomprodMap, r);
00443 cpf.put(addr, new ValueDouble(trueCase));
00444 addr[0] = 1;
00445 cpf.put(addr, new ValueDouble(1.0-trueCase));
00446 }
00447 else {
00448 int domSize = domprod[0].getDomain().getOrder();
00449 double[] values = new double[domSize];
00450 double Z = 0.0;
00451 for(int j = 0; j < domSize; j++) {
00452 values[j] = fillCPFCombiningRule_computeColumnEntry(j, addr, templateDomprodMap, r);
00453 Z += values[j];
00454 }
00455 for(int j = 0; j < domSize; j++) {
00456 values[j] /= Z;
00457 addr[0] = j;
00458 cpf.put(addr, new ValueDouble(values[j]));
00459 }
00460 }
00461 return;
00462 }
00463
00464 int domSize = domprod[i].getDomain().getOrder();
00465 for(int domIdx = 0; domIdx < domSize; domIdx++) {
00466 addr[i] = domIdx;
00467 fillCPFCombiningRule(cpf, i+1, addr, templateDomprodMap, r);
00468 }
00469 }
00470
00471 protected double fillCPFCombiningRule_computeColumnEntry(int idx0, int[] addr, Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>> templateDomprodMap, CombiningRule r) {
00472
00473 addr[0] = idx0;
00474 Vector<Double> values = new Vector<Double>();
00475 for(Pair<RelationalNode, Map<BeliefNode, Integer>> m : templateDomprodMap) {
00476 RelationalNode relNode = m.first;
00477 CPF cpf2 = relNode.node.getCPF();
00478 BeliefNode[] domprod2 = cpf2.getDomainProduct();
00479 int[] addr2 = new int[domprod2.length];
00480 addr2[0] = addr[0];
00481 for(int i2 = 1; i2 < domprod2.length; i2++) {
00482 Integer i1 = m.second.get(domprod2[i2]);
00483 if(i1 != null)
00484 addr2[i2] = addr[i1];
00485 else
00486 addr2[i2] = 0;
00487 }
00488 Double v = cpf2.getDouble(addr2);
00489 values.add(v);
00490 }
00491 return r.compute(values);
00492 }
00493
00494
00495 protected void init() {}
00496
00497 protected abstract void groundFormulaicNodes() throws Exception;
00498
00499 protected abstract void onAddGroundAtomNode(RelationalNode relNode, String[] params, BeliefNode instance);
00500
00508 public BeliefNode addHardFormulaNode(String nodeName, Collection<String> parentGAs) throws Exception {
00509 BeliefNode[] domprod = new BeliefNode[1+parentGAs.size()];
00510 BeliefNode node = groundBN.addNode(nodeName);
00511 domprod[0] = node;
00512 hardFormulaNodes.add(node);
00513 int i = 1;
00514 for(String strGA : parentGAs) {
00515 BeliefNode parent = groundBN.getNode(strGA);
00516 if(parent == null) {
00517 String parentName = strGA.substring(0, strGA.lastIndexOf(",")) + ")";
00518 parent = groundBN.getNode(parentName);
00519 if(parent == null)
00520 throw new Exception("Could not find node for ground atom " + strGA);
00521 }
00522 domprod[i++] = parent;
00523 groundBN.connect(parent, node, false);
00524 }
00525 node.getCPF().buildZero(domprod, false);
00526 return node;
00527 }
00528
00529 public Database getDatabase() {
00530 return db;
00531 }
00532
00543 protected Vector<BeliefNode> connectParents(Map<Integer, String[]> parentGrounding, RelationalNode srcRelNode, BeliefNode targetNode, HashMap<BeliefNode, BeliefNode> src2targetParent, HashMap<BeliefNode, Integer> constantSettings) throws Exception {
00544 Vector<BeliefNode> domprod = new Vector<BeliefNode>();
00545 domprod.add(targetNode);
00546 HashSet<BeliefNode> handledTargetParents = new HashSet<BeliefNode>();
00547 for(Entry<Integer, String[]> entry : parentGrounding.entrySet()) {
00548 RelationalNode relParent = bln.rbn.getRelationalNode(entry.getKey());
00549 if(relParent == srcRelNode)
00550 continue;
00551 if(relParent.isConstant) {
00552
00553 if(constantSettings != null)
00554 constantSettings.put(relParent.node, ((Discrete)relParent.node.getDomain()).findName(entry.getValue()[0]));
00555 continue;
00556 }
00557 if(relParent.isPrecondition) {
00558 if(constantSettings != null)
00559 constantSettings.put(relParent.node, 0);
00560 continue;
00561 }
00562 BeliefNode parent = instantiateVariable(relParent.getFunctionName(), entry.getValue());
00563 if(parent == null)
00564 throw new Exception("Error instantiating parent '" + Signature.formatVarName(relParent.getFunctionName(), entry.getValue()) + "' while instantiating " + targetNode);
00565 if(handledTargetParents.contains(parent))
00566 throw new Exception("Error instantiating " + targetNode + " from " + srcRelNode + ": Duplicate parent " + parent);
00567
00568 handledTargetParents.add(parent);
00569 groundBN.connect(parent, targetNode, false);
00570 domprod.add(parent);
00571 if(src2targetParent != null) src2targetParent.put(relParent.node, parent);
00572 }
00573 return domprod;
00574 }
00575
00583 protected void instantiateCPF(Map<Integer, String[]> parentGrounding, RelationalNode srcRelNode, BeliefNode targetNode) throws Exception {
00584
00585 HashMap<BeliefNode, BeliefNode> src2targetParent = new HashMap<BeliefNode, BeliefNode>();
00586 HashMap<BeliefNode, Integer> constantSettings = new HashMap<BeliefNode, Integer>();
00587 Vector<BeliefNode> vDomProd = connectParents(parentGrounding, srcRelNode, targetNode, src2targetParent, constantSettings);
00588
00589
00590
00591 BeliefNode[] srcDomainProd = srcRelNode.node.getCPF().getDomainProduct();
00592 for(int i = 1; i < srcDomainProd.length; i++) {
00593 if(srcDomainProd[i].getType() == BeliefNode.NODE_DECISION)
00594 constantSettings.put(srcDomainProd[i], 0);
00595 }
00596
00597
00598 CPF targetCPF = targetNode.getCPF();
00599 BeliefNode[] targetDomainProd = vDomProd.toArray(new BeliefNode[vDomProd.size()]);
00600 int j = 1;
00601 HashSet<BeliefNode> handledParents = new HashSet<BeliefNode>();
00602 for(int i = 1; i < srcDomainProd.length; i++) {
00603 BeliefNode targetParent = src2targetParent.get(srcDomainProd[i]);
00604
00605 if(targetParent != null) {
00606 if(handledParents.contains(targetParent))
00607 throw new Exception("Cannot instantiate " + targetNode + " using template " + srcRelNode + ": Duplicate parent " + targetParent);
00608 if(j >= targetDomainProd.length)
00609 throw new Exception("Domain product of " + targetNode + " too small; size = " + targetDomainProd.length + "; tried to add " + targetParent + "; already added " + StringTool.join(",", targetDomainProd));
00610 targetDomainProd[j++] = targetParent;
00611 handledParents.add(targetParent);
00612 }
00613 }
00614 if(j != targetDomainProd.length)
00615 throw new Exception("CPF domain product not fully filled: handled " + j + ", needed " + targetDomainProd.length);
00616
00617
00618 String cpfID = Integer.toString(srcRelNode.index);
00619
00620
00621 if(srcDomainProd.length == targetDomainProd.length) {
00622 targetCPF.build(targetDomainProd, srcRelNode.node.getCPF().getValues());
00623 }
00624
00625 else {
00626 Value[] subCPF;
00627
00628 cpfID += constantSettings.toString();
00629 subCPF = cpfCache.get(cpfID);
00630 if(subCPF == null) {
00631 subCPF = getSubCPFValues(srcRelNode.node.getCPF(), constantSettings);
00632 cpfCache.put(cpfID, subCPF);
00633 }
00634
00635 targetCPF.build(targetDomainProd, subCPF);
00636 }
00637 cpfIDs.put(targetNode, cpfID);
00638
00639
00640
00641
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654
00655 }
00656
00657 protected Value[] getSubCPFValues(CPF cpf, HashMap<BeliefNode, Integer> constantSettings) {
00658 BeliefNode[] domProd = cpf.getDomainProduct();
00659 int[] addr = new int[domProd.length];
00660 Vector<Value> v = new Vector<Value>();
00661 getSubCPFValues(cpf, constantSettings, 0, addr, v);
00662 return v.toArray(new Value[0]);
00663 }
00664
00665 protected void getSubCPFValues(CPF cpf, HashMap<BeliefNode, Integer> constantSettings, int i, int[] addr, Vector<Value> ret) {
00666 BeliefNode[] domProd = cpf.getDomainProduct();
00667 if(i == domProd.length) {
00668 ret.add(cpf.get(addr));
00669 return;
00670 }
00671 BeliefNode n = domProd[i];
00672
00673 Integer setting = constantSettings.get(n);
00674 if(setting != null) {
00675 addr[i] = setting;
00676 getSubCPFValues(cpf, constantSettings, i+1, addr, ret);
00677 }
00678
00679 else {
00680 Domain d = domProd[i].getDomain();
00681 for(int j = 0; j < d.getOrder(); j++) {
00682 addr[i] = j;
00683 getSubCPFValues(cpf, constantSettings, i+1, addr, ret);
00684 }
00685 }
00686 }
00687
00692 public abstract class CPFFiller {
00693 CPF cpf;
00694 BeliefNode[] nodes;
00695
00696 public CPFFiller(BeliefNode node) {
00697 cpf = node.getCPF();
00698 nodes = cpf.getDomainProduct();
00699 }
00700
00701 public void fill() throws Exception {
00702 int[] addr = new int[nodes.length];
00703 fill(0, addr);
00704 }
00705
00706 protected void fill(int iNode, int[] addr) throws Exception {
00707
00708
00709 if(iNode == nodes.length) {
00710 cpf.put(addr, new ValueDouble(getValue(addr)));
00711 return;
00712 }
00713 Discrete domain = (Discrete)nodes[iNode].getDomain();
00714
00715 for(int i = 0; i < domain.getOrder(); i++) {
00716
00717 addr[iNode] = i;
00718
00719 fill(iNode+1, addr);
00720 }
00721 }
00722
00723 protected abstract double getValue(int[] addr);
00724 }
00725
00730 public class CPFFiller_OR extends CPFFiller {
00731 public CPFFiller_OR(BeliefNode node) {
00732 super(node);
00733 }
00734
00735 @Override
00736 protected double getValue(int[] addr) {
00737
00738 boolean isTrue = false;
00739 for(int i = 1; i < addr.length; i++)
00740 isTrue = isTrue || addr[i] == 0;
00741 return (addr[0] == 0 && isTrue) || (addr[0] == 1 && !isTrue) ? 1.0 : 0.0;
00742 }
00743 }
00744
00749 public class CPFFiller_ORGrouped extends CPFFiller {
00750 int groupSize;
00751
00757 public CPFFiller_ORGrouped(BeliefNode node, int groupSize) {
00758 super(node);
00759 this.groupSize = groupSize;
00760 }
00761
00762 @Override
00763 protected double getValue(int[] addr) {
00764
00765
00766 boolean isTrue = false;
00767 int g = 0;
00768 for(int i = 1; i < addr.length;) {
00769 if((i-1) % groupSize == 0) {
00770 if(isTrue)
00771 break;
00772 }
00773 isTrue = addr[i] == 0;
00774 if(!isTrue) {
00775 ++g;
00776 i = 1 + g * groupSize;
00777 continue;
00778 }
00779 ++i;
00780 }
00781 return (addr[0] == 0 && isTrue) || (addr[0] == 1 && !isTrue) ? 1.0 : 0.0;
00782 }
00783 }
00784
00785 public void show() {
00786 groundBN.show();
00787 }
00788
00794 public int[] getFullEvidence(String[][] evidence) {
00795 String[][] fullEvidence = new String[evidence.length+this.hardFormulaNodes.size()][2];
00796 for(int i = 0; i < evidence.length; i++) {
00797 fullEvidence[i][0] = evidence[i][0];
00798 fullEvidence[i][1] = evidence[i][1];
00799 }
00800 {
00801 int i = evidence.length;
00802 for(BeliefNode node : hardFormulaNodes) {
00803 fullEvidence[i][0] = node.getName();
00804 fullEvidence[i][1] = "True";
00805 i++;
00806 }
00807 }
00808 return groundBN.evidence2DomainIndices(fullEvidence);
00809 }
00810
00811 public BeliefNetworkEx getGroundNetwork() {
00812 return this.groundBN;
00813 }
00814
00820 public String getCPFID(BeliefNode node) {
00821 String cpfID = cpfIDs.get(node);
00822 return cpfID;
00823 }
00824
00825 public void setDebugMode(boolean enabled) {
00826 this.debug = enabled;
00827 }
00828
00829 public RelationalBeliefNetwork getRBN() {
00830 return bln.rbn;
00831 }
00832
00838 public RelationalNode getTemplateOf(BeliefNode node) {
00839 return this.groundNode2TemplateNode.get(node);
00840 }
00841
00846 public Vector<BeliefNode> getAuxiliaryVariables() {
00847 return this.hardFormulaNodes;
00848 }
00849
00850 public void setVerbose(boolean verbose) {
00851 this.verbose = verbose;
00852 }
00853
00854 public ParameterHandler getParameterHandler() {
00855 return paramHandler;
00856 }
00857 }