00001
00002
00003
00004
00005
00006
00007 package edu.tum.cs.bayesnets.inference;
00008
00009 import java.io.File;
00010 import java.io.FileNotFoundException;
00011 import java.io.PrintStream;
00012 import java.util.Collection;
00013 import java.util.HashMap;
00014 import java.util.HashSet;
00015 import java.util.Vector;
00016
00017 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00018 import edu.ksu.cis.bnj.ver3.core.CPF;
00019 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00020 import edu.tum.cs.bayesnets.inference.IJGP.JoinGraph.Arc;
00021 import edu.tum.cs.util.StringTool;
00022 import edu.tum.cs.util.datastruct.MutableDouble;
00023
00028 public class IJGP extends Sampler {
00029
00030 protected JoinGraph jg;
00031 Vector<JoinGraph.Node> jgNodes;
00032 protected BeliefNode[] nodes;
00033 protected final boolean debug = false;
00034 protected int ibound;
00035 protected boolean verbose = true;
00036
00037 public IJGP(BeliefNetworkEx bn) throws Exception {
00038 super(bn);
00039 this.nodes = bn.bn.getNodes();
00040
00041 ibound = 1;
00042 for (BeliefNode n : nodes) {
00043 int l = n.getCPF().getDomainProduct().length;
00044 if (l > ibound)
00045 ibound = l;
00046 }
00047
00048 if(verbose)
00049 out.printf("constructing join-graph with i-bound %d...\n", ibound);
00050 jg = new JoinGraph(bn, ibound);
00051
00052 }
00053
00054 @Override
00055 public String getAlgorithmName() {
00056 return String.format("IJGP[i-bound %d]", this.ibound);
00057 }
00058
00059
00060
00061
00062
00063
00064
00065 @Override
00066 public SampledDistribution _infer() throws Exception {
00067
00068 if(verbose) out.println("determining order...");
00069 jgNodes = jg.getTopologicalorder();
00070 if(debug) {
00071 out.println("Topological Order: ");
00072 for (int i = 0; i < jgNodes.size(); i++) {
00073 out.println(jgNodes.get(i).getShortName());
00074 }
00075 }
00076
00077 if(verbose) out.println("processing observed variables...");
00078 for (JoinGraph.Node n : jgNodes) {
00079 Vector<BeliefNode> nodes = new Vector<BeliefNode>(n.getNodes());
00080 for (BeliefNode belNode : nodes) {
00081 int nodeIdx = bn.getNodeIndex(belNode);
00082 int domainIdx = evidenceDomainIndices[nodeIdx];
00083 if (domainIdx > -1)
00084 n.nodes.remove(belNode);
00085 }
00086 }
00087 out.printf("running propagation (%d steps)...\n", this.numSamples);
00088 for (int step = 1; step <= this.numSamples; step++) {
00089 out.printf("step %d\n", step);
00090
00091 int s = jgNodes.size();
00092 boolean direction = true;
00093 for (int j = 0; j < 2 * s; j++) {
00094 int i;
00095 if (j < s)
00096 i = j;
00097 else {
00098 i = 2 * s - j - 1;
00099 direction = false;
00100 }
00101 JoinGraph.Node u = jgNodes.get(i);
00102
00103 int topIndex = jgNodes.indexOf(u);
00104 for (JoinGraph.Node v : u.getNeighbors()) {
00105 if ((direction && jgNodes.indexOf(v) < topIndex)
00106 || (!direction && jgNodes.indexOf(v) > topIndex)) {
00107 continue;
00108 }
00109 Arc arc = u.getArcToNode(v);
00110 arc.clearOutMessages(u);
00111
00112 Cluster cluster_u = new Cluster(u, v);
00113
00114
00115 HashSet<BeliefNode> elim = new HashSet<BeliefNode>(u.nodes);
00116
00117
00118
00119 elim.removeAll(arc.separator);
00120 Cluster cluster_H = cluster_u.getReducedCluster(elim);
00121
00122 Cluster cluster_A = cluster_u.copy();
00123 cluster_A.subtractCluster(cluster_H);
00124
00125 if (debug) {
00126 out.println(" cluster_v(u): \n" + cluster_u);
00127 out.println(" A: \n" + cluster_A);
00128 out.println(" H_(u,v): \n" + cluster_H);
00129 }
00130
00131 int[] varsToSumOver = new int[elim.size()];
00132 int k = 0;
00133 for (BeliefNode n : elim)
00134 varsToSumOver[k++] = bn.getNodeIndex(n);
00135
00136 MessageFunction m = new MessageFunction(arc.separator,
00137 varsToSumOver, cluster_A);
00138 m.calcuSave(evidenceDomainIndices.clone());
00139 arc.addOutMessage(u, m);
00140 for (MessageFunction mf : cluster_H.functions) {
00141 mf.calcuSave(evidenceDomainIndices.clone());
00142 arc.addOutMessage(u, mf);
00143 }
00144 for (BeliefNode n : cluster_H.cpts) {
00145 arc.addCPTOutMessage(u, n);
00146 }
00147 }
00148 }
00149 }
00150
00151
00152 out.println("computing results...");
00153 this.createDistribution();
00154 dist.Z = 1.0;
00155 for (int i = 0; i < nodes.length; i++) {
00156
00157 if (evidenceDomainIndices[i] >= 0) {
00158 dist.values[i][evidenceDomainIndices[i]] = 1.0;
00159 continue;
00160 }
00161
00162
00163
00164 JoinGraph.Node u = null;
00165 for (JoinGraph.Node node : jgNodes) {
00166 if (node.nodes.contains(nodes[i])) {
00167 u = node;
00168 break;
00169 }
00170 }
00171 if (u == null)
00172 throw new Exception(
00173 "Could not find vertex in join graph containing variable "
00174 + nodes[i].getName());
00175
00176
00177
00178 int domSize = dist.values[i].length;
00179 double Z = 0.0;
00180 int[] nodeDomainIndices = evidenceDomainIndices.clone();
00181 for (int j = 0; j < domSize; j++) {
00182 nodeDomainIndices[i] = j;
00183 MutableDouble sum = new MutableDouble(0.0);
00184 BeliefNode[] nodesToSumOver = u.nodes
00185 .toArray(new BeliefNode[u.nodes.size()]);
00186 computeSum(0, nodesToSumOver, nodes[i], new Cluster(u),
00187 nodeDomainIndices, sum);
00188 Z += (dist.values[i][j] = sum.value);
00189 }
00190
00191 for (int j = 0; j < domSize; j++)
00192 dist.values[i][j] /= Z;
00193 }
00194
00195 return dist;
00196 }
00197
00198 protected void computeSum(int i, BeliefNode[] varsToSumOver,
00199 BeliefNode excludedNode, Cluster u, int[] nodeDomainIndices,
00200 MutableDouble result) {
00201 if (i == varsToSumOver.length) {
00202 result.value += u.product(nodeDomainIndices);
00203 return;
00204 }
00205 if (varsToSumOver[i] == excludedNode)
00206 computeSum(i + 1, varsToSumOver, excludedNode, u,
00207 nodeDomainIndices, result);
00208 else {
00209 for (int j = 0; j < varsToSumOver[i].getDomain().getOrder(); j++) {
00210 nodeDomainIndices[this.getNodeIndex(varsToSumOver[i])] = j;
00211 computeSum(i + 1, varsToSumOver, excludedNode, u,
00212 nodeDomainIndices, result);
00213 }
00214 }
00215 }
00216
00217 protected class Cluster {
00218 HashSet<BeliefNode> cpts = new HashSet<BeliefNode>();
00219 HashSet<MessageFunction> functions = new HashSet<MessageFunction>();
00220 JoinGraph.Node node;
00221
00222 public Cluster(JoinGraph.Node u) {
00223
00224 this.node = u;
00225
00226 for (CPF cpf : u.functions)
00227 cpts.add(cpf.getDomainProduct()[0]);
00228
00229 for (JoinGraph.Node nb : u.getNeighbors()) {
00230 JoinGraph.Arc arc = u.arcs.get(nb);
00231 HashSet<MessageFunction> m = arc.getInMessage(u);
00232 if (!m.isEmpty())
00233 functions.addAll(m);
00234 HashSet<BeliefNode> bn = arc.getCPTInMessage(u);
00235 if (!bn.isEmpty())
00236 cpts.addAll(bn);
00237 }
00238 }
00239
00240 public Cluster(JoinGraph.Node u, JoinGraph.Node v) {
00241
00242 this.node = u;
00243
00244 for (CPF cpf : u.functions)
00245 cpts.add(cpf.getDomainProduct()[0]);
00246
00247 for (JoinGraph.Node nb : u.getNeighbors()) {
00248 if (!nb.equals(v)) {
00249 JoinGraph.Arc arc = u.arcs.get(nb);
00250 HashSet<MessageFunction> m = arc.getInMessage(u);
00251 if (!m.isEmpty())
00252 functions.addAll(m);
00253 HashSet<BeliefNode> bn = arc.getCPTInMessage(u);
00254 if (!bn.isEmpty())
00255 cpts.addAll(bn);
00256 }
00257 }
00258 }
00259
00260 public Cluster() {
00261 }
00262
00263 public String toString() {
00264 StringBuffer sb = new StringBuffer();
00265 sb.append(StringTool.join(", ", this.cpts));
00266 sb.append("; ");
00267 sb.append(StringTool.join(", ", this.functions));
00268 return sb.toString();
00269 }
00270
00271 public void excludeMessagesFrom(JoinGraph.Node n) {
00272 JoinGraph.Arc arc = node.arcs.get(n);
00273 for (MessageFunction mf : arc.getInMessage(node)) {
00274 if (functions.contains(mf))
00275 functions.remove(mf);
00276 }
00277 for (BeliefNode bn : arc.getCPTInMessage(node)) {
00278 if (cpts.contains(bn))
00279 functions.remove(bn);
00280 }
00281 }
00282
00283 public Cluster copy() {
00284 Cluster copyCluster = new Cluster();
00285 for (BeliefNode cpt : cpts) {
00286 copyCluster.cpts.add(cpt);
00287 }
00288 for (MessageFunction f : functions) {
00289 copyCluster.functions.add(f);
00290 }
00291 return copyCluster;
00292 }
00293
00294 public Cluster getReducedCluster(HashSet<BeliefNode> nodes)
00295 throws CloneNotSupportedException {
00296
00297
00298 Cluster redCluster = this.copy();
00299 for (BeliefNode bn : nodes) {
00300 HashSet<BeliefNode> foo = (HashSet<BeliefNode>) cpts.clone();
00301 for (BeliefNode n : foo) {
00302 BeliefNode[] domProd = n.getCPF().getDomainProduct();
00303
00304
00305
00306 for (int i = 0; i < domProd.length; i++) {
00307 if (bn.equals(domProd[i])) {
00308 redCluster.cpts.remove(n);
00309 break;
00310 }
00311 }
00312 }
00313 for (MessageFunction m : ((HashSet<MessageFunction>) functions
00314 .clone())) {
00315 if (m.scope.contains(bn))
00316 redCluster.functions.remove(m);
00317 }
00318 }
00319 return redCluster;
00320 }
00321
00322 public void subtractCluster(Cluster c2) {
00323
00324
00325 for (BeliefNode n : ((HashSet<BeliefNode>) c2.cpts.clone())) {
00326
00327 cpts.remove(n);
00328 }
00329 for (MessageFunction m : ((HashSet<MessageFunction>) c2.functions
00330 .clone())) {
00331 functions.remove(m);
00332 }
00333 }
00334
00335 public double product(int[] nodeDomainIndices) {
00336 double ret = 1.0;
00337 for (BeliefNode n : cpts) {
00338
00339 ret *= getCPTProbability(n, nodeDomainIndices);
00340 }
00341 for (MessageFunction f : this.functions) {
00342
00343 ret *= f.compute(nodeDomainIndices);
00344 }
00345 return ret;
00346 }
00347 }
00348
00349 protected class MessageFunction {
00350
00351 protected int[] varsToSumOver;
00352 protected MessageTable table;
00353 HashSet<BeliefNode> cpts;
00354 Iterable<MessageFunction> childFunctions;
00355 HashSet<BeliefNode> scope;
00356
00357 public MessageFunction(HashSet<BeliefNode> scope, int[] varsToSumOver,
00358 Cluster cluster) {
00359 this.scope = scope;
00360 this.varsToSumOver = varsToSumOver;
00361 this.cpts = cluster.cpts;
00362 this.childFunctions = cluster.functions;
00363 this.table = null;
00364 }
00365
00366 public void calcuSave(int[] nodeDomainIndices) {
00367 table = new MessageTable(new Vector<BeliefNode>(scope), 0);
00368 int[] scopeToSumOver = new int[scope.size()];
00369 int k = 0;
00370 for (BeliefNode n : scope)
00371 scopeToSumOver[k++] = bn.getNodeIndex(n);
00372 calcuSave(scopeToSumOver, 0, nodeDomainIndices.clone());
00373 }
00374
00375 public void calcuSave(int[] scopeToSumOver, int i, int[] nodeDomainIndices) {
00376 if (i == scope.size()) {
00377 table.addEntry(nodeDomainIndices, compute(nodeDomainIndices));
00378 return;
00379 } else {
00380 int idxVar = scopeToSumOver[i];
00381 for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) {
00382 nodeDomainIndices[idxVar] = v;
00383 calcuSave(scopeToSumOver, i + 1, nodeDomainIndices);
00384 }
00385 }
00386 }
00387
00388 public double compute(int[] nodeDomainIndices) {
00389 if (!table.containsEntry(nodeDomainIndices)) {
00390 MutableDouble sum = new MutableDouble(0.0);
00391 compute(varsToSumOver, 0, nodeDomainIndices, sum);
00392 return sum.value;
00393 } else {
00394 return table.getEntry(nodeDomainIndices);
00395 }
00396 }
00397
00398 protected void compute(int[] varsToSumOver, int i,
00399 int[] nodeDomainIndices, MutableDouble sum) {
00400 if (i == varsToSumOver.length) {
00401 double result = 1.0;
00402 for (BeliefNode node : cpts)
00403 result *= getCPTProbability(node, nodeDomainIndices);
00404 for (MessageFunction h : childFunctions)
00405 result *= h.compute(nodeDomainIndices);
00406 sum.value += result;
00407 return;
00408 }
00409 int idxVar = varsToSumOver[i];
00410 for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) {
00411 nodeDomainIndices[idxVar] = v;
00412 compute(varsToSumOver, i + 1, nodeDomainIndices, sum);
00413 }
00414 }
00415
00416 public String toString() {
00417 StringBuffer sb = new StringBuffer("MF[");
00418 sb.append("scope: " + StringTool.join(", ", scope));
00419 sb.append("; CPFs:");
00420 int i = 0;
00421 for (BeliefNode n : this.cpts) {
00422 if (i++ > 0)
00423 sb.append("; ");
00424 sb.append(n.getCPF().toString());
00425 }
00426 sb.append("; children: ");
00427 sb.append(StringTool.join("; ", this.childFunctions));
00428 sb.append("]");
00429 return sb.toString();
00430 }
00431
00432 protected class MessageTable {
00433
00434 protected Vector<BeliefNode> scope;
00435 protected boolean leaf;
00436 protected MessageTable[] map;
00437 protected Double[] result;
00438
00439 public MessageTable(Vector<BeliefNode> scope, int i) {
00440 int domSize = scope.get(i).getDomain().getOrder();
00441 this.map = new MessageTable[domSize];
00442 this.scope = scope;
00443 if (i == scope.size() - 1) {
00444 leaf = true;
00445 result = new Double[domSize];
00446 } else {
00447 leaf = false;
00448 result = null;
00449 for (int j = 0; j < scope.get(i).getDomain().getOrder(); j++) {
00450 map[j] = new MessageTable(scope, i + 1);
00451 }
00452 }
00453 }
00454
00455 public void addEntry(int[] domainIndices, double entry) {
00456 addEntry(domainIndices, 0, entry);
00457 }
00458
00459 public void addEntry(int[] domainIndices, int i, double entry) {
00460 if (i != scope.size() - 1) {
00461 int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
00462 map[idx].addEntry(domainIndices, i + 1, entry);
00463 } else {
00464 int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
00465 result[idx] = entry;
00466 }
00467 }
00468
00469 public double getEntry(int[] domainIndices) {
00470 return getEntry(domainIndices, 0);
00471 }
00472
00473 public double getEntry(int[] domainIndices, int i) {
00474 if (i != scope.size() - 1) {
00475 int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
00476 return map[idx].getEntry(domainIndices, i + 1);
00477 }
00478 else {
00479 int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
00480 return result[idx];
00481 }
00482 }
00483
00484 public boolean containsEntry(int[] domainIndices){
00485 return containsEntry(domainIndices, 0);
00486 }
00487
00488 public boolean containsEntry(int[] domainIndices, int i){
00489 if (i != scope.size() - 1){
00490 int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
00491 if (map[idx] == null){
00492 return false;
00493 }
00494 else{
00495 return map[idx].containsEntry(domainIndices, i+1);
00496 }
00497 }
00498 else{
00499 int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
00500 return (result[idx] != null);
00501 }
00502 }
00503 }
00504 }
00505
00506 protected static class BucketVar {
00507
00508 public HashSet<BeliefNode> nodes;
00509 public CPF cpf = null;
00510 public Vector<MiniBucket> parents;
00511 public BeliefNode idxVar;
00512
00513 public BucketVar(HashSet<BeliefNode> nodes) {
00514 this(nodes, null);
00515 }
00516
00517 public BucketVar(HashSet<BeliefNode> nodes, MiniBucket parent) {
00518 this.nodes = nodes;
00519 if (nodes.size() == 0)
00520 throw new RuntimeException(
00521 "Must provide non-empty set of nodes.");
00522 this.parents = new Vector<MiniBucket>();
00523 if (parent != null)
00524 parents.add(parent);
00525 }
00526
00527 public void setFunction(CPF cpf) {
00528 this.cpf = cpf;
00529 }
00530
00531 public void addInArrow(MiniBucket parent) {
00532 parents.add(parent);
00533 }
00534
00535 public BeliefNode getMaxNode(BeliefNetworkEx bn) {
00536
00537
00538 BeliefNode maxNode = null;
00539 int[] topOrder = bn.getTopologicalOrder();
00540 for (int i = topOrder.length - 1; i > -1; i--) {
00541 for (BeliefNode node : nodes) {
00542 if (bn.getNodeIndex(node) == topOrder[i]) {
00543 return node;
00544 }
00545 }
00546 }
00547 return maxNode;
00548 }
00549
00550 public String toString() {
00551 return "[" + StringTool.join(" ", this.nodes) + "]";
00552 }
00553
00554 public boolean equals(BucketVar other) {
00555 if (other.nodes.size() != this.nodes.size())
00556 return false;
00557 for (BeliefNode n : nodes)
00558 if (!other.nodes.contains(n))
00559 return false;
00560 return true;
00561 }
00562 }
00563
00564 protected static class MiniBucket {
00565
00566 public HashSet<BucketVar> items;
00567 public Bucket bucket;
00568 public HashSet<MiniBucket> parents;
00569 public BucketVar child;
00570
00571 public MiniBucket(Bucket bucket) {
00572 this.items = new HashSet<BucketVar>();
00573 this.bucket = bucket;
00574 this.child = null;
00575 this.parents = new HashSet<MiniBucket>();
00576 }
00577
00578 public void addVar(BucketVar bv) {
00579 items.add(bv);
00580 for (MiniBucket p : bv.parents)
00581 parents.add(p);
00582 }
00583
00584 public String toString() {
00585 return "Minibucket[" + StringTool.join(" ", items) + "]";
00586 }
00587 }
00588
00589 protected static class Bucket {
00590
00591 public BeliefNode bucketNode;
00592 public HashSet<BucketVar> vars = new HashSet<BucketVar>();
00593 public Vector<MiniBucket> minibuckets = new Vector<MiniBucket>();
00594
00595 public Bucket(BeliefNode bucketNode) {
00596 this.bucketNode = bucketNode;
00597 }
00598
00599 public void addVar(BucketVar bv) {
00600 for (BucketVar v : vars)
00601 if (v.equals(bv)) {
00602 for (MiniBucket p : bv.parents)
00603 v.addInArrow(p);
00604 return;
00605 }
00606 vars.add(bv);
00607 }
00608
00614 public void partition(int bound) {
00615 minibuckets.add(new MiniBucket(this));
00616 HashSet<BeliefNode> count = new HashSet<BeliefNode>();
00617 for (BucketVar bv : vars) {
00618 int newNodes = 0;
00619 for (BeliefNode n : bv.nodes) {
00620 if (!count.contains(n)) {
00621 newNodes++;
00622 }
00623 }
00624 if (count.size() + newNodes > bound) {
00625
00626 minibuckets.add(new MiniBucket(this));
00627 count.clear();
00628 count.addAll(bv.nodes);
00629 } else {
00630 count.addAll(bv.nodes);
00631 }
00632 minibuckets.lastElement().addVar(bv);
00633 }
00634 }
00635
00636 public HashSet<BucketVar> createScopeFunctions() {
00637 HashSet<BucketVar> newVars = new HashSet<BucketVar>();
00638 for (MiniBucket mb : minibuckets) {
00639 HashSet<BeliefNode> nodes = new HashSet<BeliefNode>();
00640 for (BucketVar bv : mb.items) {
00641 for (BeliefNode bn : bv.nodes) {
00642 if (bn != bucketNode)
00643 nodes.add(bn);
00644 }
00645 }
00646 if (nodes.size() != 0) {
00647 BucketVar newBucketVar = new BucketVar(nodes, mb);
00648 newVars.add(newBucketVar);
00649 }
00650 }
00651 return newVars;
00652 }
00653
00654 public String toString() {
00655 return StringTool.join(" ", vars);
00656 }
00657 }
00658
00659 protected static class SchematicMiniBucket {
00660
00661 public HashMap<BeliefNode, Bucket> bucketMap;
00662 public BeliefNetworkEx bn;
00663
00664 public SchematicMiniBucket(BeliefNetworkEx bn, int bound) {
00665 this.bn = bn;
00666 bucketMap = new HashMap<BeliefNode, Bucket>();
00667
00668 BeliefNode[] nodes = bn.bn.getNodes();
00669 int[] topOrder = bn.getTopologicalOrder();
00670
00671 for (int i = topOrder.length - 1; i > -1; i--) {
00672 Bucket bucket = new Bucket(nodes[topOrder[i]]);
00673 int[] cpt = bn.getDomainProductNodeIndices(nodes[topOrder[i]]);
00674 HashSet<BeliefNode> cptNodes = new HashSet<BeliefNode>();
00675 for (int j : cpt) {
00676 cptNodes.add(nodes[j]);
00677 }
00678 BucketVar bv = new BucketVar(cptNodes);
00679 bv.setFunction(nodes[topOrder[i]].getCPF());
00680 bucket.addVar(bv);
00681 bucketMap.put(nodes[topOrder[i]], bucket);
00682 }
00683
00684 for (int i = topOrder.length - 1; i > -1; i--) {
00685 Bucket oldVar = bucketMap.get(nodes[topOrder[i]]);
00686 oldVar.partition(bound);
00687 HashSet<BucketVar> scopes = oldVar.createScopeFunctions();
00688 for (BucketVar bv : scopes) {
00689
00690 BeliefNode node = bv.getMaxNode(bn);
00691 bucketMap.get(node).addVar(bv);
00692 }
00693 }
00694 }
00695
00696 public void print(PrintStream out) {
00697 BeliefNode[] nodes = bn.bn.getNodes();
00698 int[] order = bn.getTopologicalOrder();
00699 for (int i = nodes.length - 1; i >= 0; i--) {
00700 BeliefNode n = nodes[order[i]];
00701 out.printf("%s: %s\n", n.toString(), bucketMap.get(n));
00702 }
00703 }
00704
00705 public Vector<MiniBucket> getMiniBuckets() {
00706 Vector<MiniBucket> mb = new Vector<MiniBucket>();
00707 for (Bucket b : bucketMap.values()) {
00708 mb.addAll(b.minibuckets);
00709 }
00710 return mb;
00711 }
00712
00713 public Vector<Bucket> getBuckets() {
00714 return new Vector<Bucket>(bucketMap.values());
00715 }
00716 }
00717
00718 protected static class JoinGraph {
00719
00720 HashSet<Node> nodes;
00721 HashMap<MiniBucket, Node> bucket2node = new HashMap<MiniBucket, Node>();
00722
00723 public JoinGraph(BeliefNetworkEx bn, int bound) {
00724 nodes = new HashSet<Node>();
00725
00726 SchematicMiniBucket smb = new SchematicMiniBucket(bn, bound);
00727
00728
00729 Vector<MiniBucket> minibuckets = smb.getMiniBuckets();
00730
00731
00732 for (MiniBucket mb : minibuckets) {
00733
00734 Node newNode = new Node(mb);
00735
00736 nodes.add(newNode);
00737 bucket2node.put(mb, newNode);
00738 }
00739
00740 for (MiniBucket mb : minibuckets) {
00741 for (MiniBucket p : mb.parents) {
00742 bucket2node.get(mb).parents.add(bucket2node.get(p));
00743 }
00744 }
00745
00746 for (MiniBucket mb : minibuckets) {
00747 for (MiniBucket par : mb.parents) {
00748 Node n1 = bucket2node.get(par);
00749 Node n2 = bucket2node.get(mb);
00750 new Arc(n1, n2);
00751 }
00752 }
00753
00754 for (MiniBucket mb1 : minibuckets) {
00755 for (MiniBucket mb2 : minibuckets) {
00756 if (mb1 != mb2 && mb1.bucket == mb2.bucket) {
00757 new Arc(bucket2node.get(mb1), bucket2node.get(mb2));
00758 }
00759 }
00760 }
00761 }
00762
00763 public void print(PrintStream out) {
00764 int i = 0;
00765 for (Node n : nodes) {
00766 out.printf("Node%d: %s\n", i++, StringTool.join(", ", n.nodes));
00767 for (CPF cpf : n.functions) {
00768 out.printf(" CPFS: %s | %s\n", cpf.getDomainProduct()[0],
00769 StringTool.join(", ", cpf.getDomainProduct()));
00770 }
00771 }
00772 }
00773
00774 public void writeDOT(File f) throws FileNotFoundException {
00775 PrintStream ps = new PrintStream(f);
00776 ps.println("graph {");
00777 for (Node n : nodes) {
00778 for (Node n2 : n.getNeighbors()) {
00779 ps.printf("\"%s\" -- \"%s\";\n", n.getShortName(), n2
00780 .getShortName());
00781 }
00782 }
00783 ps.println("}");
00784 }
00785
00786 public Vector<Node> getTopologicalorder() {
00787 Vector<Node> topOrder = new Vector<Node>();
00788 HashSet<Node> nodesLeft = new HashSet<Node>();
00789 nodesLeft.addAll(nodes);
00790 for (Node n : nodes) {
00791 if (n.parents.isEmpty()) {
00792 topOrder.add(n);
00793 nodesLeft.remove(n);
00794 }
00795 }
00796
00797
00798 int i = 0;
00799 while (!nodesLeft.isEmpty() && i < 10) {
00800 HashSet<Node> removeNodes = new HashSet<Node>();
00801
00802
00803 for (Node n : nodesLeft) {
00804
00805
00806 if (topOrder.containsAll(n.parents)) {
00807
00808 topOrder.add(n);
00809 removeNodes.add(n);
00810 }
00811 }
00812 nodesLeft.removeAll(removeNodes);
00813
00814 }
00815 return topOrder;
00816 }
00817
00818 public static class Arc {
00819 HashSet<BeliefNode> separator = new HashSet<BeliefNode>();
00820
00821 Vector<Node> nodes = new Vector<Node>();
00822 HashMap<Node, HashSet<MessageFunction>> outMessage = new HashMap<Node, HashSet<MessageFunction>>();
00823 HashMap<Node, HashSet<BeliefNode>> outCPTMessage = new HashMap<Node, HashSet<BeliefNode>>();
00824
00825 public Arc(Node n0, Node n1) {
00826 if (n0 != n1) {
00827
00828
00829
00830
00831
00832 separator = (HashSet<BeliefNode>) n0.nodes.clone();
00833 separator.retainAll(n1.nodes);
00834
00835
00836 nodes.add(n0);
00837 nodes.add(n1);
00838 n0.addArc(n1, this);
00839 n1.addArc(n0, this);
00840 outMessage.put(n0, new HashSet<MessageFunction>());
00841 outMessage.put(n1, new HashSet<MessageFunction>());
00842 outCPTMessage.put(n0, new HashSet<BeliefNode>());
00843 outCPTMessage.put(n1, new HashSet<BeliefNode>());
00844 } else
00845 throw new RuntimeException("1-node loop in graph");
00846 }
00847
00848 public Node getNeighbor(Node n) {
00849
00850 return nodes.get((nodes.indexOf(n) + 1) % 2);
00851 }
00852
00853 public void addOutMessage(Node n, MessageFunction m) {
00854 outMessage.get(n).add(m);
00855 }
00856
00857 public HashSet<MessageFunction> getOutMessages(Node n) {
00858 return outMessage.get(n);
00859 }
00860
00861 public HashSet<MessageFunction> getInMessage(Node n) {
00862 return this.getOutMessages(this.getNeighbor(n));
00863 }
00864
00865 public void addCPTOutMessage(Node n, BeliefNode bn) {
00866 outCPTMessage.get(n).add(bn);
00867 }
00868
00869 public HashSet<BeliefNode> getCPTOutMessages(Node n) {
00870 return outCPTMessage.get(n);
00871 }
00872
00873 public HashSet<BeliefNode> getCPTInMessage(Node n) {
00874 return this.getCPTOutMessages(this.getNeighbor(n));
00875 }
00876
00877 public void clearOutMessages(Node n) {
00878 outMessage.get(n).clear();
00879 outCPTMessage.get(n).clear();
00880 }
00881 }
00882
00883 public static class Node {
00884 MiniBucket mb;
00885 Vector<CPF> functions = new Vector<CPF>();
00886 HashSet<BeliefNode> nodes = new HashSet<BeliefNode>();
00887 HashSet<Node> parents;
00888 HashMap<Node, Arc> arcs = new HashMap<Node, Arc>();
00889
00890 public Node(MiniBucket mb) {
00891 this.mb = mb;
00892 this.parents = new HashSet<Node>();
00893 for (BucketVar var : mb.items) {
00894 nodes.addAll(var.nodes);
00895 if (var.cpf != null)
00896 functions.add(var.cpf);
00897 }
00898 }
00899
00900 public void addArc(Node n, Arc arc) {
00901 arcs.put(n, arc);
00902 }
00903
00904 public HashSet<Node> getNeighbors() {
00905 return new HashSet<Node>(arcs.keySet());
00906 }
00907
00908 public Arc getArcToNode(Node n) {
00909 return arcs.get(n);
00910 }
00911
00912 public Collection<BeliefNode> getNodes() {
00913 return nodes;
00914 }
00915
00916 public String toString() {
00917 return "Supernode[" + StringTool.join(",", nodes) + "; "
00918 + StringTool.join("; ", this.functions) + "]";
00919 }
00920
00921 public String getShortName() {
00922 return StringTool.join(",", nodes);
00923 }
00924 }
00925 }
00926 }