00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.HashMap;
00004 import java.util.Vector;
00005
00006 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00007 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00008 import edu.tum.cs.util.datastruct.MutableDouble;
00009
00014 public class BeliefPropagation extends Sampler {
00015
00016 protected BeliefNode[] nodes;
00017 protected int[] topOrder;
00018 protected HashMap<BeliefNode,double[]> lambda;
00019 protected HashMap<BeliefNode,double[]> pi;
00020 protected HashMap<BeliefNode, BeliefMessageContainer> messages;
00021 protected HashMap<BeliefNode,double[]> priors;
00022
00023 public class BeliefMessageContainer{
00024 public HashMap<BeliefNode, double[]> lambdaMessages;
00025 public HashMap<BeliefNode, double[]> piMessages;
00026 protected BeliefNode node;
00027 protected int nodeOrder;
00028
00029 public BeliefMessageContainer(BeliefNode node){
00030
00031 lambdaMessages = new HashMap<BeliefNode, double[]>();
00032 piMessages = new HashMap<BeliefNode, double[]>();
00033 this.node = node;
00034 nodeOrder = node.getDomain().getOrder();
00035
00036 for (BeliefNode n : bn.bn.getChildren(node)){
00037 double[] initPi = new double[nodeOrder];
00038 for (int i = 0; i < nodeOrder; i++){
00039 initPi[i] = 1.0/nodeOrder;
00040 }
00041 piMessages.put(n,initPi);
00042 }
00043 for(BeliefNode n : bn.bn.getParents(node)){
00044 int parentOrder = n.getDomain().getOrder();
00045 double[] initLambda = new double[parentOrder];
00046 for (int i = 0; i < parentOrder; i++){
00047 initLambda[i] = 1.0/parentOrder;
00048 }
00049 lambdaMessages.put(n, initLambda);
00050 }
00051 }
00052
00053 public void computePiMessages(BeliefNode n){
00054 double normalize = 0.0;
00055 for (int i = 0; i < nodeOrder; i++){
00056 double prod = 1.0;
00057 for (BeliefNode c : piMessages.keySet()){
00058 if (c != n){
00059 prod *= messages.get(c).lambdaMessages.get(node)[i];
00060 }
00061 }
00062 double entry = prod * pi.get(node)[i];
00063 piMessages.get(n)[i] = entry;
00064 normalize += entry;
00065 }
00066
00067 if (normalize != 0.0){
00068 if (normalize == 0.0)
00069 return;
00070 for (int i = 0; i < nodeOrder; i++){
00071 piMessages.get(n)[i] /= normalize;
00072 }
00073 }
00074 }
00075
00076 public void computeLambdaMessages(BeliefNode n, int[] nodeDomainIndices) {
00077
00078 Vector<Integer> varsToSumOver = new Vector<Integer>();
00079 for (BeliefNode p : lambdaMessages.keySet()){
00080 if (p != n && nodeDomainIndices[getNodeIndex(p)] == -1){
00081 varsToSumOver.add(getNodeIndex(p));
00082 }
00083 }
00084
00085 double normalize = 0.0;
00086 for (int i = 0; i < lambdaMessages.get(n).length; i++){
00087 nodeDomainIndices[getNodeIndex(n)] = i;
00088 double sum = 0.0;
00089 for (int j = 0; j < nodeOrder; j++){
00090 nodeDomainIndices[getNodeIndex(node)] = j;
00091 double prod = lambda.get(node)[j];
00092 MutableDouble mutableSum = new MutableDouble(0.0);
00093 computeLambdaMessages(n, varsToSumOver,0,nodeDomainIndices,mutableSum);
00094 sum += prod * mutableSum.value;
00095 }
00096 lambdaMessages.get(n)[i] = sum;
00097 normalize += sum;
00098 }
00099 if (normalize != 0.0){
00100 if (normalize == 0.0)
00101 return;
00102 for (int i = 0; i < lambdaMessages.get(n).length; i++){
00103 lambdaMessages.get(n)[i] /= normalize;
00104 }
00105 }
00106 }
00107
00108 protected void computeLambdaMessages(BeliefNode n, Vector<Integer> varsToSumOver, int i, int[] nodeDomainIndices, MutableDouble sum) {
00109 if (i == varsToSumOver.size()) {
00110 double result = getCPTProbability(node, nodeDomainIndices);
00111
00112 for (BeliefNode p : bn.bn.getParents(node)){
00113 if (n != p){
00114 result *= messages.get(p).piMessages.get(node)[nodeDomainIndices[getNodeIndex(p)]];
00115 }
00116 }
00117 sum.value += result;
00118 return;
00119 }
00120 int idxVar = varsToSumOver.get(i);
00121 for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) {
00122 nodeDomainIndices[idxVar] = v;
00123 computeLambdaMessages(n, varsToSumOver, i + 1, nodeDomainIndices, sum);
00124 }
00125 }
00126
00127 public boolean sentPiMessageTo(BeliefNode c){
00128 if (pi.containsKey(c)){
00129 double sum = 0.0;
00130 for (double d : pi.get(c)){
00131 sum += d;
00132 }
00133 return (sum != 0.0);
00134 }
00135 else{
00136 return false;
00137 }
00138 }
00139
00140 public boolean sentLambdaMessageTo(BeliefNode p){
00141 if (lambdaMessages.containsKey(p)){
00142 double sum = 0.0;
00143 for (double d : lambdaMessages.get(p)){
00144 sum += d;
00145 }
00146 return (sum != 0.0);
00147 }
00148 else{
00149 return false;
00150 }
00151 }
00152 }
00153
00154 public void computePi(BeliefNode n, int[] nodeDomainIndices){
00155
00156 if (evidenceDomainIndices[getNodeIndex(n)] != -1)
00157 return;
00158 Vector<Integer> varsToSumOver = new Vector<Integer>();
00159 for (BeliefNode p : bn.bn.getParents(n)){
00160 if (nodeDomainIndices[getNodeIndex(p)] == -1){
00161 varsToSumOver.add(getNodeIndex(p));
00162 }
00163 }
00164 double normalize = 0.0;
00165 for (int i = 0; i < pi.get(n).length; i++){
00166 nodeDomainIndices[getNodeIndex(n)] = i;
00167 MutableDouble mutableSum = new MutableDouble(0.0);
00168 computePi(n,varsToSumOver,0,nodeDomainIndices,mutableSum);
00169 pi.get(n)[i] = mutableSum.value;
00170 normalize += mutableSum.value;
00171 }
00172 if (normalize == 0.0)
00173 return;
00174 for (int i = 0; i < pi.get(n).length; i++){
00175 pi.get(n)[i] /= normalize;
00176 }
00177 }
00178
00179 protected void computePi(BeliefNode n, Vector<Integer> varsToSumOver, int i, int[] nodeDomainIndices, MutableDouble sum) {
00180 if (i == varsToSumOver.size()) {
00181 double result = getCPTProbability(n, nodeDomainIndices);
00182
00183 for (BeliefNode p : bn.bn.getParents(n)){
00184 result *= messages.get(p).piMessages.get(n)[nodeDomainIndices[getNodeIndex(p)]];
00185 }
00186 sum.value += result;
00187 return;
00188 }
00189 int idxVar = varsToSumOver.get(i);
00190 for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) {
00191 nodeDomainIndices[idxVar] = v;
00192 computePi(n, varsToSumOver, i + 1, nodeDomainIndices, sum);
00193 }
00194 }
00195
00196 public void computeLambda(BeliefNode n){
00197 if (evidenceDomainIndices[getNodeIndex(n)] != -1)
00198 return;
00199 double normalize = 0.0;
00200 for (int i = 0; i < lambda.get(n).length; i++){
00201 double prod = 1.0;
00202 for (BeliefNode c : bn.bn.getChildren(n)){
00203 prod *= messages.get(c).lambdaMessages.get(n)[i];
00204 }
00205 lambda.get(n)[i] = prod;
00206 normalize += prod;
00207 }
00208 if (normalize == 0.0)
00209 return;
00210 for (int i = 0; i < lambda.get(n).length; i++){
00211 lambda.get(n)[i] /= normalize;
00212 }
00213 }
00214
00215 public BeliefPropagation(BeliefNetworkEx bn) throws Exception {
00216 super(bn);
00217
00218 nodes = bn.getNodes();
00219 topOrder = bn.getTopologicalOrder();
00220 lambda = new HashMap<BeliefNode,double[]>();
00221 pi = new HashMap<BeliefNode,double[]>();
00222 messages = new HashMap<BeliefNode, BeliefMessageContainer>();
00223 }
00224
00225 @Override
00226 public String getAlgorithmName() {
00227 return String.format("Belief Propagation");
00228 }
00229
00230 @Override
00231 protected SampledDistribution _infer() throws Exception {
00232
00233 priors = bn.computePriors(evidenceDomainIndices);
00234 for (BeliefNode n : nodes){
00235 int domSize = n.getDomain().getOrder();
00236 int domIdx = evidenceDomainIndices[getNodeIndex(n)];
00237
00238 messages.put(n, new BeliefMessageContainer(n));
00239
00240 double[] init = new double[domSize];
00241 for (int i = 0; i < init.length; i++){
00242 init[i] = 0.0;
00243 }
00244
00245 if (domIdx != -1){
00246 init[domIdx] = 1.0;
00247 lambda.put(n, init.clone());
00248 pi.put(n, init.clone());
00249 }
00250
00251 else{
00252
00253 if (bn.bn.getParents(n).length == 0){
00254 double[] prior = priors.get(n);
00255 pi.put(n, prior);
00256 }
00257
00258 if (bn.bn.getChildren(n).length == 0){
00259 double normalized = 1 / (double) domSize;
00260 double[] uniform = new double[domSize];
00261 for (int i = 0; i < uniform.length; i++){
00262 uniform[i] = normalized;
00263 }
00264 lambda.put(n, uniform);
00265 }
00266 if (!pi.containsKey(n)){
00267 pi.put(n, init.clone());
00268 }
00269 if (!lambda.containsKey(n)){
00270 lambda.put(n, init.clone());
00271 }
00272 }
00273 }
00274
00275 if (debug){
00276 out.println("After initialization process");
00277 for (BeliefNode n : nodes){
00278 out.println(" Node: " + n);
00279 out.println(" Pi(x):" + n);
00280 for(int i = 0; i < pi.get(n).length; i++){
00281 out.println(" " + i + ": " + pi.get(n)[i]);
00282 }
00283 out.println(" Lambda(x):" + n);
00284 for(int i = 0; i < lambda.get(n).length; i++){
00285 out.println(" " + i + ": " + lambda.get(n)[i]);
00286 }
00287 }
00288 }
00289
00290
00291
00292 for (int step = 1; step <= this.numSamples; step++) {
00293
00294 if(verbose && step % this.infoInterval == 0)
00295 out.println("step " + step);
00296
00297
00298 for (BeliefNode n : nodes){
00299 int[] nodeDomainIndices = evidenceDomainIndices.clone();
00300
00301 boolean receivedAll = true;
00302
00303 for (BeliefNode c : bn.bn.getParents(n)){
00304 double sum = 0.0;
00305 for (double d : messages.get(c).piMessages.get(n)){
00306 sum += d;
00307 }
00308 if (sum == 0.0){
00309 receivedAll = false;
00310 }
00311 }
00312 if (receivedAll){
00313 computePi(n, nodeDomainIndices);
00314 }
00315 }
00316
00317
00318 for (BeliefNode n : nodes){
00319 boolean receivedAll = true;
00320
00321 for (BeliefNode c : bn.bn.getChildren(n)){
00322 double sum = 0.0;
00323 for (double d : messages.get(c).lambdaMessages.get(n)){
00324 sum += d;
00325 }
00326 if (sum == 0.0){
00327 receivedAll = false;
00328 break;
00329 }
00330 }
00331 if (receivedAll && bn.bn.getChildren(n).length > 0){
00332 computeLambda(n);
00333 }
00334 }
00335
00336
00337 for (BeliefNode n : nodes){
00338
00339 double sum = 0.0;
00340 for (double d : pi.get(n)){
00341 sum += d;
00342 }
00343 if (sum != 0){
00344 BeliefNode[] children = bn.bn.getChildren(n);
00345 for (BeliefNode c : children){
00346
00347 boolean receivedAll = true;
00348 for (BeliefNode c2 : children){
00349 if ((c2 != c) && !messages.get(c2).sentLambdaMessageTo(n)){
00350 receivedAll = false;
00351 break;
00352 }
00353 }
00354 if (receivedAll){
00355 messages.get(n).computePiMessages(c);
00356 }
00357 }
00358 }
00359 }
00360
00361
00362 for (BeliefNode n : nodes){
00363
00364 double sum = 0.0;
00365 for (double d : lambda.get(n)){
00366 sum += d;
00367 }
00368 if (sum != 0){
00369 for (BeliefNode p : bn.bn.getParents(n)){
00370
00371 boolean receivedAll = true;
00372 for (BeliefNode p2 : bn.bn.getParents(n)){
00373 if ((p2 != p) && !messages.get(p2).sentPiMessageTo(n)){
00374 receivedAll = false;
00375 break;
00376 }
00377 }
00378 if (receivedAll){
00379 int[] nodeDomainIndices = evidenceDomainIndices.clone();
00380 messages.get(n).computeLambdaMessages(p, nodeDomainIndices);
00381 }
00382 }
00383 }
00384 }
00385
00386 if(debug){
00387 out.println("\n\n****After step " + step + "****");
00388 out.println("\n Pi and Lambda Functions");
00389 for (BeliefNode n : nodes){
00390 out.println(" Node: " + n);
00391 out.println(" Pi(x):" + n);
00392 for(int i = 0; i < pi.get(n).length; i++){
00393 out.println(" " + i + ": " + pi.get(n)[i]);
00394 }
00395 out.println(" Lambda(x):" + n);
00396 for(int i = 0; i < lambda.get(n).length; i++){
00397 out.println(" " + i + ": " + lambda.get(n)[i]);
00398 }
00399 }
00400
00401 out.println("\n Message Functions");
00402 for (BeliefNode n : nodes){
00403 out.println("Node: " + n);
00404 for (BeliefNode c : messages.get(n).piMessages.keySet()){
00405 out.println(" Pi-Message to " + c + ":");
00406 for(int i = 0; i < messages.get(n).piMessages.get(c).length; i++){
00407 out.println(" " + i + ": " + messages.get(n).piMessages.get(c)[i]);
00408 }
00409 }
00410 for (BeliefNode c : messages.get(n).lambdaMessages.keySet()){
00411 out.println(" Lambda to (x):" + c);
00412 for(int i = 0; i < messages.get(n).lambdaMessages.get(c).length; i++){
00413 out.println(" " + i + ": " + messages.get(n).lambdaMessages.get(c)[i]);
00414 }
00415 }
00416 }
00417 }
00418
00419 }
00420
00421 if(verbose) out.println("computing results....");
00422 this.createDistribution();
00423 dist.Z = 1.0;
00424 for (BeliefNode n : nodes) {
00425 int i = getNodeIndex(n);
00426 if (evidenceDomainIndices[i] >= 0) {
00427 dist.values[i][evidenceDomainIndices[i]] = 1.0;
00428 continue;
00429 }
00430 int domSize = dist.values[i].length;
00431 double normalize = 0.0;
00432 for (int j = 0; j < domSize; j++) {
00433 dist.values[i][j] = lambda.get(n)[j]*pi.get(n)[j];
00434 normalize += dist.values[i][j];
00435 }
00436 for (int j = 0; j < domSize; j++) {
00437 if (normalize == 0.0)
00438 continue;
00439 dist.values[i][j] /= normalize;
00440 }
00441 }
00442 return dist;
00443 }
00444
00445 }