00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.HashMap;
00004 import java.util.Vector;
00005
00006 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00007 import edu.ksu.cis.bnj.ver3.core.CPF;
00008 import edu.ksu.cis.bnj.ver3.core.Discrete;
00009 import edu.ksu.cis.bnj.ver3.core.Domain;
00010 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00011
00012 public class BackwardSamplingWithPriors extends BackwardSampling {
00013
00014 public HashMap<BeliefNode, double[]> priors;
00015
00016 public static class BackSamplingDistribution extends edu.tum.cs.bayesnets.inference.BackwardSampling.BackSamplingDistribution {
00017
00018 public Vector<Double> parentProbs;
00019
00020 public BackSamplingDistribution(BackwardSamplingWithPriors sampler) {
00021 super(sampler);
00022 parentProbs = new Vector<Double>();
00023 }
00024
00031 @Override
00032 protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
00033 BeliefNode[] domProd = cpf.getDomainProduct();
00034 if(i == addr.length) {
00035 double child_prob = cpf.getDouble(addr);
00036 double parent_prob = 1.0;
00037 for(int j = 1; j < addr.length; j++) {
00038 double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
00039 parent_prob *= parentPrior[addr[j]];
00040 }
00041 double p = child_prob * parent_prob;
00042 if(p != 0) {
00043 addValue(p, addr.clone());
00044 parentProbs.add(parent_prob);
00045 }
00046 return;
00047 }
00048 int nodeIdx = sampler.nodeIndices.get(domProd[i]);
00049 if(nodeDomainIndices[nodeIdx] >= 0) {
00050 addr[i] = nodeDomainIndices[nodeIdx];
00051 construct(i+1, addr, cpf, nodeDomainIndices);
00052 }
00053 else {
00054 Discrete dom = (Discrete)domProd[i].getDomain();
00055 for(int j = 0; j < dom.getOrder(); j++) {
00056 addr[i] = j;
00057 construct(i+1, addr, cpf, nodeDomainIndices);
00058 }
00059 }
00060 }
00061
00062 @Override
00063 public void applyWeight(WeightedSample s, int sampledValue) {
00064 s.weight *= Z / parentProbs.get(sampledValue);
00065 }
00066 }
00067
00068 public BackwardSamplingWithPriors(BeliefNetworkEx bn) throws Exception {
00069 super(bn);
00070 }
00071
00072 @Override
00073 protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
00074 BackSamplingDistribution d = new BackSamplingDistribution(this);
00075 d.construct(node, s.nodeDomainIndices);
00076 return d;
00077 }
00078
00079 @Override
00080 protected void prepareInference(int[] evidenceDomainIndices) throws Exception {
00081 super.prepareInference(evidenceDomainIndices);
00082 if(verbose) out.println("computing priors...");
00083 computePriors(evidenceDomainIndices);
00084 }
00085
00086 protected void computePriors(int[] evidenceDomainIndices) {
00087 priors = new HashMap<BeliefNode, double[]>();
00088 int[] topOrder = bn.getTopologicalOrder();
00089 for(int i : topOrder) {
00090 BeliefNode node = nodes[i];
00091 double[] dist = new double[node.getDomain().getOrder()];
00092 int evidence = evidenceDomainIndices[i];
00093 if(evidence >= 0) {
00094 for(int j = 0; j < dist.length; j++)
00095 dist[j] = evidence == j ? 1.0 : 0.0;
00096 }
00097 else {
00098 CPF cpf = node.getCPF();
00099 computePrior(cpf, 0, new int[cpf.getDomainProduct().length], dist);
00100 }
00101 priors.put(node, dist);
00102 }
00103 }
00104
00105 protected void computePrior(CPF cpf, int i, int[] addr, double[] dist) {
00106 BeliefNode[] domProd = cpf.getDomainProduct();
00107 if(i == addr.length) {
00108 double p = cpf.getDouble(addr);
00109 for(int j = 1; j < addr.length; j++) {
00110 double[] parentPrior = priors.get(domProd[j]);
00111 p *= parentPrior[addr[j]];
00112 }
00113 dist[addr[0]] += p;
00114 return;
00115 }
00116 BeliefNode node = domProd[i];
00117 int nodeIdx = getNodeIndex(node);
00118 if(evidenceDomainIndices[nodeIdx] >= 0) {
00119 addr[i] = evidenceDomainIndices[nodeIdx];
00120 computePrior(cpf, i+1, addr, dist);
00121 }
00122 else {
00123 Domain dom = node.getDomain();
00124 for(int j = 0; j < dom.getOrder(); j++) {
00125 addr[i] = j;
00126 computePrior(cpf, i+1, addr, dist);
00127 }
00128 }
00129 }
00130 }