00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.Arrays;
00004 import java.util.Comparator;
00005 import java.util.HashSet;
00006 import java.util.PriorityQueue;
00007 import java.util.Vector;
00008
00009 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00010 import edu.ksu.cis.bnj.ver3.core.CPF;
00011 import edu.ksu.cis.bnj.ver3.core.Discrete;
00012 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00013 import edu.tum.cs.bayesnets.util.TopologicalOrdering;
00014 import edu.tum.cs.bayesnets.util.TopologicalSort;
00015 import edu.tum.cs.util.Stopwatch;
00016 import edu.tum.cs.util.StringTool;
00017
00024 public class BackwardSampling extends Sampler {
00025
00026 protected Vector<BeliefNode> backwardSampledNodes;
00027 protected Vector<BeliefNode> forwardSampledNodes;
00028 protected HashSet<BeliefNode> outsideSamplingOrder;
00029 protected int currentStep;
00030
00031 protected static class BackSamplingDistribution {
00032 public Vector<Double> distribution;
00033 public Vector<int[]> states;
00034 double Z;
00035 protected Sampler sampler;
00036
00037 public BackSamplingDistribution(Sampler sampler) {
00038 Z = 0.0;
00039 distribution = new Vector<Double>();
00040 states = new Vector<int[]>();
00041 this.sampler = sampler;
00042 }
00043
00044 public void addValue(double p, int[] state) {
00045 distribution.add(p);
00046 states.add(state);
00047 Z += p;
00048 }
00049
00050 public void applyWeight(WeightedSample s, int sampledValue) {
00051 s.weight *= Z;
00052 }
00053
00054 public void construct(BeliefNode node, int[] nodeDomainIndices) {
00055 CPF cpf = node.getCPF();
00056 BeliefNode[] domProd = cpf.getDomainProduct();
00057 int[] addr = new int[domProd.length];
00058 addr[0] = nodeDomainIndices[sampler.nodeIndices.get(node)];
00059 construct(1, addr, cpf, nodeDomainIndices);
00060 }
00061
00069 protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
00070 if(i == addr.length) {
00071 double p = cpf.getDouble(addr);
00072 if(p != 0)
00073 addValue(p, addr.clone());
00074 return;
00075 }
00076 BeliefNode[] domProd = cpf.getDomainProduct();
00077 int nodeIdx = sampler.nodeIndices.get(domProd[i]);
00078 if(nodeDomainIndices[nodeIdx] >= 0) {
00079 addr[i] = nodeDomainIndices[nodeIdx];
00080 construct(i+1, addr, cpf, nodeDomainIndices);
00081 }
00082 else {
00083 Discrete dom = (Discrete)domProd[i].getDomain();
00084 for(int j = 0; j < dom.getOrder(); j++) {
00085 addr[i] = j;
00086 construct(i+1, addr, cpf, nodeDomainIndices);
00087 }
00088 }
00089 }
00090 }
00091
00092 public BackwardSampling(BeliefNetworkEx bn) throws Exception {
00093 super(bn);
00094 }
00095
00101 public static class TierComparator implements Comparator<BeliefNode> {
00102
00103 TopologicalOrdering topOrder;
00104
00105 public TierComparator(TopologicalOrdering topOrder) {
00106 this.topOrder = topOrder;
00107 }
00108
00109 public int compare(BeliefNode o1, BeliefNode o2) {
00110 return -(topOrder.getTier(o1) - topOrder.getTier(o2));
00111 }
00112 }
00113
00119 protected void getOrdering(int[] evidenceDomainIndices) throws Exception {
00120 HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes));
00121 backwardSampledNodes = new Vector<BeliefNode>();
00122 forwardSampledNodes = new Vector<BeliefNode>();
00123 outsideSamplingOrder = new HashSet<BeliefNode>();
00124 TopologicalOrdering topOrder = new TopologicalSort(bn.bn).run(true);
00125 PriorityQueue<BeliefNode> backSamplingCandidates = new PriorityQueue<BeliefNode>(1, new TierComparator(topOrder));
00126
00127
00128 for(int i = 0; i < evidenceDomainIndices.length; i++) {
00129 if(evidenceDomainIndices[i] >= 0) {
00130 backSamplingCandidates.add(nodes[i]);
00131 uninstantiatedNodes.remove(nodes[i]);
00132 }
00133 }
00134
00135
00136 while(!backSamplingCandidates.isEmpty()) {
00137 BeliefNode node = backSamplingCandidates.remove();
00138
00139 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00140 boolean doBackSampling = false;
00141 for(int j = 1; j < domProd.length; j++) {
00142 BeliefNode parent = domProd[j];
00143
00144 if(uninstantiatedNodes.remove(parent)) {
00145 doBackSampling = true;
00146 backSamplingCandidates.add(parent);
00147 }
00148 }
00149 if(doBackSampling)
00150 backwardSampledNodes.add(node);
00151
00152
00153 else
00154 outsideSamplingOrder.add(node);
00155 }
00156
00157
00158 for(int i : topOrder) {
00159 if(uninstantiatedNodes.contains(nodes[i]))
00160 forwardSampledNodes.add(nodes[i]);
00161 }
00162 }
00163
00170 protected boolean sampleBackward(BeliefNode node, WeightedSample s) {
00171
00172
00173 BackSamplingDistribution d = getBackSamplingDistribution(node, s);
00174
00175 int idx = sample(d.distribution, generator);
00176 if(idx == -1)
00177 return false;
00178 int[] state = d.states.get(idx);
00179
00180 d.applyWeight(s, idx);
00181 if(s.weight == 0.0)
00182 return false;
00183
00184 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00185 for(int i = 1; i < state.length; i++) {
00186 int nodeIdx = this.nodeIndices.get(domProd[i]);
00187 s.nodeDomainIndices[nodeIdx] = state[i];
00188
00189 }
00190 return true;
00191 }
00192
00193 protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
00194 BackSamplingDistribution d = new BackSamplingDistribution(this);
00195 d.construct(node, s.nodeDomainIndices);
00196 return d;
00197 }
00198
00199 protected void prepareInference(int[] evidenceDomainIndices) throws Exception {
00200 this.evidenceDomainIndices = evidenceDomainIndices;
00201 getOrdering(evidenceDomainIndices);
00202 if(debug) {
00203 out.println("sampling backward: " + this.backwardSampledNodes);
00204 out.println("sampling forward: " + this.forwardSampledNodes);
00205 out.println("not in order: " + this.outsideSamplingOrder);
00206 }
00207 }
00208
00209 @Override
00210 public SampledDistribution _infer() throws Exception {
00211 Stopwatch sw = new Stopwatch();
00212 sw.start();
00213
00214 this.prepareInference(evidenceDomainIndices);
00215
00216 this.createDistribution();
00217 if(verbose) out.println("sampling...");
00218 WeightedSample s = new WeightedSample(this.bn, evidenceDomainIndices.clone(), 1.0, null, 0);
00219 for(currentStep = 1; currentStep <= this.numSamples; currentStep++) {
00220 if(verbose && currentStep % infoInterval == 0)
00221 out.println(" step " + currentStep);
00222 getSample(s);
00223 this.addSample(s);
00224 onAddedSample(s);
00225 if(converged())
00226 break;
00227 }
00228
00229 sw.stop();
00230 report(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/step)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep()));
00231 return this.dist;
00232 }
00233
00239 public void getSample(WeightedSample s) throws Exception {
00240 int MAX_TRIALS = this.maxTrials;
00241 loop1: for(int t = 1; t <= MAX_TRIALS; t++) {
00242
00243 initSample(s);
00244
00245 for(BeliefNode node : backwardSampledNodes) {
00246 if(!sampleBackward(node, s)) {
00247 if(debug) out.println("!!! backward sampling failed at " + node + " in step " + currentStep);
00248 continue loop1;
00249 }
00250 }
00251
00252
00253 for(BeliefNode node : forwardSampledNodes) {
00254 if(!sampleForward(node, s)) {
00255 if(debug) {
00256
00257
00258
00259
00260
00261
00262
00263
00264 out.println("!!! forward sampling failed at " + node + " in step " + currentStep + "; cond: " + s.getCPDLookupString(node));
00265 }
00266 continue loop1;
00267 }
00268 }
00269
00270
00271 for(BeliefNode node : outsideSamplingOrder) {
00272 double p = this.getCPTProbability(node, s.nodeDomainIndices);
00273 s.weight *= p;
00274 if(s.weight == 0.0) {
00275 if(p != 0.0)
00276 throw new Exception("Precision loss in weight calculation");
00277
00278 if(debug) out.println("!!! weight became zero at unordered node " + node + " in step " + currentStep + "; cond: " + s.getCPDLookupString(node));
00279 if(debug && this instanceof BackwardSamplingWithPriors) {
00280 double[] dist = ((BackwardSamplingWithPriors)this).priors.get(node);
00281 out.println("prior: " + StringTool.join(", ", dist) + " value=" + s.nodeDomainIndices[getNodeIndex(node)]);
00282 CPF cpf = node.getCPF();
00283 BeliefNode[] domProd = cpf.getDomainProduct();
00284 int[] addr = new int[domProd.length];
00285 for(int i = 1; i < addr.length; i++)
00286 addr[i] = s.nodeDomainIndices[getNodeIndex(domProd[i])];
00287 for(int i = 0; i < dist.length; i++) {
00288 addr[0] = i;
00289 dist[i] = cpf.getDouble(addr);
00290 }
00291 out.println("cpd: " + StringTool.join(", ", dist));
00292 }
00293 continue loop1;
00294 }
00295 }
00296
00297 s.trials = t;
00298 return;
00299 }
00300 throw new RuntimeException("Maximum number of trials exceeded.");
00301 }
00302
00303 public void initSample(WeightedSample s) throws Exception {
00304 s.nodeDomainIndices = evidenceDomainIndices.clone();
00305 s.weight = 1.0;
00306 }
00307
00308 protected boolean sampleForward(BeliefNode node, WeightedSample s) {
00309 int idx = super.sampleForward(node, s.nodeDomainIndices);
00310 if(idx == -1)
00311 return false;
00312 s.nodeDomainIndices[this.nodeIndices.get(node)] = idx;
00313 return true;
00314 }
00315
00316 protected void onAddedSample(WeightedSample s) throws Exception {
00317 }
00318 }