00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.HashMap;
00004
00005 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00006 import edu.ksu.cis.bnj.ver3.core.CPF;
00007 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00008 import edu.tum.cs.util.Stopwatch;
00009
00015 public class SampleSearch extends Sampler {
00016 protected int[] nodeOrder;
00017 protected int currentStep;
00018
00019 public SampleSearch(BeliefNetworkEx bn) throws Exception {
00020 super(bn);
00021
00022 nodeOrder = computeNodeOrdering();
00023 }
00024
00025 protected int[] computeNodeOrdering() throws Exception {
00026 return bn.getTopologicalOrder();
00027 }
00028
00029 protected void info(int step) {
00030 out.println(" step " + step);
00031 }
00032
00033 public SampledDistribution _infer() throws Exception {
00034
00035 Stopwatch sw = new Stopwatch();
00036 createDistribution();
00037 out.println("sampling...");
00038 sw.start();
00039 WeightedSample s = new WeightedSample(bn);
00040 for(int i = 1; i <= numSamples; i++) {
00041 currentStep = i;
00042 if(i % infoInterval == 0)
00043 info(i);
00044 WeightedSample ret = getWeightedSample(s, nodeOrder, evidenceDomainIndices);
00045 if(ret != null) {
00046 if(false) {
00047 out.print("w=" + ret.weight);
00048 double prod = 1.0;
00049 for(int j = 0; j < evidenceDomainIndices.length; j++)
00050 if(true || evidenceDomainIndices[j] == -1) {
00051 BeliefNode node = nodes[j];
00052 out.print(" " + node.getName() + "=" + node.getDomain().getName(s.nodeDomainIndices[j]));
00053 double p = bn.getCPTProbability(node, s.nodeDomainIndices);
00054 out.printf(" %f", p);
00055 if(p == 0.0)
00056 throw new Exception("Sample has 0 probability.");
00057 prod *= p;
00058 if(prod == 0.0)
00059 throw new Exception("Precision loss - product became 0");
00060 }
00061 out.println();
00062 }
00063
00064 addSample(ret);
00065 }
00066 if(converged())
00067 break;
00068 }
00069 sw.stop();
00070 report(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/sample, %.4f*N assignments/sample, %d samples)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep(), (float)dist.operations/nodes.length/numSamples, dist.steps));
00071 return dist;
00072 }
00073
00074 public WeightedSample getWeightedSample(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices) throws Exception {
00075 s.trials = 0;
00076 s.operations = 0;
00077 s.weight = 1.0;
00078 s.trials++;
00079 double[] samplingProb = new double[nodeOrder.length];
00080
00081 HashMap<Integer, boolean[]> domExclusions = new HashMap<Integer, boolean[]>();
00082 for(int i=0; i < nodeOrder.length;) {
00083 s.operations++;
00084 if(i == -1)
00085 throw new Exception("It appears that the evidence is contradictory.");
00086 int nodeIdx = nodeOrder[i];
00087 int domainIdx = evidenceDomainIndices[nodeIdx];
00088
00089 boolean[] excluded = domExclusions.get(nodeIdx);
00090 if(excluded == null) {
00091 excluded = new boolean[nodes[nodeIdx].getDomain().getOrder()];
00092 domExclusions.put(nodeIdx, excluded);
00093 }
00094
00095 if(debug) {
00096 int numex = 0;
00097 for(int j=0; j<excluded.length; j++)
00098 if(excluded[j])
00099 numex++;
00100 out.printf(" step %d, node %d '%s' (%d/%d exclusions)\n", currentStep, i, nodes[nodeIdx].getName(), numex, excluded.length);
00101 }
00102
00103 if(domainIdx >= 0) {
00104 s.nodeDomainIndices[nodeIdx] = domainIdx;
00105 samplingProb[nodeIdx] = 1.0;
00106 double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
00107 if(prob != 0.0) {
00108 ++i;
00109 continue;
00110 }
00111 else {
00112 if(debug)
00113 out.println(" evidence with probability 0.0; backtracking...");
00114 }
00115 }
00116
00117 else {
00118 SampledAssignment sa = sampleForward(nodes[nodeIdx], s.nodeDomainIndices, excluded);
00119 if(sa != null) {
00120 domainIdx = sa.domIdx;
00121 samplingProb[nodeIdx] = sa.probability;
00122 s.nodeDomainIndices[nodeIdx] = domainIdx;
00123 ++i;
00124 continue;
00125 }
00126 else if(debug)
00127 out.println(" impossible case; backtracking...");
00128 }
00129
00130
00131 s.trials++;
00132 do {
00133
00134 domExclusions.remove(nodeIdx);
00135
00136 --i;
00137 if(i < 0)
00138 throw new Exception("Could not find a sample with non-zero probability. Most likely, the evidence specified has 0 probability.");
00139 nodeIdx = nodeOrder[i];
00140 boolean[] prevExcl = domExclusions.get(nodeIdx);
00141 prevExcl[s.nodeDomainIndices[nodeIdx]] = true;
00142
00143 } while(evidenceDomainIndices[nodeIdx] != -1);
00144 }
00145
00146 for(int i = 0; i < this.nodes.length; i++) {
00147 s.weight *= getCPTProbability(nodes[i], s.nodeDomainIndices) / samplingProb[i];
00148 }
00149 return s;
00150 }
00151
00152 protected class SampledAssignment {
00153 public int domIdx;
00154 public double probability;
00155 public SampledAssignment(int domainIdx, double p) {
00156 domIdx = domainIdx;
00157 probability = p;
00158 }
00159 }
00160
00168 protected SampledAssignment sampleForward(BeliefNode node, int[] nodeDomainIndices, boolean[] excluded) {
00169 CPF cpf = node.getCPF();
00170 BeliefNode[] domProd = cpf.getDomainProduct();
00171 int[] addr = new int[domProd.length];
00172
00173 for(int i = 1; i < addr.length; i++)
00174 addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
00175 addr[0] = 0;
00176 int realAddr = cpf.addr2realaddr(addr);
00177 addr[0] = 1;
00178 int diff = cpf.addr2realaddr(addr) - realAddr;
00179
00180 double[] cpt_entries = new double[domProd[0].getDomain().getOrder()];
00181 double sum = 0;
00182 for(int i = 0; i < cpt_entries.length; i++) {
00183 double value;
00184 if(excluded[i])
00185 value = 0.0;
00186 else
00187 value = cpf.getDouble(realAddr);
00188 cpt_entries[i] = value;
00189 sum += value;
00190 realAddr += diff;
00191 }
00192
00193 if(sum == 0)
00194 return null;
00195 int domIdx = sample(cpt_entries, sum, generator);
00196 return new SampledAssignment(domIdx, cpt_entries[domIdx]/sum);
00197 }
00198 }