00001 package edu.tum.cs.srl.bayesnets.inference;
00002
00003 import java.util.ArrayList;
00004 import java.util.Collections;
00005 import java.util.Vector;
00006 import java.util.regex.Pattern;
00007
00008 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00009 import edu.tum.cs.bayesnets.inference.SampledDistribution;
00010 import edu.tum.cs.inference.IParameterHandler;
00011 import edu.tum.cs.inference.ParameterHandler;
00012 import edu.tum.cs.srl.bayesnets.bln.AbstractGroundBLN;
00013 import edu.tum.cs.util.Stopwatch;
00014
00019 public abstract class Sampler implements IParameterHandler {
00020 protected boolean debug = false;
00021 protected boolean verbose = true;
00022 protected int numSamples = 1000;
00023 protected int infoInterval = 100;
00024 protected ParameterHandler paramHandler;
00025 protected Vector<Integer> queryVars;
00026 protected AbstractGroundBLN gbln;
00027 double samplingTime;
00028
00029 public Sampler(AbstractGroundBLN gbln) throws Exception {
00030 this.gbln = gbln;
00031 paramHandler = new ParameterHandler(this);
00032 paramHandler.add("maxSteps", "setNumSamples");
00033 paramHandler.add("numSamples", "setNumSamples");
00034 paramHandler.add("infoInterval", "setInfoInterval");
00035 paramHandler.add("debug", "setDebugMode");
00036 paramHandler.add("verbose", "setVerbose");
00037 }
00038
00044 public Vector<InferenceResult> getResults(SampledDistribution dist) {
00045 Vector<InferenceResult> results = new Vector<InferenceResult>();
00046 for(Integer i : queryVars)
00047 results.add(new InferenceResult(dist, i));
00048 return results;
00049 }
00050
00051 public void printResults(SampledDistribution dist) {
00052 ArrayList<InferenceResult> results = new ArrayList<InferenceResult>(getResults(dist));
00053 Collections.sort(results);
00054 for(InferenceResult res : results)
00055 res.print();
00056 }
00057
00058 public double getSamplingTime() {
00059 return samplingTime;
00060 }
00061
00062 public void setNumSamples(int n) {
00063 numSamples = n;
00064 }
00065
00066 public void setInfoInterval(int n) {
00067 infoInterval = n;
00068 }
00069
00070 public SampledDistribution infer() throws Exception {
00071 Stopwatch sw = new Stopwatch();
00072 sw.start();
00073 SampledDistribution ret = _infer();
00074 samplingTime = sw.getElapsedTimeSecs();
00075 return ret;
00076 }
00077
00078 protected abstract SampledDistribution _infer() throws Exception;
00079
00080 public Vector<InferenceResult> inferQueries() throws Exception {
00081 return getResults(infer());
00082 }
00083
00084 public String getAlgorithmName() {
00085 return this.getClass().getSimpleName();
00086 }
00087
00088 public void setDebugMode(boolean active) {
00089 debug = active;
00090 }
00091
00092 public void setVerbose(boolean verbose) {
00093 this.verbose = verbose;
00094 }
00095
00096 public ParameterHandler getParameterHandler() {
00097 return paramHandler;
00098 }
00099
00100 public void setQueries(Iterable<String> queries) {
00101
00102 Vector<Pattern> patterns = new Vector<Pattern>();
00103 for(String query : queries) {
00104 String p = query;
00105 p = Pattern.compile("([,\\(])([a-z][^,\\)]*)").matcher(p).replaceAll("$1.*?");
00106 p = p.replace("(", "\\(").replace(")", "\\)") + ".*";
00107 patterns.add(Pattern.compile(p));
00108
00109 }
00110
00111
00112
00113 BeliefNode[] nodes = gbln.getGroundNetwork().getNodes();
00114 queryVars = new Vector<Integer>();
00115 for(int i = 0; i < nodes.length; i++)
00116 for(Pattern pattern : patterns)
00117 if(pattern.matcher(nodes[i].getName()).matches()) {
00118 queryVars.add(i);
00119 break;
00120 }
00121 }
00122 }