00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import edu.ksu.cis.bnj.ver3.core.Domain;
00004 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00005 import edu.tum.cs.util.Stopwatch;
00006
00007 public class EnumerationAsk extends Sampler {
00008 int[] nodeOrder;
00009 int numPathsPruned;
00010 double numWorldsPruned, numWorldsCounted;
00011 Stopwatch timer;
00015 double numTotalWorlds;
00016
00017 public EnumerationAsk(BeliefNetworkEx bn) throws Exception {
00018 super(bn);
00019 nodeOrder = bn.getTopologicalOrder();
00020 numTotalWorlds = bn.getNumWorlds();
00021 }
00022
00023 public SampledDistribution _infer() throws Exception {
00024 Stopwatch sw = new Stopwatch();
00025 numPathsPruned = 0;
00026 numWorldsPruned = numWorldsCounted = 0;
00027 createDistribution();
00028 if(verbose) out.printf("enumerating %s worlds...\n", numTotalWorlds);
00029 sw.start();
00030 WeightedSample s = new WeightedSample(bn);
00031 timer = new Stopwatch();
00032 timer.start();
00033 enumerateWorlds(s, nodeOrder, evidenceDomainIndices, 0, 1);
00034 sw.stop();
00035 report(String.format("\ntime taken: %.2fs (%d worlds enumerated, %d paths pruned)\n", sw.getElapsedTimeSecs(), dist.steps, numPathsPruned));
00036 return dist;
00037 }
00038
00039 public void enumerateWorlds(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices, int i, double combinationsHandled) throws Exception {
00040
00041
00042 if(timer.getElapsedTimeSecs() > 1) {
00043 double numDone = numWorldsCounted+numWorldsPruned;
00044 if(verbose) out.printf(" ~ %.4f%% done (%s worlds handled, %d paths pruned)\r", 100.0*numDone/numTotalWorlds, numDone, numPathsPruned);
00045 timer = new Stopwatch();
00046 timer.start();
00047 }
00048
00049 if(i == nodes.length) {
00050
00051 addSample(s);
00052 numWorldsCounted++;
00053 return;
00054 }
00055
00056 int nodeIdx = nodeOrder[i];
00057 combinationsHandled *= nodes[nodeOrder[i]].getDomain().getOrder();
00058 int domainIdx = evidenceDomainIndices[nodeIdx];
00059
00060 if(domainIdx >= 0) {
00061 s.nodeDomainIndices[nodeIdx] = domainIdx;
00062 double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
00063 s.weight *= prob;
00064 if(prob == 0.0) {
00065
00066 numPathsPruned++;
00067 numWorldsPruned += numTotalWorlds / combinationsHandled;
00068 return;
00069 }
00070 enumerateWorlds(s, nodeOrder, evidenceDomainIndices, i+1, combinationsHandled);
00071 }
00072
00073 else {
00074 Domain d = nodes[nodeIdx].getDomain();
00075 int order = d.getOrder();
00076
00077 double weight = s.weight;
00078 for(int j = 0; j < order; j++) {
00079 s.nodeDomainIndices[nodeIdx] = j;
00080 double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
00081 if(prob == 0.0) {
00082
00083 numPathsPruned++;
00084 numWorldsPruned += numTotalWorlds / combinationsHandled;
00085 continue;
00086 }
00087 s.weight = weight * prob;
00088 enumerateWorlds(s, nodeOrder, evidenceDomainIndices, i+1, combinationsHandled);
00089 }
00090 }
00091 }
00092 }