00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.io.PrintStream;
00004
00005 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00006 import edu.ksu.cis.bnj.ver3.core.Discrete;
00007 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00008 import edu.tum.cs.inference.BasicSampledDistribution;
00009
00016 public class SampledDistribution extends BasicSampledDistribution implements Cloneable {
00020 public BeliefNetworkEx bn;
00024 public int steps, trials, operations;
00025 protected double maxWeight = 0.0;
00026 protected boolean debug = true;
00027 protected BeliefNode[] nodes;
00028
00029 public SampledDistribution(BeliefNetworkEx bn) throws Exception {
00030 this.bn = bn;
00031 this.Z = 0.0;
00032 nodes = bn.bn.getNodes();
00033 values = new double[nodes.length][];
00034 for(int i = 0; i < nodes.length; i++)
00035 values[i] = new double[nodes[i].getDomain().getOrder()];
00036 }
00037
00038 public synchronized void addSample(WeightedSample s) {
00039 if(s.weight == 0.0) {
00040 throw new RuntimeException("Zero-weight sample was added to distribution. Precision loss?");
00041 }
00042
00043
00044 Z += s.weight;
00045 if(maxWeight < s.weight)
00046 maxWeight = s.weight;
00047
00048
00049 if(debug) {
00050 double prob = bn.getWorldProbability(s.nodeDomainIndices);
00051
00052
00053
00054 System.out.printf("sample weight: %s (%.2f%%); max weight: %s (%.2f%%); prob: %s\n", s.weight, s.weight*100/Z, maxWeight, maxWeight*100/Z, prob);
00055 }
00056
00057
00058 for(int i = 0; i < s.nodeIndices.length; i++) {
00059 try {
00060 values[s.nodeIndices[i]][s.nodeDomainIndices[i]] += s.weight;
00061 }
00062 catch(ArrayIndexOutOfBoundsException e) {
00063 System.err.println("Error: Node " + nodes[s.nodeIndices[i]].getName() + " was not sampled correctly.");
00064 throw e;
00065 }
00066 }
00067
00068
00069 trials += s.trials;
00070 operations += s.operations;
00071 steps++;
00072 }
00073
00074 @Override
00075 public void printVariableDistribution(PrintStream out, int index) {
00076 BeliefNode node = nodes[index];
00077 out.println(node.getName() + ":");
00078 Discrete domain = (Discrete)node.getDomain();
00079 for(int j = 0; j < domain.getOrder(); j++) {
00080 double prob = values[index][j] / Z;
00081 out.println(String.format(" %.4f %s", prob, domain.getName(j)));
00082 }
00083 }
00084
00085 public double getTrialsPerStep() {
00086 return (double)trials/steps;
00087 }
00088
00089 @Override
00090 public synchronized SampledDistribution clone() throws CloneNotSupportedException {
00091 return (SampledDistribution)super.clone();
00092 }
00093
00094 @Override
00095 public String[] getDomain(int idx) {
00096 return BeliefNetworkEx.getDiscreteDomainAsArray(bn.getNode(idx));
00097 }
00098
00099 @Override
00100 public String getVariableName(int idx) {
00101 return bn.getNode(idx).getName();
00102 }
00103
00104 @Override
00105 public int getVariableIndex(String name) {
00106 return bn.getNodeIndex(name);
00107 }
00108
00109 public void setDebugMode(boolean active) {
00110 debug = active;
00111 }
00112
00113 @Override
00114 public Integer getNumSamples() {
00115 return steps;
00116 }
00117 }