00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.HashMap;
00004
00005 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00006 import edu.ksu.cis.bnj.ver3.core.Discrete;
00007 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00008 import edu.tum.cs.util.Stopwatch;
00009
00010 public class GibbsSampling extends Sampler {
00011 int[] nodeOrder;
00012 HashMap<BeliefNode, BeliefNode[]> children;
00013
00014 public GibbsSampling(BeliefNetworkEx bn) throws Exception {
00015 super(bn);
00016 children = new HashMap<BeliefNode, BeliefNode[]>();
00017 for(int i = 0; i < nodes.length; i++) {
00018 children.put(nodes[i], bn.bn.getChildren(nodes[i]));
00019 }
00020 nodeOrder = bn.getTopologicalOrder();
00021 }
00022
00023 public SampledDistribution _infer() throws Exception {
00024 Stopwatch sw = new Stopwatch();
00025 createDistribution();
00026
00027
00028 out.println("initial setting...");
00029 WeightedSample s = bn.getWeightedSample(nodeOrder, evidenceDomainIndices, generator);
00030 if(s == null)
00031 throw new Exception("Could not find an initial state with non-zero probability in given number of trials.");
00032
00033
00034 out.println("Gibbs sampling...");
00035 sw.start();
00036
00037 for(int i = 1; i <= numSamples; i++) {
00038 if(i % infoInterval == 0)
00039 out.println(" step " + i);
00040 gibbsStep(evidenceDomainIndices, s);
00041 s.trials = 1;
00042 s.weight = 1;
00043 addSample(s);
00044 }
00045
00046 sw.stop();
00047 report(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/step)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep()));
00048 return dist;
00049 }
00050
00051 public void gibbsStep(int[] evidenceDomainIndices, WeightedSample s) {
00052
00053 for(int j = 0; j < nodes.length; j++) {
00054
00055 if(evidenceDomainIndices[j] != -1)
00056 continue;
00057
00058 BeliefNode n = nodes[j];
00059 Discrete dom = (Discrete)n.getDomain();
00060 int domSize = dom.getOrder();
00061 double[] distribution = new double[domSize];
00062
00063 for(int d = 0; d < domSize; d++) {
00064 s.nodeDomainIndices[j] = d;
00065
00066 double value = getCPTProbability(n, s.nodeDomainIndices);
00067
00068 for(BeliefNode child : children.get(n)) {
00069 value *= getCPTProbability(child, s.nodeDomainIndices);
00070 }
00071 distribution[d] = value;
00072 }
00073 s.nodeDomainIndices[j] = sample(distribution, generator);
00074 }
00075 }
00076 }