00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.io.PrintStream;
00004 import java.util.Collection;
00005 import java.util.HashMap;
00006 import java.util.Random;
00007
00008 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00009 import edu.ksu.cis.bnj.ver3.core.CPF;
00010 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00011 import edu.tum.cs.inference.IParameterHandler;
00012 import edu.tum.cs.inference.ParameterHandler;
00013 import edu.tum.cs.inference.BasicSampledDistribution.ConfidenceInterval;
00014 import edu.tum.cs.util.Stopwatch;
00015
00016 public abstract class Sampler implements ITimeLimitedInference, IParameterHandler {
00017 public BeliefNetworkEx bn;
00018 public SampledDistribution dist;
00019 public HashMap<BeliefNode, Integer> nodeIndices;
00020 public Random generator;
00021 public BeliefNode[] nodes;
00022 public int[] evidenceDomainIndices;
00023 protected ParameterHandler paramHandler;
00024 protected Collection<Integer> queryVars = null;
00025 protected StringBuffer report = new StringBuffer();
00026 protected boolean verbose;
00027 protected PrintStream out;
00028
00032 public int numSamples = 1000;
00033
00034 protected int maxTrials = 5000;
00035 protected boolean skipFailedSteps = false;
00036 protected Double confidenceIntervalSizeThreshold = null;
00037 public double convergenceCheckInterval = 100;
00038 public double samplingTime;
00039
00043 public int infoInterval = 100;
00044
00045 public boolean debug = false;
00046
00047 public Sampler(BeliefNetworkEx bn) throws Exception {
00048 this.bn = bn;
00049 this.nodes = bn.bn.getNodes();
00050 nodeIndices = new HashMap<BeliefNode, Integer>();
00051 for(int i = 0; i < nodes.length; i++) {
00052 nodeIndices.put(nodes[i], i);
00053 }
00054 generator = new Random();
00055 setVerbose(true);
00056 paramHandler = new ParameterHandler(this);
00057 paramHandler.add("confidenceIntervalSizeThreshold", "setConfidenceIntervalSizeThreshold");
00058 paramHandler.add("randomSeed", "setRandomSeed");
00059 paramHandler.add("verbose", "setVerbose");
00060 }
00061
00062 protected void createDistribution() throws Exception {
00063 this.dist = new SampledDistribution(bn);
00064 dist.setDebugMode(debug);
00065 paramHandler.addSubhandler(dist.getParameterHandler());
00066 }
00067
00068 protected synchronized void addSample(WeightedSample s) throws Exception {
00069
00070 if(debug) {
00071 for(int i = 0; i < evidenceDomainIndices.length; i++)
00072 if(evidenceDomainIndices[i] >= 0 && s.nodeDomainIndices[i] != evidenceDomainIndices[i])
00073 throw new Exception("Attempted to add sample to distribution that does not respect evidence");
00074 }
00075
00076 this.dist.addSample(s);
00077 }
00078
00079 public void setQueryVars(Collection<Integer> queryVars) {
00080 this.queryVars = queryVars;
00081 }
00082
00083 protected boolean converged() throws Exception {
00084 if(dist.getNumSamples() % this.convergenceCheckInterval != 0)
00085 return false;
00086
00087 if(confidenceIntervalSizeThreshold != null) {
00088 if(!dist.usesConfidenceComputation())
00089 throw new Exception("Cannot determine convergence based on confidence interval size: No confidence level specified.");
00090 double max = 0;
00091 for(Integer i : queryVars) {
00092 ConfidenceInterval interval = dist.getConfidenceInterval(i, 0);
00093 max = Math.max(max, interval.getSize());
00094 }
00095 if(max <= confidenceIntervalSizeThreshold) {
00096 if(verbose) System.out.printf("Convergence criterion reached: maximum confidence interval size = %f\n", max);
00097 return true;
00098 }
00099 }
00100 return false;
00101 }
00102
00103 public void setConfidenceIntervalSizeThreshold(double t) {
00104 confidenceIntervalSizeThreshold = t;
00105 }
00106
00112 public synchronized SampledDistribution pollResults() throws CloneNotSupportedException {
00113 if(dist == null)
00114 return null;
00115 return dist.clone();
00116 }
00117
00124 public static int sample(double[] distribution, Random generator) {
00125 double sum = 0;
00126 for(int i = 0; i < distribution.length; i++)
00127 sum += distribution[i];
00128 return sample(distribution, sum, generator);
00129 }
00130
00138 public static int sample(double[] distribution, double sum, Random generator) {
00139 double random = generator.nextDouble() * sum;
00140 int ret = 0;
00141 sum = 0;
00142 int i = 0;
00143 while(sum < random && i < distribution.length) {
00144 sum += distribution[ret = i++];
00145 }
00146 return sum >= random ? ret : -1;
00147 }
00148
00155 public static int sample(Collection<Double> distribution, Random generator) {
00156 double sum = 0;
00157 for(Double d : distribution)
00158 sum += d;
00159 return sample(distribution, sum, generator);
00160 }
00161
00169 public static int sample(Collection<Double> distribution, double sum, Random generator) {
00170 double random = generator.nextDouble() * sum;
00171 sum = 0;
00172 int i = 0;
00173 for(Double d : distribution) {
00174 sum += d;
00175 if(sum >= random)
00176 return i;
00177 ++i;
00178 }
00179 return -1;
00180 }
00181
00188 protected double getCPTProbability(BeliefNode node, int[] nodeDomainIndices) {
00189 CPF cpf = node.getCPF();
00190 BeliefNode[] domProd = cpf.getDomainProduct();
00191 int[] addr = new int[domProd.length];
00192 for(int i = 0; i < addr.length; i++)
00193 addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
00194 return cpf.getDouble(addr);
00195 }
00196
00197 public void setNumSamples(int numSamples) {
00198 this.numSamples = numSamples;
00199 }
00200
00201 public void setInfoInterval(int infoInterval) {
00202 this.infoInterval = infoInterval;
00203 }
00204
00205 public void setMaxTrials(int maxTrials) {
00206 this.maxTrials = maxTrials;
00207 }
00208
00209 public void setSkipFailedSteps(boolean canSkip) {
00210 this.skipFailedSteps = canSkip;
00211 }
00212
00213 public void setEvidence(int[] evidenceDomainIndices) throws Exception {
00214 this.evidenceDomainIndices = evidenceDomainIndices;
00215 }
00216
00217 public void setRandomSeed(int seed) {
00218 generator.setSeed(seed);
00219 }
00220
00221 protected abstract SampledDistribution _infer() throws Exception;
00222
00223 public SampledDistribution infer() throws Exception {
00224 Stopwatch sw = new Stopwatch();
00225 sw.start();
00226 SampledDistribution ret = _infer();
00227 samplingTime = sw.getElapsedTimeSecs();
00228 if(verbose) out.print(report.toString());
00229 return ret;
00230 }
00231
00235 public double getSamplingTime() {
00236 return samplingTime;
00237 }
00238
00245 protected int sampleForward(BeliefNode node, int[] nodeDomainIndices) {
00246 CPF cpf = node.getCPF();
00247 BeliefNode[] domProd = cpf.getDomainProduct();
00248 int[] addr = new int[domProd.length];
00249
00250 for(int i = 1; i < addr.length; i++)
00251 addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
00252 addr[0] = 0;
00253 int realAddr = cpf.addr2realaddr(addr);
00254 addr[0] = 1;
00255 int diff = cpf.addr2realaddr(addr) - realAddr;
00256
00257 double[] cpt_entries = new double[domProd[0].getDomain().getOrder()];
00258 double sum = 0;
00259 for(int i = 0; i < cpt_entries.length; i++){
00260 cpt_entries[i] = cpf.getDouble(realAddr);
00261 sum += cpt_entries[i];
00262 realAddr += diff;
00263 }
00264
00265 if(sum == 0)
00266 return -1;
00267 return sample(cpt_entries, sum, generator);
00268 }
00269
00270 public double[] getConditionalDistribution(BeliefNode node, int[] nodeDomainIndices) {
00271 CPF cpf = node.getCPF();
00272 BeliefNode[] domProd = cpf.getDomainProduct();
00273 int[] addr = new int[domProd.length];
00274
00275 for(int i = 1; i < addr.length; i++)
00276 addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
00277 addr[0] = 0;
00278 int realAddr = cpf.addr2realaddr(addr);
00279 addr[0] = 1;
00280 int diff = cpf.addr2realaddr(addr) - realAddr;
00281
00282 double[] cpt_entries = new double[domProd[0].getDomain().getOrder()];
00283 for(int i = 0; i < cpt_entries.length; i++){
00284 cpt_entries[i] = cpf.getDouble(realAddr);
00285 realAddr += diff;
00286 }
00287 return cpt_entries;
00288 }
00289
00290 public int getNodeIndex(BeliefNode node) {
00291 return nodeIndices.get(node);
00292 }
00293
00294 public void setDebugMode(boolean active) {
00295 debug = active;
00296 }
00297
00298 public void setVerbose(boolean verbose) {
00299 this.verbose = verbose;
00300 if(verbose)
00301 out = System.out;
00302 else
00303 out = new PrintStream(new java.io.OutputStream() { public void write(int b){} });
00304 }
00305
00306 public String getAlgorithmName() {
00307 return this.getClass().getSimpleName();
00308 }
00309
00310 public ParameterHandler getParameterHandler() {
00311 return paramHandler;
00312 }
00313
00318 protected void report(String s) {
00319 this.report.append(s);
00320 this.report.append('\n');
00321 }
00322 }