00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.HashSet;
00004
00005 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00006 import edu.ksu.cis.bnj.ver3.core.CPF;
00007 import edu.ksu.cis.bnj.ver3.core.Discrete;
00008 import edu.ksu.cis.bnj.ver3.core.Domain;
00009 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00010 import edu.tum.cs.util.Stopwatch;
00011 import edu.tum.cs.util.datastruct.Cache2D;
00012 import edu.tum.cs.util.datastruct.MutableDouble;
00013
00023 public class BackwardSamplingWithChildren extends BackwardSamplingWithPriors {
00024
00025 protected Cache2D<CPF, Integer, Double> probCache;
00026 protected Cache2D<BeliefNode, Long, BackSamplingDistribution> distCache;
00027 protected Stopwatch probSW, distSW;
00028
00029 public class BackSamplingDistribution extends edu.tum.cs.bayesnets.inference.BackwardSamplingWithPriors.BackSamplingDistribution {
00030
00031 public BackSamplingDistribution(BackwardSamplingWithPriors sampler) {
00032 super(sampler);
00033 }
00034
00041 @Override
00042 protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
00043 BeliefNode[] domProd = cpf.getDomainProduct();
00044 if(i == addr.length) {
00045 double child_prob = cpf.getDouble(addr);
00046
00047 boolean[] tempEvidence = new boolean[addr.length];
00048 for(int k = 1; k < addr.length; k++) {
00049 int nodeIdx = sampler.nodeIndices.get(domProd[k]);
00050 tempEvidence[k] = nodeDomainIndices[nodeIdx] == -1;
00051 if(tempEvidence[k])
00052 nodeDomainIndices[nodeIdx] = addr[k];
00053 }
00054
00055 double parent_prob = 1.0;
00056 HashSet<BeliefNode> handledChildren = new HashSet<BeliefNode>();
00057 handledChildren.add(domProd[0]);
00058 for(int j = 1; j < addr.length; j++) {
00059 double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
00060 parent_prob *= parentPrior[addr[j]];
00061
00062
00063 BeliefNode[] children = sampler.bn.bn.getChildren(domProd[j]);
00064 for(BeliefNode child : children) {
00065 if(nodeDomainIndices[sampler.getNodeIndex(child)] >= 0 && !handledChildren.contains(child)) {
00066 CPF childCPF = child.getCPF();
00067
00068 double p = getProb(childCPF, nodeDomainIndices);
00069 parent_prob *= p;
00070 handledChildren.add(child);
00071 }
00072 }
00073 }
00074
00075 for(int k = 1; k < addr.length; k++) {
00076 if(tempEvidence[k])
00077 nodeDomainIndices[sampler.nodeIndices.get(domProd[k])] = -1;
00078 }
00079
00080 double p = child_prob * parent_prob;
00081 if(p != 0) {
00082 addValue(p, addr.clone());
00083 parentProbs.add(parent_prob);
00084 }
00085 return;
00086 }
00087 int nodeIdx = sampler.nodeIndices.get(domProd[i]);
00088 if(nodeDomainIndices[nodeIdx] >= 0) {
00089 addr[i] = nodeDomainIndices[nodeIdx];
00090 construct(i+1, addr, cpf, nodeDomainIndices);
00091 }
00092 else {
00093 Discrete dom = (Discrete)domProd[i].getDomain();
00094 for(int j = 0; j < dom.getOrder(); j++) {
00095 addr[i] = j;
00096 construct(i+1, addr, cpf, nodeDomainIndices);
00097 }
00098 }
00099 }
00100
00101 protected double getProb(CPF cpf, int[] nodeDomainIndices) {
00102 final boolean debugCache = false;
00103 probSW.start();
00104
00105 Double cacheValue = null;
00106 BeliefNode[] domProd = cpf.getDomainProduct();
00107 int[] addr = new int[domProd.length];
00108 boolean allSet = true;
00109 int key = 0;
00110 for(int i = 0; i < addr.length; i++) {
00111 int idx = nodeDomainIndices[sampler.getNodeIndex(domProd[i])];
00112 allSet = allSet && idx >= 0;
00113 addr[i] = idx;
00114 key *= cpf._SizeBuffer[i]+1;
00115 key += idx == -1 ? cpf._SizeBuffer[i] : idx;
00116 }
00117 if(allSet) {
00118 probSW.stop();
00119 return cpf.getDouble(addr);
00120 }
00121
00122 Double value = cacheValue = probCache.get(cpf, key);
00123 if(!debugCache && value != null) {
00124 probSW.stop();
00125 return value;
00126 }
00127
00128 MutableDouble p = new MutableDouble(0.0);
00129 getProb(cpf, 0, addr, nodeDomainIndices, p);
00130
00131 probCache.put(p.value);
00132
00133 if(cacheValue != null && p.value != cacheValue) {
00134 throw new RuntimeException("cache mismatch");
00135 }
00136 probSW.stop();
00137 return p.value;
00138 }
00139
00149 protected void getProb(CPF cpf, int i, int[] addr, int[] nodeDomainIndices, MutableDouble ret) {
00150 BeliefNode[] domProd = cpf.getDomainProduct();
00151
00152 if(i == addr.length) {
00153 double p = cpf.getDouble(addr);
00154 for(int j = 1; j < addr.length; j++) {
00155 if(nodeDomainIndices[sampler.getNodeIndex(domProd[j])] == -1); {
00156 double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
00157 p *= parentPrior[addr[j]];
00158 }
00159 }
00160 ret.value += p;
00161 return;
00162 }
00163
00164 BeliefNode node = domProd[i];
00165 int nodeIdx = sampler.getNodeIndex(node);
00166
00167 if(nodeDomainIndices[nodeIdx] >= 0) {
00168 addr[i] = nodeDomainIndices[nodeIdx];
00169 getProb(cpf, i+1, addr, nodeDomainIndices, ret);
00170 }
00171
00172 else {
00173 Domain dom = node.getDomain();
00174 for(int j = 0; j < dom.getOrder(); j++) {
00175 addr[i] = j;
00176 getProb(cpf, i+1, addr, nodeDomainIndices, ret);
00177 }
00178 }
00179 }
00180 }
00181
00182 @Override
00183 protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
00184 BackSamplingDistribution d;
00185 long key = 0;
00186 final boolean useCache = true;
00187 distSW.start();
00188
00189 if(useCache) {
00190
00191 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00192
00193 for(int i = 0; i < domProd.length; i++) {
00194 BeliefNode n = domProd[i];
00195 int idx = s.nodeDomainIndices[getNodeIndex(n)];
00196 int order = n.getDomain().getOrder();
00197 key *= order + 1;
00198 key += idx == -1 ? order : idx;
00199
00200 if(i != 0) {
00201 BeliefNode[] children = bn.bn.getChildren(n);
00202 for(int j = 0; j < children.length; j++) {
00203 if(children[j] != node) {
00204 n = children[j];
00205 idx = s.nodeDomainIndices[getNodeIndex(n)];
00206 order = n.getDomain().getOrder();
00207 key *= order + 1;
00208 key += idx == -1 ? order : idx;
00209
00210 BeliefNode[] parentsofchildren = children[j].getCPF().getDomainProduct();
00211 for(int k = 1; k < parentsofchildren.length; k++) {
00212 n = parentsofchildren[k];
00213 idx = s.nodeDomainIndices[getNodeIndex(n)];
00214 order = n.getDomain().getOrder();
00215 key *= order + 1;
00216 key += idx == -1 ? order : idx;
00217 }
00218 }
00219 }
00220 }
00221 }
00222
00223
00224 d = distCache.get(node, key);
00225 if(d != null)
00226 return d;
00227 }
00228
00229
00230 d = new BackSamplingDistribution(this);
00231 d.construct(node, s.nodeDomainIndices);
00232
00233
00234 if(useCache)
00235 distCache.put(d);
00236
00237 distSW.stop();
00238 return d;
00239 }
00240
00241 public BackwardSamplingWithChildren(BeliefNetworkEx bn) throws Exception {
00242 super(bn);
00243 }
00244
00245 @Override
00246 public void prepareInference(int[] evidenceDomainIndices) throws Exception {
00247 probCache = new Cache2D<CPF, Integer, Double>();
00248 distCache = new Cache2D<BeliefNode, Long, BackSamplingDistribution>();
00249 super.prepareInference(evidenceDomainIndices);
00250 }
00251
00252 public SampledDistribution infer() throws Exception {
00253 probSW = new Stopwatch();
00254 distSW = new Stopwatch();
00255 SampledDistribution d = super.infer();
00256 report("prob time: " + probSW.getElapsedTimeSecs());
00257 report(String.format(" cache hit ratio: %f (%d accesses)", this.probCache.getHitRatio(), this.probCache.getNumAccesses()));
00258 report("dist time: " + distSW.getElapsedTimeSecs());
00259 report(String.format(" cache hit ratio: %f (%d accesses)", this.distCache.getHitRatio(), this.distCache.getNumAccesses()));
00260 return d;
00261 }
00262 }