00001 package edu.tum.cs.bayesnets.inference;
00002
00003 import java.util.Random;
00004
00005 import edu.tum.cs.bayesnets.core.BeliefNetworkEx;
00006
00010 public class LikelihoodWeightingWithUncertainEvidence extends LikelihoodWeighting {
00011 protected final double evidenceProbability = 0.8;
00012
00013 public LikelihoodWeightingWithUncertainEvidence(BeliefNetworkEx bn) throws Exception {
00014 super(bn);
00015 }
00016
00017 public WeightedSample getWeightedSample(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices) throws Exception {
00018 Random rand = new Random();
00019 s.trials = 0;
00020 boolean successful = false;
00021 loop: while(!successful) {
00022 s.weight = 1.0;
00023 s.trials++;
00024 if(s.trials > this.maxTrials) {
00025 if(!this.skipFailedSteps)
00026 throw new Exception("Could not obtain a countable sample in the maximum allowed number of trials (" + maxTrials + ")");
00027 else
00028 return null;
00029 }
00030
00031 for(int i=0; i < nodeOrder.length; i++) {
00032 int nodeIdx = nodeOrder[i];
00033 int domainIdx = evidenceDomainIndices[nodeIdx];
00034
00035 if(domainIdx >= 0) {
00036
00037 double choiceProb = evidenceProbability;
00038 if(rand.nextDouble() > evidenceProbability) {
00039 int numOtherChoices = nodes[nodeIdx].getDomain().getOrder()-1;
00040 if(numOtherChoices > 0) {
00041 int newDomIdx = rand.nextInt(numOtherChoices);
00042 if(newDomIdx >= domainIdx)
00043 newDomIdx++;
00044 domainIdx = newDomIdx;
00045 choiceProb = (1-evidenceProbability)/numOtherChoices;
00046 }
00047 }
00048 s.weight *= choiceProb;
00049
00050 nodes[nodeIdx].getDomain();
00051
00052 s.nodeDomainIndices[nodeIdx] = domainIdx;
00053 double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
00054 if(prob == 0.0) {
00055 if(debug)
00056 out.println("!!! evidence probability was 0 at node " + nodes[nodeIdx] + " in step " + (dist.steps+1));
00057 continue loop;
00058 }
00059 s.weight *= prob;
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078 }
00079
00080 else {
00081 domainIdx = sampleForward(nodes[nodeIdx], s.nodeDomainIndices);
00082 if(domainIdx < 0) {
00083 if(debug)
00084 out.println("!!! could not sample forward because of column with only 0s in CPT of " + nodes[nodeIdx].getName() + " in step " + (dist.steps+1));
00085 bn.removeAllEvidences();
00086 continue loop;
00087 }
00088 s.nodeDomainIndices[nodeIdx] = domainIdx;
00089 }
00090 }
00091 successful = true;
00092 }
00093 return s;
00094 }
00095 }