00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.Arrays;
00004 import java.util.HashSet;
00005 import java.util.LinkedList;
00006 import java.util.Vector;
00007
00008 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00009 import edu.ksu.cis.bnj.ver3.core.CPF;
00010 import edu.ksu.cis.bnj.ver3.core.Discrete;
00011 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00012 import edu.tum.cs.util.Stopwatch;
00013 import edu.tum.cs.util.StringTool;
00014
00021 public class JointBackwardSampling extends Sampler {
00022
00023 Vector<BeliefNode> backwardSampledNodes;
00024 Vector<BeliefNode> forwardSampledNodes;
00025 HashSet<BeliefNode> outsideSamplingOrder;
00026 int[] evidenceDomainIndices;
00027
00028 protected static class BackSamplingDistribution {
00029 public Vector<Double> distribution;
00030 public Vector<int[]> states;
00031 double Z;
00032 protected Sampler sampler;
00033
00034 public BackSamplingDistribution(Sampler sampler) {
00035 Z = 0.0;
00036 distribution = new Vector<Double>();
00037 states = new Vector<int[]>();
00038 this.sampler = sampler;
00039 }
00040
00041 public void addValue(double p, int[] state) {
00042 distribution.add(p);
00043 states.add(state);
00044 Z += p;
00045 }
00046
00047 public void applyWeight(WeightedSample s, int sampledValue) {
00048 s.weight *= Z;
00049 }
00050
00051 public void construct(BeliefNode node, int[] nodeDomainIndices) {
00052 CPF cpf = node.getCPF();
00053 BeliefNode[] domProd = cpf.getDomainProduct();
00054 int[] addr = new int[domProd.length];
00055 addr[0] = nodeDomainIndices[sampler.nodeIndices.get(node)];
00056 construct(1, addr, cpf, nodeDomainIndices);
00057 }
00058
00066 protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
00067 if(i == addr.length) {
00068 double p = cpf.getDouble(addr);
00069 if(p != 0)
00070 addValue(p, addr.clone());
00071 return;
00072 }
00073 BeliefNode[] domProd = cpf.getDomainProduct();
00074 int nodeIdx = sampler.nodeIndices.get(domProd[i]);
00075 if(nodeDomainIndices[nodeIdx] >= 0) {
00076 addr[i] = nodeDomainIndices[nodeIdx];
00077 construct(i+1, addr, cpf, nodeDomainIndices);
00078 }
00079 else {
00080 Discrete dom = (Discrete)domProd[i].getDomain();
00081 for(int j = 0; j < dom.getOrder(); j++) {
00082 addr[i] = j;
00083 construct(i+1, addr, cpf, nodeDomainIndices);
00084 }
00085 }
00086 }
00087 }
00088
00089 public JointBackwardSampling(BeliefNetworkEx bn) throws Exception {
00090 super(bn);
00091 }
00092
00097 protected void getOrdering(int[] evidenceDomainIndices) {
00098 HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes));
00099 backwardSampledNodes = new Vector<BeliefNode>();
00100 forwardSampledNodes = new Vector<BeliefNode>();
00101 outsideSamplingOrder = new HashSet<BeliefNode>();
00102 LinkedList<BeliefNode> backSamplingCandidates = new LinkedList<BeliefNode>();
00103
00104
00105 for(int i = 0; i < evidenceDomainIndices.length; i++) {
00106 if(evidenceDomainIndices[i] >= 0) {
00107 backSamplingCandidates.add(nodes[i]);
00108 uninstantiatedNodes.remove(nodes[i]);
00109 }
00110 }
00111
00112
00113 while(!backSamplingCandidates.isEmpty()) {
00114 BeliefNode node = backSamplingCandidates.removeFirst();
00115
00116 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00117 boolean doBackSampling = false;
00118 for(int j = 1; j < domProd.length; j++) {
00119 BeliefNode parent = domProd[j];
00120
00121 if(uninstantiatedNodes.remove(parent)) {
00122 doBackSampling = true;
00123 backSamplingCandidates.add(parent);
00124 }
00125 }
00126 if(doBackSampling)
00127 backwardSampledNodes.add(node);
00128
00129
00130 else
00131 outsideSamplingOrder.add(node);
00132 }
00133
00134
00135 int[] topOrder = bn.getTopologicalOrder();
00136 for(int i : topOrder) {
00137 if(uninstantiatedNodes.contains(nodes[i]))
00138 forwardSampledNodes.add(nodes[i]);
00139 }
00140 }
00141
00148 protected boolean sampleBackward(BeliefNode node, WeightedSample s) {
00149
00150
00151 BackSamplingDistribution d = getBackSamplingDistribution(node, s);
00152
00153 int idx = sample(d.distribution, generator);
00154 if(idx == -1)
00155 return false;
00156 int[] state = d.states.get(idx);
00157
00158 d.applyWeight(s, idx);
00159 if(s.weight == 0.0)
00160 return false;
00161
00162 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00163 for(int i = 1; i < state.length; i++) {
00164 int nodeIdx = this.nodeIndices.get(domProd[i]);
00165 s.nodeDomainIndices[nodeIdx] = state[i];
00166
00167 }
00168 return true;
00169 }
00170
00171 protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
00172 BackSamplingDistribution d = new BackSamplingDistribution(this);
00173 d.construct(node, s.nodeDomainIndices);
00174 return d;
00175 }
00176
00177 protected void prepareInference(int[] evidenceDomainIndices) {
00178 this.evidenceDomainIndices = evidenceDomainIndices;
00179 getOrdering(evidenceDomainIndices);
00180 if(true) {
00181 System.out.println("sampling backward: " + this.backwardSampledNodes);
00182 System.out.println("sampling forward: " + this.forwardSampledNodes);
00183 System.out.println("not in order: " + this.outsideSamplingOrder);
00184 }
00185 }
00186
00187 @Override
00188 public SampledDistribution _infer() throws Exception {
00189 Stopwatch sw = new Stopwatch();
00190 sw.start();
00191
00192 this.prepareInference(evidenceDomainIndices);
00193
00194 this.createDistribution();
00195 System.out.println("sampling...");
00196 WeightedSample s = new WeightedSample(this.bn, evidenceDomainIndices.clone(), 1.0, null, 0);
00197 for(int i = 1; i <= this.numSamples; i++) {
00198 if(i % infoInterval == 0)
00199 System.out.println(" step " + i);
00200 getSample(s);
00201 this.addSample(s);
00202 }
00203
00204 sw.stop();
00205 System.out.println(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/step)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep()));
00206 return this.dist;
00207 }
00208
00213 public void getSample(WeightedSample s) {
00214 int MAX_TRIALS = 5000;
00215 boolean debug = true;
00216 loop1: for(int t = 1; t <= MAX_TRIALS; t++) {
00217
00218 s.nodeDomainIndices = evidenceDomainIndices.clone();
00219 s.weight = 1.0;
00220
00221 for(BeliefNode node : backwardSampledNodes) {
00222 if(!sampleBackward(node, s)) {
00223 if(debug) System.out.println("!!! backward sampling failed at " + node);
00224 continue loop1;
00225 }
00226 }
00227
00228
00229 for(BeliefNode node : forwardSampledNodes) {
00230 if(!sampleForward(node, s)) {
00231 if(debug) System.out.println("!!! forward sampling failed at " + node);
00232 continue loop1;
00233 }
00234 }
00235
00236
00237 for(BeliefNode node : outsideSamplingOrder) {
00238 s.weight *= this.getCPTProbability(node, s.nodeDomainIndices);
00239 if(s.weight == 0.0) {
00240
00241 if(debug) System.out.println("!!! weight became zero at unordered node " + node);
00242 continue loop1;
00243 }
00244 }
00245 s.trials = t;
00246 return;
00247 }
00248 throw new RuntimeException("Maximum number of trials exceeded.");
00249 }
00250
00251 protected boolean sampleForward(BeliefNode node, WeightedSample s) {
00252 int idx = super.sampleForward(node, s.nodeDomainIndices);
00253 if(idx == -1)
00254 return false;
00255 s.nodeDomainIndices[this.nodeIndices.get(node)] = idx;
00256 return true;
00257 }
00258 }