00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.HashSet;
00004 import java.util.Vector;
00005
00006 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00007 import edu.ksu.cis.bnj.ver3.core.CPF;
00008 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00009 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00010 import edu.tum.cs.util.Stopwatch;
00011 import edu.tum.cs.util.StringTool;
00012
00017 public class VariableElimination extends Sampler {
00018 int[] nodeOrder;
00019 Stopwatch timer;
00020 int[] nodeDomainIndices;
00021
00022 public VariableElimination(BeliefNetworkEx bn) throws Exception {
00023 super(bn);
00024 nodeOrder = bn.getTopologicalOrder();
00025 }
00026
00027 protected class Factor {
00028 CPF cpf;
00029
00030 public Factor(BeliefNode n) {
00031 cpf = n.getCPF();
00032 BeliefNode[] domprod = cpf.getDomainProduct();
00033 for(int i = 0; i < domprod.length; i++) {
00034 if(evidenceDomainIndices[getNodeIndex(domprod[i])] != 0) {
00035 cpf = removeEvidence(cpf);
00036 break;
00037 }
00038 }
00039 }
00040
00041 protected CPF removeEvidence(CPF cpf) {
00042 BeliefNode[] domprod = cpf.getDomainProduct();
00043 Vector<BeliefNode> domprod2 = new Vector<BeliefNode>();
00044 for(int i = 0; i < domprod.length; i++)
00045 if(evidenceDomainIndices[getNodeIndex(domprod[i])] == -1)
00046 domprod2.add(domprod[i]);
00047 CPF cpf2 = new CPF(domprod2.toArray(new BeliefNode[domprod2.size()]));
00048 int[] addr = new int[domprod.length];
00049 int[] addr2 = new int[domprod2.size()];
00050 removeEvidence(cpf, cpf2, 0, addr, 0, addr2);
00051 return cpf2;
00052 }
00053
00054 protected void removeEvidence(CPF cpf, CPF cpf2, int i, int[] addr, int j, int[] addr2) {
00055 if(i == addr.length) {
00056 cpf2.put(addr2, cpf.get(addr));
00057 return;
00058 }
00059
00060 BeliefNode[] domprod = cpf.getDomainProduct();
00061 BeliefNode[] domprod2 = cpf2.getDomainProduct();
00062 BeliefNode node = domprod[i];
00063 boolean transfer = false;
00064 if(j < domprod2.length)
00065 transfer = domprod2[j] == domprod[i];
00066 int evidence = evidenceDomainIndices[getNodeIndex(node)];
00067 if(evidence != -1) {
00068 addr[i] = evidence;
00069 if(transfer)
00070 addr2[j] = evidence;
00071 removeEvidence(cpf, cpf2, i+1, addr, transfer ? j+1 : j, addr2);
00072 }
00073 else {
00074 int domSize = node.getDomain().getOrder();
00075 for(int domIdx = 0; domIdx < domSize; domIdx++) {
00076 addr[i] = domIdx;
00077 if(transfer)
00078 addr2[j] = domIdx;
00079 removeEvidence(cpf, cpf2, i+1, addr, transfer ? j+1 : j, addr2);
00080 }
00081 }
00082 }
00083
00084 public Factor(CPF cpf) {
00085 this.cpf = cpf;
00086 }
00087
00088 public double getValue(int[] nodeDomainIndices) {
00089 BeliefNode[] domProd = cpf.getDomainProduct();
00090 int[] addr = new int[domProd.length];
00091 for(int i = 0; i < addr.length; i++)
00092 addr[i] = nodeDomainIndices[getNodeIndex(domProd[i])];
00093 return cpf.getDouble(addr);
00094 }
00095
00096 public Factor sumOut(BeliefNode n) {
00097 BeliefNode[] domprod = cpf.getDomainProduct();
00098 BeliefNode[] domprod2 = new BeliefNode[domprod.length-1];
00099 int j = 0;
00100 for(int i = 0; i < domprod.length; i++)
00101 if(domprod[i] != n)
00102 domprod2[j++] = domprod[i];
00103 CPF cpf2 = new CPF(domprod2);
00104 int[] addr = new int[domprod.length];
00105 int[] addr2 = new int[domprod2.length];
00106 sumOut(cpf2, n, 0, addr, 0, addr2);
00107 return new Factor(cpf2);
00108 }
00109
00110 protected void sumOut(CPF cpf2, BeliefNode n, int i, int[] addr, int j, int[] addr2) {
00111 if(i == addr.length) {
00112 int realaddr2 = cpf2.addr2realaddr(addr2);
00113 double v = cpf2.getDouble(realaddr2);
00114 v += cpf.getDouble(addr);
00115 cpf2.put(realaddr2, new ValueDouble(v));
00116 return;
00117 }
00118
00119 BeliefNode node = this.cpf.getDomainProduct()[i];
00120 int evidence = evidenceDomainIndices[getNodeIndex(node)];
00121 if(evidence != -1) {
00122 addr[i] = evidence;
00123 if(node != n)
00124 addr2[j] = evidence;
00125 sumOut(cpf2, n, i+1, addr, node == n ? j : j+1, addr2);
00126 }
00127 else {
00128 int domSize = node.getDomain().getOrder();
00129 for(int domIdx = 0; domIdx < domSize; domIdx++) {
00130 addr[i] = domIdx;
00131 if(node != n)
00132 addr2[j] = domIdx;
00133 sumOut(cpf2, n, i+1, addr, node == n ? j : j+1, addr2);
00134 }
00135 }
00136 }
00137
00138 public String toString() {
00139 return "F(" + StringTool.join(",", cpf.getDomainProduct()) + ")";
00140 }
00141 }
00142
00143 protected Factor join(Iterable<Factor> factors) {
00144 HashSet<BeliefNode> domain = new HashSet<BeliefNode>();
00145 for(Factor f : factors) {
00146 for(BeliefNode n : f.cpf.getDomainProduct())
00147 domain.add(n);
00148 }
00149 BeliefNode[] domProd = domain.toArray(new BeliefNode[domain.size()]);
00150 CPF cpf = new CPF(domProd);
00151 int[] addr = new int[domProd.length];
00152 fillCPF(factors, cpf, 0, addr);
00153 return new Factor(cpf);
00154 }
00155
00156 protected void fillCPF(Iterable<Factor> factors, CPF cpf, int i, int[] addr) {
00157 if(i == addr.length) {
00158 double value = 1.0;
00159 for(Factor f : factors) {
00160 value *= f.getValue(nodeDomainIndices);
00161 }
00162 cpf.put(addr, new ValueDouble(value));
00163 return;
00164 }
00165 BeliefNode[] domProd = cpf.getDomainProduct();
00166 int domSize = domProd[i].getDomain().getOrder();
00167 for(int j = 0; j < domSize; j++) {
00168 addr[i] = j;
00169 nodeDomainIndices[getNodeIndex(domProd[i])] = j;
00170 fillCPF(factors, cpf, i+1, addr);
00171 }
00172 }
00173
00174 protected Vector<Factor> sumout(Vector<Factor> factors, BeliefNode n) {
00175 Vector<Factor> newFacs = new Vector<Factor>();
00176 Vector<Factor> joinFacs = new Vector<Factor>();
00177 for(Factor f : factors) {
00178 BeliefNode[] domProd = f.cpf.getDomainProduct();
00179 boolean sumover = false;
00180 for(int i = 0; i < domProd.length; i++)
00181 if(domProd[i] == n)
00182 sumover = true;
00183 if(sumover)
00184 joinFacs.add(f);
00185 else
00186 newFacs.add(f);
00187 }
00188 Factor joinedFac = join(joinFacs);
00189 if(debug) out.println("Summing out " + n + " from " + joinedFac);
00190 newFacs.add(joinedFac.sumOut(n));
00191 return newFacs;
00192 }
00193
00194 protected void computeMarginal(BeliefNode Q) {
00195 Vector<Factor> factors = new Vector<Factor>();
00196 for(int i = nodeOrder.length-1; i >= 0; i--) {
00197 if(!debug) out.printf(" %s %d \r", Q.getName(), i);
00198 int nodeIdx = nodeOrder[i];
00199 BeliefNode node = nodes[nodeIdx];
00200 if(debug) out.println("Current node: " + node);
00201
00202 factors.add(new Factor(node));
00203 if(debug) out.println(factors);
00204
00205 if(evidenceDomainIndices[nodeIdx] == -1 && node != Q)
00206 factors = sumout(factors, node);
00207 }
00208 if(!debug) out.println();
00209
00210 if(debug) out.printf("%d final factors: %s\n", factors.size(), StringTool.join(", ", factors));
00211
00212
00213 int nodeIdx = getNodeIndex(Q);
00214 double[] marginal = new double[Q.getDomain().getOrder()];
00215 double Z = 0.0;
00216 for(int i = 0; i < marginal.length; i++) {
00217 nodeDomainIndices[nodeIdx] = i;
00218 marginal[i] = 1.0;
00219 for(Factor f : factors) {
00220 marginal[i] *= f.getValue(nodeDomainIndices);
00221 }
00222
00223 Z += marginal[i];
00224 }
00225 for(int i = 0; i < marginal.length; i++)
00226 marginal[i] /= Z;
00227 dist.values[nodeIdx] = marginal;
00228 }
00229
00230 public SampledDistribution _infer() throws Exception {
00231 Stopwatch sw = new Stopwatch();
00232
00233 createDistribution();
00234 dist.Z = 1.0;
00235
00236 sw.start();
00237
00238 nodeDomainIndices = evidenceDomainIndices.clone();
00239 for(Integer nodeIdx : queryVars)
00240 computeMarginal(nodes[nodeIdx]);
00241
00242 sw.stop();
00243 return dist;
00244 }
00245 }