00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00004 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00005 import edu.tum.cs.util.Stopwatch;
00006
00007 public class LikelihoodWeighting extends Sampler {
00008 int[] nodeOrder;
00009
00010 public LikelihoodWeighting(BeliefNetworkEx bn) throws Exception {
00011 super(bn);
00012 nodeOrder = bn.getTopologicalOrder();
00013 }
00014
00015 @Override
00016 public SampledDistribution _infer() throws Exception {
00017
00018 Stopwatch sw = new Stopwatch();
00019 createDistribution();
00020 out.println("sampling...");
00021 sw.start();
00022 WeightedSample s = new WeightedSample(bn);
00023 for(int i = 1; i <= numSamples; i++) {
00024 if(i % infoInterval == 0)
00025 out.println(" step " + i);
00026 WeightedSample ret = getWeightedSample(s, nodeOrder, evidenceDomainIndices);
00027 if(ret != null) {
00028 addSample(ret);
00029
00030 if(false) {
00031 out.print("w=" + ret.weight);
00032 for(int j = 0; j < evidenceDomainIndices.length; j++)
00033 if(evidenceDomainIndices[j] == -1) {
00034 BeliefNode node = nodes[j];
00035 out.print(" " + node.getName() + "=" + node.getDomain().getName(s.nodeDomainIndices[j]));
00036 }
00037 out.println();
00038 }
00039 }
00040 if(converged())
00041 break;
00042 }
00043 sw.stop();
00044 out.println(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/sample, %d samples)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep(), dist.steps));
00045 return dist;
00046 }
00047
00048 public WeightedSample getWeightedSample(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices) throws Exception {
00049 s.trials = 0;
00050 boolean successful = false;
00051 loop: while(!successful) {
00052 s.weight = 1.0;
00053 s.trials++;
00054 if(maxTrials > 0 && s.trials > this.maxTrials) {
00055 if(!this.skipFailedSteps)
00056 throw new Exception("Could not obtain a countable sample in the maximum allowed number of trials (" + maxTrials + ")");
00057 else
00058 return null;
00059 }
00060
00061 for(int i=0; i < nodeOrder.length; i++) {
00062 int nodeIdx = nodeOrder[i];
00063 int domainIdx = evidenceDomainIndices[nodeIdx];
00064
00065 if(domainIdx >= 0) {
00066 s.nodeDomainIndices[nodeIdx] = domainIdx;
00067 double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
00068 if(prob == 0.0) {
00069 if(debug)
00070 out.println("!!! evidence probability was 0 at node " + nodes[nodeIdx] + " in step " + (dist.steps+1));
00071 continue loop;
00072 }
00073 s.weight *= prob;
00074 }
00075
00076 else {
00077 domainIdx = sampleForward(nodes[nodeIdx], s.nodeDomainIndices);
00078 if(domainIdx < 0) {
00079 if(debug)
00080 out.println("!!! could not sample forward because of column with only 0s in CPT of " + nodes[nodeIdx].getName() + " in step " + (dist.steps+1));
00081 bn.removeAllEvidences();
00082 continue loop;
00083 }
00084 s.nodeDomainIndices[nodeIdx] = domainIdx;
00085 }
00086 }
00087 successful = true;
00088 }
00089 return s;
00090 }
00091 }