00001 package edu.tum.cs.srl.bayesnets.inference; 00002 00003 import edu.tum.cs.bayesnets.core.BeliefNetworkEx; 00004 import edu.tum.cs.bayesnets.inference.ITimeLimitedInference; 00005 import edu.tum.cs.bayesnets.inference.SampledDistribution; 00006 import edu.tum.cs.srl.bayesnets.bln.AbstractGroundBLN; 00007 00013 public class BNSampler extends Sampler implements ITimeLimitedInference { 00014 protected int maxTrials; 00018 protected boolean skipFailedSteps; 00019 protected Class<? extends edu.tum.cs.bayesnets.inference.Sampler> samplerClass; 00020 protected edu.tum.cs.bayesnets.inference.Sampler sampler; 00024 protected int[] evidenceDomainIndices; 00025 00026 public BNSampler(AbstractGroundBLN gbln, Class<? extends edu.tum.cs.bayesnets.inference.Sampler> samplerClass) throws Exception { 00027 super(gbln); 00028 maxTrials = 5000; 00029 this.paramHandler.add("maxTrials", "setMaxTrials"); 00030 this.paramHandler.add("skipFailedSteps", "setSkipFailedSteps"); 00031 this.samplerClass = samplerClass; 00032 } 00033 00034 public void setMaxTrials(int maxTrials) { 00035 this.maxTrials = maxTrials; 00036 } 00037 00038 public void setSkipFailedSteps(boolean canSkip) { 00039 this.skipFailedSteps = canSkip; 00040 } 00041 00042 @Override 00043 public SampledDistribution _infer() throws Exception { 00044 // create full evidence 00045 String[][] evidence = this.gbln.getDatabase().getEntriesAsArray(); 00046 evidenceDomainIndices = gbln.getFullEvidence(evidence); 00047 00048 // initialize sampler 00049 if(verbose) System.out.println("initializing..."); 00050 sampler = getSampler(); 00051 paramHandler.addSubhandler(sampler.getParameterHandler()); 00052 sampler.setEvidence(evidenceDomainIndices); 00053 sampler.setQueryVars(queryVars); 00054 sampler.setDebugMode(debug); 00055 sampler.setNumSamples(numSamples); 00056 sampler.setInfoInterval(infoInterval); 00057 sampler.setMaxTrials(maxTrials); 00058 sampler.setSkipFailedSteps(skipFailedSteps); 00059 00060 // run inference 00061 if(verbose) System.out.printf("running %s...\n", sampler.getAlgorithmName()); 00062 SampledDistribution dist = sampler.infer(); 00063 return dist; 00064 } 00065 00066 protected edu.tum.cs.bayesnets.inference.Sampler getSampler() throws Exception { 00067 return samplerClass.getConstructor(BeliefNetworkEx.class).newInstance(gbln.getGroundNetwork()); 00068 } 00069 00070 @Override 00071 public String getAlgorithmName() { 00072 return "BNInference:" + samplerClass.getSimpleName(); 00073 } 00074 00075 public SampledDistribution pollResults() throws CloneNotSupportedException { 00076 if(sampler == null) 00077 return null; 00078 return sampler.pollResults(); 00079 } 00080 }