00001 package edu.tum.cs.srl.bayesnets.inference;
00002
00003 import java.util.HashMap;
00004 import java.util.HashSet;
00005
00006 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00007 import edu.ksu.cis.bnj.ver3.core.CPF;
00008 import edu.ksu.cis.bnj.ver3.core.Discrete;
00009 import edu.ksu.cis.bnj.ver3.core.Domain;
00010 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00011 import edu.tum.cs.bayesnets.inference.BackwardSamplingWithPriors;
00012 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00013 import edu.tum.cs.bayesnets.inference.WeightedSample;
00014 import edu.tum.cs.srl.bayesnets.bln.AbstractGroundBLN;
00015 import edu.tum.cs.srl.bayesnets.bln.GroundBLN;
00016 import edu.tum.cs.util.Stopwatch;
00017 import edu.tum.cs.util.datastruct.Cache2D;
00018 import edu.tum.cs.util.datastruct.MutableDouble;
00019
00020 public class LiftedBackwardSampling extends Sampler {
00021
00025 HashMap<BeliefNode,Integer> node2class = new HashMap<BeliefNode, Integer>();
00026
00027 public LiftedBackwardSampling(GroundBLN gbln) throws Exception {
00028 super(gbln);
00029 }
00030
00031 @Override
00032 public SampledDistribution _infer() throws Exception {
00033
00034
00035
00036 System.out.println("computing equivalence classes...");
00037 Integer classNo = 0;
00038 Cache2D<String, String, Integer> classes = new Cache2D<String, String, Integer>();
00039 BeliefNetworkEx groundBN = gbln.getGroundNetwork();
00040 for(BeliefNode node : groundBN.bn.getNodes()) {
00041
00042 StringBuffer key = new StringBuffer();
00043 BeliefNode[] domprod = node.getCPF().getDomainProduct();
00044 for(int i = 1; i < domprod.length; i++) {
00045 key.append(",").append(gbln.getCPFID(domprod[i]));
00046 for(BeliefNode c : groundBN.bn.getChildren(domprod[i])) {
00047 for(BeliefNode d : c.getCPF().getDomainProduct()) {
00048 key.append(",").append(gbln.getCPFID(d));
00049 }
00050 }
00051 }
00052 String skey = key.toString();
00053
00054 String mainCPFID = gbln.getCPFID(node);
00055 if(mainCPFID == null)
00056 throw new Exception("Node " + node + " has no CPF-ID");
00057 Integer value = classes.get(mainCPFID, skey);
00058 if(value == null) {
00059 value = ++classNo;
00060 classes.put(classNo);
00061 }
00062 node2class.put(node, value);
00063 if(debug)
00064 System.out.println(node + " is class " + value + "\n " + mainCPFID + skey);
00065 }
00066 System.out.println(" reduced " + groundBN.bn.getNodes().length + " nodes to " + classNo + " equivalence classes");
00067
00068
00069 String[][] evidence = this.gbln.getDatabase().getEntriesAsArray();
00070 int[] evidenceDomainIndices = gbln.getFullEvidence(evidence);
00071 Sampler sampler = new Sampler(gbln.getGroundNetwork());
00072 sampler.setDebugMode(debug);
00073 sampler.setNumSamples(numSamples);
00074 sampler.setInfoInterval(infoInterval);
00075 sampler.setEvidence(evidenceDomainIndices);
00076
00077
00078 SampledDistribution dist = sampler.infer();
00079
00080 return dist;
00081 }
00082
00088 protected class Sampler extends BackwardSamplingWithPriors {
00089
00090
00091
00092
00093 protected Cache2D<String, Integer, Double> probCache;
00097 protected Cache2D<Integer, Long, BackSamplingDistribution> distCache;
00098 protected Stopwatch probSW, distSW;
00099 protected boolean useDistributionCache = true;
00100 protected boolean useProbabilityCache = false;
00101
00102 public class BackSamplingDistribution extends edu.tum.cs.bayesnets.inference.BackwardSamplingWithPriors.BackSamplingDistribution {
00103
00104 public BackSamplingDistribution(BackwardSamplingWithPriors sampler) {
00105 super(sampler);
00106 }
00107
00114 @Override
00115 protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
00116 BeliefNode[] domProd = cpf.getDomainProduct();
00117 if(i == addr.length) {
00118 double child_prob = cpf.getDouble(addr);
00119
00120 boolean[] tempEvidence = new boolean[addr.length];
00121 for(int k = 1; k < addr.length; k++) {
00122 int nodeIdx = sampler.nodeIndices.get(domProd[k]);
00123 tempEvidence[k] = nodeDomainIndices[nodeIdx] == -1;
00124 if(tempEvidence[k])
00125 nodeDomainIndices[nodeIdx] = addr[k];
00126 }
00127
00128 double parent_prob = 1.0;
00129 HashSet<BeliefNode> handledChildren = new HashSet<BeliefNode>();
00130 handledChildren.add(domProd[0]);
00131 for(int j = 1; j < addr.length; j++) {
00132 double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
00133 parent_prob *= parentPrior[addr[j]];
00134
00135
00136 BeliefNode[] children = sampler.bn.bn.getChildren(domProd[j]);
00137 for(BeliefNode child : children) {
00138 if(nodeDomainIndices[sampler.getNodeIndex(child)] >= 0 && !handledChildren.contains(child)) {
00139
00140 double p = getProb(child, nodeDomainIndices);
00141 parent_prob *= p;
00142 handledChildren.add(child);
00143 }
00144 }
00145 }
00146
00147 for(int k = 1; k < addr.length; k++) {
00148 if(tempEvidence[k])
00149 nodeDomainIndices[sampler.nodeIndices.get(domProd[k])] = -1;
00150 }
00151
00152 double p = child_prob * parent_prob;
00153 if(p != 0) {
00154 addValue(p, addr.clone());
00155 parentProbs.add(parent_prob);
00156 }
00157 return;
00158 }
00159 int nodeIdx = sampler.nodeIndices.get(domProd[i]);
00160 if(nodeDomainIndices[nodeIdx] >= 0) {
00161 addr[i] = nodeDomainIndices[nodeIdx];
00162 construct(i+1, addr, cpf, nodeDomainIndices);
00163 }
00164 else {
00165 Discrete dom = (Discrete)domProd[i].getDomain();
00166 for(int j = 0; j < dom.getOrder(); j++) {
00167 addr[i] = j;
00168 construct(i+1, addr, cpf, nodeDomainIndices);
00169 }
00170 }
00171 }
00172
00173 protected double getProb(BeliefNode node, int[] nodeDomainIndices) {
00174 CPF cpf = node.getCPF();
00175 boolean debugCache = debug;
00176 probSW.start();
00177
00178 Double cacheValue = null;
00179 BeliefNode[] domProd = cpf.getDomainProduct();
00180 int[] addr = new int[domProd.length];
00181 boolean allSet = true;
00182 int key = 0;
00183 for(int i = 0; i < addr.length; i++) {
00184 int idx = nodeDomainIndices[sampler.getNodeIndex(domProd[i])];
00185 allSet = allSet && idx >= 0;
00186 addr[i] = idx;
00187 key *= cpf._SizeBuffer[i]+1;
00188 key += idx == -1 ? cpf._SizeBuffer[i] : idx;
00189 }
00190 if(allSet) {
00191 probSW.stop();
00192 return cpf.getDouble(addr);
00193 }
00194
00195 Double value = null;
00196 if(useProbabilityCache)
00197 value = cacheValue = probCache.get(gbln.getCPFID(node), key);
00198 if(value != null) {
00199 probSW.stop();
00200 if(!debugCache)
00201 return value;
00202 }
00203
00204 MutableDouble p = new MutableDouble(0.0);
00205 getProb(cpf, 0, addr, nodeDomainIndices, p);
00206
00207 if(useProbabilityCache) {
00208 probCache.put(p.value);
00209 if(cacheValue != null && p.value != cacheValue) {
00210 throw new RuntimeException("Probability cache mismatch");
00211 }
00212 }
00213
00214 probSW.stop();
00215 return p.value;
00216 }
00217
00227 protected void getProb(CPF cpf, int i, int[] addr, int[] nodeDomainIndices, MutableDouble ret) {
00228 BeliefNode[] domProd = cpf.getDomainProduct();
00229
00230 if(i == addr.length) {
00231 double p = cpf.getDouble(addr);
00232 for(int j = 1; j < addr.length; j++) {
00233 if(nodeDomainIndices[sampler.getNodeIndex(domProd[j])] == -1); {
00234 double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
00235 p *= parentPrior[addr[j]];
00236 }
00237 }
00238 ret.value += p;
00239 return;
00240 }
00241
00242 BeliefNode node = domProd[i];
00243 int nodeIdx = sampler.getNodeIndex(node);
00244
00245 if(nodeDomainIndices[nodeIdx] >= 0) {
00246 addr[i] = nodeDomainIndices[nodeIdx];
00247 getProb(cpf, i+1, addr, nodeDomainIndices, ret);
00248 }
00249
00250 else {
00251 Domain dom = node.getDomain();
00252 for(int j = 0; j < dom.getOrder(); j++) {
00253 addr[i] = j;
00254 getProb(cpf, i+1, addr, nodeDomainIndices, ret);
00255 }
00256 }
00257 }
00258 }
00259
00260 @Override
00261 protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
00262 BackSamplingDistribution d;
00263 long key = 0;
00264 distSW.start();
00265
00266 if(useDistributionCache) {
00267
00268 BeliefNode[] domProd = node.getCPF().getDomainProduct();
00269
00270 for(int i = 0; i < domProd.length; i++) {
00271 BeliefNode n = domProd[i];
00272 int idx = s.nodeDomainIndices[getNodeIndex(n)];
00273 int order = n.getDomain().getOrder();
00274 key *= order + 1;
00275 key += idx == -1 ? order : idx;
00276
00277 if(i != 0) {
00278 BeliefNode[] children = bn.bn.getChildren(n);
00279 for(int j = 0; j < children.length; j++) {
00280 if(children[j] != node) {
00281 n = children[j];
00282 idx = s.nodeDomainIndices[getNodeIndex(n)];
00283 order = n.getDomain().getOrder();
00284 key *= order + 1;
00285 key += idx == -1 ? order : idx;
00286
00287 BeliefNode[] parentsofchildren = children[j].getCPF().getDomainProduct();
00288 for(int k = 1; k < parentsofchildren.length; k++) {
00289 n = parentsofchildren[k];
00290 idx = s.nodeDomainIndices[getNodeIndex(n)];
00291 order = n.getDomain().getOrder();
00292 key *= order + 1;
00293 key += idx == -1 ? order : idx;
00294 }
00295 }
00296 }
00297 }
00298 }
00299
00300
00301 d = distCache.get(node2class.get(node), key);
00302 if(d != null)
00303 return d;
00304 }
00305
00306
00307 d = new BackSamplingDistribution(this);
00308 d.construct(node, s.nodeDomainIndices);
00309
00310
00311 if(useDistributionCache)
00312 distCache.put(d);
00313
00314 distSW.stop();
00315 return d;
00316 }
00317
00318 public Sampler(BeliefNetworkEx bn) throws Exception {
00319 super(bn);
00320 }
00321
00322 @Override
00323 public void prepareInference(int[] evidenceDomainIndices) throws Exception {
00324 probCache = new Cache2D<String, Integer, Double>();
00325 distCache = new Cache2D<Integer, Long, BackSamplingDistribution>();
00326 super.prepareInference(evidenceDomainIndices);
00327 }
00328
00329 @Override
00330 public SampledDistribution _infer() throws Exception {
00331 probSW = new Stopwatch();
00332 distSW = new Stopwatch();
00333 SampledDistribution d = super.infer();
00334 System.out.println("prob time: " + probSW.getElapsedTimeSecs());
00335 System.out.println(String.format(" cache hit ratio: %f (%d accesses)", this.probCache.getHitRatio(), this.probCache.getNumAccesses()));
00336 System.out.println("dist time: " + distSW.getElapsedTimeSecs());
00337 System.out.println(String.format(" cache hit ratio: %f (%d accesses)", this.distCache.getHitRatio(), this.distCache.getNumAccesses()));
00338 System.out.println();
00339 return d;
00340 }
00341 }
00342 }