00001 package edu.tum.cs.bayesnets.core;
00002
00003 import java.io.BufferedReader;
00004 import java.io.FileInputStream;
00005 import java.io.FileNotFoundException;
00006 import java.io.FileOutputStream;
00007 import java.io.IOException;
00008 import java.io.InputStreamReader;
00009 import java.util.ArrayList;
00010 import java.util.Arrays;
00011 import java.util.Comparator;
00012 import java.util.HashMap;
00013 import java.util.HashSet;
00014 import java.util.Map;
00015 import java.util.Random;
00016 import java.util.Set;
00017 import java.util.Stack;
00018 import java.util.regex.Matcher;
00019 import java.util.regex.Pattern;
00020
00021 import org.apache.log4j.Level;
00022 import org.apache.log4j.Logger;
00023
00024 import edu.ksu.cis.bnj.ver3.core.BeliefNetwork;
00025 import edu.ksu.cis.bnj.ver3.core.BeliefNode;
00026 import edu.ksu.cis.bnj.ver3.core.CPF;
00027 import edu.ksu.cis.bnj.ver3.core.Discrete;
00028 import edu.ksu.cis.bnj.ver3.core.DiscreteEvidence;
00029 import edu.ksu.cis.bnj.ver3.core.Domain;
00030 import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
00031 import edu.ksu.cis.bnj.ver3.inference.approximate.sampling.ForwardSampling;
00032 import edu.ksu.cis.bnj.ver3.inference.exact.Pearl;
00033 import edu.ksu.cis.bnj.ver3.plugin.IOPlugInLoader;
00034 import edu.ksu.cis.bnj.ver3.streams.Exporter;
00035 import edu.ksu.cis.bnj.ver3.streams.Importer;
00036 import edu.ksu.cis.bnj.ver3.streams.OmniFormatV1_Reader;
00037 import edu.ksu.cis.util.graph.algorithms.TopologicalSort;
00038 import edu.ksu.cis.util.graph.core.Graph;
00039 import edu.ksu.cis.util.graph.core.Vertex;
00040 import edu.tum.cs.bayesnets.core.io.Converter_ergo;
00041 import edu.tum.cs.bayesnets.core.io.Converter_hugin;
00042 import edu.tum.cs.bayesnets.core.io.Converter_pmml;
00043 import edu.tum.cs.bayesnets.core.io.Converter_xmlbif;
00044 import edu.tum.cs.bayesnets.inference.WeightedSample;
00045
00056 public class BeliefNetworkEx {
00060 static final Logger logger = Logger.getLogger(BeliefNetworkEx.class);
00061 static {
00062 logger.setLevel(Level.WARN);
00063 }
00064
00069 public static final int MAX_TRIALS = 5000;
00070
00076 public BeliefNetwork bn;
00077
00081 protected String filename;
00082
00086 protected Map<String, String> nodeNameToAttributeMapping = new HashMap<String, String>();
00087
00091 protected Map<String, Set<String>> attributeToNodeNameMapping = new HashMap<String, Set<String>>();
00092
00097 public BeliefNetworkEx(BeliefNetwork bn) {
00098 this.bn = bn;
00099 initAttributeMapping();
00100 }
00101
00107 public BeliefNetworkEx(String filename) throws Exception {
00108 this.bn = load(filename);
00109 initAttributeMapping();
00110 this.filename = filename;
00111 }
00112
00116 public BeliefNetworkEx() {
00117 this.bn = new BeliefNetwork();
00118 }
00119
00123 protected void initAttributeMapping() {
00124 for (BeliefNode node: bn.getNodes()) {
00125 addAttributeMapping(node.getName(), node.getName());
00126 }
00127 }
00128
00135 protected void addAttributeMapping(String nodeName, String attributeName) {
00136 nodeNameToAttributeMapping.put(nodeName, attributeName);
00137 Set<String> nodeNames = attributeToNodeNameMapping.get(attributeName);
00138 if (nodeNames == null) {
00139 nodeNames = new HashSet<String>();
00140 attributeToNodeNameMapping.put(attributeName, nodeNames);
00141 }
00142 nodeNames.add(nodeName);
00143 }
00144
00150 public String getAttributeNameForNode(String nodeName) {
00151 return nodeNameToAttributeMapping.get(nodeName);
00152 }
00153
00159 public Set<String> getNodeNamesForAttribute(String attributeName) {
00160 return attributeToNodeNameMapping.get(attributeName);
00161 }
00162
00167 public void addNode(BeliefNode node) {
00168 bn.addBeliefNode(node);
00169 addAttributeMapping(node.getName(), node.getName());
00170 }
00171
00176 public BeliefNode addDecisionNode(String name) {
00177 BeliefNode node = new BeliefNode(name, new Discrete(new String[]{"True", "False"}));
00178 node.setType(BeliefNode.NODE_DECISION);
00179 bn.addBeliefNode(node);
00180 return node;
00181 }
00182
00183
00189 public BeliefNode addNode(String name) {
00190 return addNode(name, new Discrete(new String[]{"True", "False"}));
00191 }
00192
00200 public BeliefNode addNode(String name, Domain domain) {
00201 return addNode(name, domain, name);
00202 }
00203
00211 public BeliefNode addNode(String name, Domain domain, String attributeName) {
00212 BeliefNode node = new BeliefNode(name, domain);
00213 bn.addBeliefNode(node);
00214 addAttributeMapping(name, attributeName);
00215 logger.debug("Added node "+name+" with attributeName "+attributeName);
00216 return node;
00217 }
00218
00225 public void connect(String node1, String node2) throws Exception {
00226 try {
00227 logger.debug("connecting "+node1+" and "+node2);
00228 logger.debug("Memory free: "+Runtime.getRuntime().freeMemory()+"/"+Runtime.getRuntime().totalMemory());
00229 BeliefNode n1 = getNode(node1);
00230 BeliefNode n2 = getNode(node2);
00231 if(n1 == null || n2 == null)
00232 throw new Exception("One of the node names "+node1+" or "+node2+" is invalid!");
00233 logger.debug("Domainsize: "+n1.getDomain().getOrder()+"x"+n2.getDomain().getOrder());
00234 logger.debug("Doing the connect...");
00235 bn.connect(n1, n2);
00236 logger.debug("Memory free: "+Runtime.getRuntime().freeMemory()+"/"+Runtime.getRuntime().totalMemory());
00237 logger.debug("Connection done.");
00238 } catch(Exception e) {
00239 System.out.println("Exception occurred in connect!");
00240 e.printStackTrace(System.out);
00241 throw e;
00242 } catch(Error e2) {
00243 System.out.println("Error occurred");
00244 e2.printStackTrace(System.out);
00245 throw e2;
00246 }
00247 }
00248
00254 public void connect(BeliefNode parent, BeliefNode child, boolean adjustCPF) {
00255 Graph graph = bn.getGraph();
00256 graph.addDirectedEdge(parent.getOwner(), child.getOwner());
00257 if(adjustCPF) {
00258 Vertex[] parents = graph.getParents(child.getOwner());
00259 BeliefNode[] after = new BeliefNode[parents.length + 1];
00260 for (int i = 0; i < parents.length; i++)
00261 {
00262 after[i + 1] = ((BeliefNode) parents[i].getObject());
00263 }
00264 after[0] = child;
00265 CPF beforeCPF = child.getCPF();
00266 child.setCPF(beforeCPF.expand(after));
00267 }
00268 }
00269
00270 public void connect(BeliefNode parent, BeliefNode child) {
00271 connect(parent, child, true);
00272 }
00273
00279 public BeliefNode getNode(String name) {
00280 int idx = getNodeIndex(name);
00281 if(idx == -1)
00282 return null;
00283 return bn.getNodes()[idx];
00284 }
00285
00286 public BeliefNode getNode(int idx) {
00287 return bn.getNodes()[idx];
00288 }
00289
00295 public int getNodeIndex(String name) {
00296 BeliefNode[] nodes = bn.getNodes();
00297 for(int i = 0; i < nodes.length; i++)
00298 if(nodes[i].getName().equals(name))
00299 return i;
00300 return -1;
00301 }
00302
00308 public int[] getDomainProductNodeIndices(BeliefNode node) {
00309 BeliefNode[] nodes = node.getCPF().getDomainProduct();
00310 int[] nodeIndices = new int[nodes.length];
00311 for(int i = 0; i < nodes.length; i++)
00312 nodeIndices[i] = this.getNodeIndex(nodes[i].getName());
00313 return nodeIndices;
00314 }
00315
00321 public int[] getNodeDomainIndicesFromStrings(String[][] nodeAndDomains) {
00322 BeliefNode[] nodes = bn.getNodes();
00323 int[] nodeDomainIndices = new int[nodes.length];
00324 Arrays.fill(nodeDomainIndices, -1);
00325 for (String[] nodeAndDomain: nodeAndDomains) {
00326 if (nodeAndDomain == null || nodeAndDomain.length != 2)
00327 throw new IllegalArgumentException("Evidences not in the correct format: "+Arrays.toString(nodeAndDomain)+"!");
00328 int nodeIdx = getNodeIndex(nodeAndDomain[0]);
00329 if (nodeIdx < 0)
00330 throw new IllegalArgumentException("Variable with the name "+nodeAndDomain[0]+" not found!");
00331 if (nodeDomainIndices[nodeIdx] > 0)
00332 logger.warn("Evidence "+nodeAndDomain[0]+" set twice!");
00333 Discrete domain = (Discrete)nodes[nodeIdx].getDomain();
00334 int domainIdx = domain.findName(nodeAndDomain[1]);
00335 if (domainIdx < 0) {
00336 if (domain instanceof Discretized) {
00337 try {
00338 double value = Double.parseDouble(nodeAndDomain[1]);
00339 String domainStr = ((Discretized)domain).getNameFromContinuous(value);
00340 domainIdx = domain.findName(domainStr);
00341 } catch (Exception e) {
00342 throw new IllegalArgumentException("Cannot find evidence value "+nodeAndDomain[1]+" in domain "+domain+"!");
00343 }
00344 } else {
00345 throw new IllegalArgumentException("Cannot find evidence value "+nodeAndDomain[1]+" in domain "+domain+"!");
00346 }
00347 }
00348 nodeDomainIndices[nodeIdx]=domainIdx;
00349 }
00350 return nodeDomainIndices;
00351 }
00352
00353 public int getNodeIndex(BeliefNode node) {
00354 BeliefNode[] nodes = bn.getNodes();
00355 for(int i = 0; i < nodes.length; i++)
00356 if(nodes[i] == node)
00357 return i;
00358 return -1;
00359 }
00360
00367 public void setEvidence(String nodeName, String outcome) throws Exception {
00368 BeliefNode node = getNode(nodeName);
00369 if(node == null)
00370 throw new Exception("Invalid node reference: " + nodeName);
00371 Discrete domain = (Discrete) node.getDomain();
00372 int idx = domain.findName(outcome);
00373 if(idx == -1)
00374 throw new Exception("Outcome " + outcome + " not in domain of " + nodeName);
00375 node.setEvidence(new DiscreteEvidence(idx));
00376 }
00377
00386 public double getProbability(String[][] queries, String[][] evidences) throws Exception {
00387
00388
00389 if(queries.length == 1) {
00390
00391 BeliefNode[] nodes = bn.getNodes();
00392 for(int i = 0; i < nodes.length; i++)
00393 nodes[i].setEvidence(null);
00394
00395 if(evidences != null)
00396 for(int i = 0; i < evidences.length; i++) {
00397 setEvidence(evidences[i][0], evidences[i][1]);
00398 }
00399
00400 Pearl inf = new Pearl();
00401 inf.run(this.bn);
00402
00403 BeliefNode node = getNode(queries[0][0]);
00404 CPF cpf = inf.queryMarginal(node);
00405 BeliefNode[] dp = cpf.getDomainProduct();
00406 boolean done = false;
00407 int[] addr = cpf.realaddr2addr(0);
00408 while(!done) {
00409 for (int i = 0; i < addr.length; i++)
00410 if(dp[i].getDomain().getName(addr[i]).equals(queries[0][1])) {
00411 ValueDouble v = (ValueDouble) cpf.get(addr);
00412 return v.getValue();
00413 }
00414 done = cpf.addOne(addr);
00415 }
00416 throw new Exception("Outcome not in domain!");
00417
00418 }
00419 else {
00420 String[][] _queries = new String[1][2];
00421 String[][] _queries2 = new String[queries.length-1][2];
00422 _queries[0] = queries[0];
00423 int numEvidences = evidences == null ? 0 : evidences.length;
00424 String[][] _evidences = new String[numEvidences+queries.length-1][2];
00425 int idx = 0;
00426 for(int i = 1; i < queries.length; i++, idx++) {
00427 _evidences[idx] = queries[i];
00428 _queries2[idx] = queries[i];
00429 }
00430 for(int i = 0; i < numEvidences; i++, idx++)
00431 _evidences[idx] = evidences[i];
00432 return getProbability(_queries, _evidences) * getProbability(_queries2, evidences);
00433 }
00434 }
00435
00436 protected void printProbabilities(int node, Stack<String[]> evidence) throws Exception {
00437 BeliefNode[] nodes = bn.getNodes();
00438 if(node == nodes.length) {
00439 String[][] e = new String[evidence.size()][];
00440 evidence.toArray(e);
00441 double prob = getProbability(e, null);
00442 StringBuffer s = new StringBuffer();
00443 s.append(String.format("%6.2f%% ", 100*prob));
00444 int i = 0;
00445 for(String[] pair : evidence) {
00446 if(i > 0)
00447 s.append(", ");
00448 s.append(String.format("%s=%s", pair[0], pair[1]));
00449 i++;
00450 }
00451 System.out.println(s);
00452 return;
00453 }
00454 Domain dom = nodes[node].getDomain();
00455 for(int i = 0; i < dom.getOrder(); i++) {
00456 evidence.push(new String[]{nodes[node].getName(), dom.getName(i)});
00457 printProbabilities(node+1, evidence);
00458 evidence.pop();
00459 }
00460 }
00461
00462 public void printFullJoint() throws Exception {
00463 printProbabilities(0, new Stack<String[]>());
00464 }
00465
00469 public void printDomain() {
00470 BeliefNode[] nodes = bn.getNodes();
00471 for(int i = 0; i < nodes.length; i++) {
00472 System.out.print(nodes[i].getName());
00473 Discrete domain = (Discrete)nodes[i].getDomain();
00474 System.out.print(" {");
00475 int c = domain.getOrder();
00476 for(int j = 0; j < c; j++) {
00477 if(j > 0) System.out.print(", ");
00478 System.out.print(domain.getName(j));
00479 }
00480 System.out.println("}");
00481 }
00482 }
00483
00491 public static BeliefNetwork load(String filename, Importer importer) throws FileNotFoundException {
00492 FileInputStream fis = new FileInputStream(filename);
00493 OmniFormatV1_Reader ofv1w = new OmniFormatV1_Reader();
00494 importer.load(fis, ofv1w);
00495 return ofv1w.GetBeliefNetwork(0);
00496 }
00497
00504 public static BeliefNetwork load(String filename) throws Exception {
00505 registerDefaultPlugins();
00506 IOPlugInLoader iopl = IOPlugInLoader.getInstance();
00507 String ext = iopl.GetExt(filename);
00508 Importer imp = iopl.GetImporterByExt(ext);
00509 if(imp == null)
00510 throw new Exception("Unable to find an importer that can handle " + ext + " files.");
00511 return load(filename, imp);
00512 }
00513
00520 public static void save(BeliefNetwork net, String filename) throws Exception {
00521 registerDefaultPlugins();
00522 IOPlugInLoader iopl = IOPlugInLoader.getInstance();
00523 String ext = iopl.GetExt(filename);
00524 Exporter exporter = iopl.GetExportersByExt(ext);
00525 if(exporter == null)
00526 throw new Exception("Unable to find an exporter that can handle " + ext + " files.");
00527 save(net, filename, exporter);
00528 }
00529
00537 public static void save(BeliefNetwork net, String filename, Exporter exporter) throws FileNotFoundException {
00538 exporter.save(net, new FileOutputStream(filename));
00539
00540 }
00541
00548 public void save(String filename) throws Exception {
00549 save(this.bn, filename);
00550 }
00551
00558 public void save(String filename, Exporter exporter) throws FileNotFoundException {
00559 save(this.bn, filename, exporter);
00560 }
00561
00567 public void saveXMLBIF(String filename) throws FileNotFoundException {
00568 save(filename, new Converter_xmlbif());
00569 }
00570
00576 public void savePMML(String filename) throws FileNotFoundException {
00577 save(filename, new Converter_pmml());
00578 }
00579
00585 public void save() throws Exception {
00586 IOPlugInLoader pil = IOPlugInLoader.getInstance();
00587 if(filename == null)
00588 throw new Exception("Cannot save - filename not given!");
00589 Exporter exporter = pil.GetExportersByExt(pil.GetExt(filename));
00590 save(filename, exporter);
00591 }
00592
00602 public void sortNodeDomain(String nodeName, boolean numeric) throws Exception {
00603 BeliefNode node = getNode(nodeName);
00604 if(node == null)
00605 throw new Exception("Node not found");
00606 Discrete domain = (Discrete)node.getDomain();
00607 int ord = domain.getOrder();
00608 String[] strings = new String[ord];
00609 if(!numeric) {
00610 for(int i = 0; i < ord; i++)
00611 strings[i] = domain.getName(i);
00612 Arrays.sort(strings);
00613 }
00614 else {
00615 double[] values = new double[ord];
00616 for(int i = 0; i < ord; i++)
00617 values[i] = Double.parseDouble(domain.getName(i));
00618 double[] sorted_values = values.clone();
00619 Arrays.sort(sorted_values);
00620 for(int i = 0; i < ord; i++)
00621 for(int j = 0; j < ord; j++)
00622 if(sorted_values[i] == values[j])
00623 strings[i] = domain.getName(j);
00624 }
00625 bn.changeBeliefNodeDomain(node, new Discrete(strings));
00626 }
00627
00634 public Domain getDomain(String nodeName) {
00635 BeliefNode node = getNode(nodeName);
00636 if(node == null)
00637 return null;
00638 return node.getDomain();
00639 }
00640
00644 public void show() {
00645 registerDefaultPlugins();
00646 edu.ksu.cis.bnj.gui.GUIWindow window = new edu.ksu.cis.bnj.gui.GUIWindow();
00647 window.register();
00648 window.open(bn, filename);
00649 }
00650
00651 public static void registerDefaultPlugins() {
00652 IOPlugInLoader iopl = IOPlugInLoader.getInstance();
00653
00654 Converter_xmlbif xmlbif = new Converter_xmlbif();
00655 iopl.addPlugin(xmlbif, xmlbif);
00656
00657 Converter_pmml pmml = new Converter_pmml();
00658 iopl.addPlugin(pmml, pmml);
00659
00660 Converter_hugin hugin = new Converter_hugin();
00661 iopl.addPlugin(null, hugin);
00662
00663 Converter_ergo ergo = new Converter_ergo();
00664 iopl.addPlugin(ergo, null);
00665 }
00666
00672 public void show(String pluginDir) {
00673 IOPlugInLoader iopl = IOPlugInLoader.getInstance();
00674 iopl.loadPlugins(pluginDir);
00675 show();
00676 }
00677
00685 protected static String[][] readList(String list) throws java.lang.Exception {
00686 if(list == null)
00687 return null;
00688 String[] items = list.split(",");
00689 String[][] res = new String[items.length][2];
00690 for(int i = 0; i < items.length; i++) {
00691 res[i] = items[i].split("=");
00692 if(res[i].length != 2)
00693 throw new java.lang.Exception("syntax error!");
00694 }
00695 return res;
00696 }
00697
00701 public void queryShell() {
00702
00703 System.out.println("Domain:");
00704 printDomain();
00705 System.out.println("\nUsage: Pr[X=x, Y=y, ... | E=e, F=f, ...] (X,Y: query vars;\n" +
00706 " E,F: evidence vars;\n" +
00707 " x,y,e,f: outcomes\n" +
00708 " exit (close shell)");
00709
00710 BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
00711 for(;;) {
00712 try {
00713
00714 System.out.print("\n> ");
00715 String input = br.readLine();
00716
00717 if(input.equalsIgnoreCase("exit"))
00718 break;
00719
00720
00721 input = input.replaceAll("\\s+", "");
00722 Pattern p = Pattern.compile("Pr\\[([^\\]\\|]*)(?:\\|([^\\]]*))?\\]");
00723 Matcher m = p.matcher(input);
00724 if(!m.matches()) {
00725 System.out.println("syntax error!");
00726 }
00727 else {
00728 String[][] queries = readList(m.group(1));
00729 String[][] evidences = readList(m.group(2));
00730 try {
00731
00732 double result = getProbability(queries, evidences);
00733 System.out.println(result);
00734 }
00735 catch(Exception e) {
00736 System.out.println(e.getMessage());
00737 }
00738 }
00739 }
00740 catch(IOException e) {
00741 System.err.println(e.getMessage());
00742 }
00743 catch(java.lang.Exception e) {
00744 System.out.println(e.getMessage());
00745 }
00746 }
00747 }
00748
00758 public WeightedSample[] getAssignmentDistribution(String[][] evidences, String[] queryNodeNames, int numSamples) throws Exception {
00759 HashMap<WeightedSample, Double> sampleSums = new HashMap<WeightedSample, Double>();
00760
00761 int[] queryNodes = new int[queryNodeNames.length];
00762 for (int i=0; i<queryNodeNames.length; i++) {
00763 queryNodes[i]=getNodeIndex(queryNodeNames[i]);
00764 if (queryNodes[i] < 0)
00765 throw new IllegalArgumentException("Cannot find node with name "+queryNodeNames[i]);
00766 }
00767
00768 Random generator = new Random();
00769 for (int i=0; i<numSamples; i++) {
00770 WeightedSample sample = getWeightedSample(evidences, generator);
00771 if (sample == null && i == 0)
00772 return null;
00773 WeightedSample subSample = sample.subSample(queryNodes);
00774
00775 if (sampleSums.containsKey(subSample)) {
00776 sampleSums.put(subSample, sampleSums.get(subSample)+subSample.weight);
00777 } else {
00778 sampleSums.put(subSample, subSample.weight);
00779 }
00780 }
00781
00782 double sum = 0;
00783 for (WeightedSample sample: sampleSums.keySet()) {
00784 logger.debug(sample);
00785 double value = sampleSums.get(sample);
00786 sum += value;
00787 }
00788 WeightedSample[] samples = sampleSums.keySet().toArray(new WeightedSample[0]);
00789 for (WeightedSample sample: samples) {
00790 sample.weight = sampleSums.get(sample)/sum;
00791 }
00792
00793 Arrays.sort(samples, new Comparator<WeightedSample>() {
00794 public int compare(WeightedSample o1, WeightedSample o2) {
00795 return Double.compare(o2.weight, o1.weight);
00796 }
00797 });
00798
00799 return samples;
00800 }
00801
00806 public int[] getTopologicalOrder() {
00807 TopologicalSort topsort = new TopologicalSort();
00808 topsort.execute(bn.getGraph());
00809 return topsort.alpha;
00810 }
00811
00821 public double getCPTProbability(BeliefNode node, int[] nodeDomainIndices ) {
00822 CPF cpf = node.getCPF();
00823 int[] domainProduct = getDomainProductNodeIndices(node);
00824 int[] address = new int[domainProduct.length];
00825 for (int i=0; i<address.length; i++) {
00826 address[i]=nodeDomainIndices[domainProduct[i]];
00827 }
00828 int realAddress = cpf.addr2realaddr(address);
00829 return cpf.getDouble(realAddress);
00830 }
00831
00835 public void removeAllEvidences() {
00836
00837 for(BeliefNode node : bn.getNodes()) {
00838 node.setEvidence(null);
00839 }
00840 }
00841
00851 public double getSampledProbability(String[][] queries, String[][] evidences, int numSamples) throws Exception {
00852 String[] queryNodes = new String[queries.length];
00853 for (int i=0; i<queryNodes.length; i++) {
00854 queryNodes[i]=queries[i][0];
00855 }
00856 WeightedSample[] samples = getAssignmentDistribution(evidences, queryNodes, numSamples);
00857 double goodSum = 0;
00858 double allSum = 0;
00859 for (int i=0; i<samples.length; i++) {
00860 allSum += samples[i].weight;
00861 if (samples[i].checkAssignment(queries))
00862 goodSum += samples[i].weight;
00863 }
00864 return goodSum/allSum;
00865 }
00866
00875 public WeightedSample getWeightedSample(String[][] evidences, Random generator) throws Exception {
00876 if (generator == null) {
00877 generator = new Random();
00878 }
00879 return getWeightedSample(getTopologicalOrder(), evidence2DomainIndices(evidences), generator);
00880 }
00881
00882 public WeightedSample getWeightedSample(int[] nodeOrder, int[] evidenceDomainIndices, Random generator) throws Exception {
00883 BeliefNode[] nodes = bn.getNodes();
00884 int[] sampleDomainIndices = new int[nodes.length];
00885 boolean successful = false;
00886 double weight = 1.0;
00887 int trials=0;
00888 success:while (!successful) {
00889
00890 weight = 1.0;
00891 if (trials > MAX_TRIALS)
00892 return null;
00893 for (int i=0; i< nodeOrder.length; i++) {
00894 int nodeIdx = nodeOrder[i];
00895 int domainIdx = evidenceDomainIndices[nodeIdx];
00896 if (domainIdx >= 0) {
00897 sampleDomainIndices[nodeIdx] = domainIdx;
00898 nodes[nodeIdx].setEvidence(new DiscreteEvidence(domainIdx));
00899
00900 double prob = getCPTProbability(nodes[nodeIdx], sampleDomainIndices);
00901 if (prob == 0.0) {
00902
00903 removeAllEvidences();
00904 trials++;
00905 continue success;
00906 }
00907 weight *= prob;
00908 } else {
00909 domainIdx = ForwardSampling.sampleForward(nodes[nodeIdx], bn, generator);
00910 if (domainIdx < 0) {
00911 System.out.println("could not sample forward because of column with 0s in CPT of " + nodes[nodeIdx].getName());
00912 removeAllEvidences();
00913 trials++;
00914 continue success;
00915 }
00916 sampleDomainIndices[nodeIdx] = domainIdx;
00917 nodes[nodeIdx].setEvidence(new DiscreteEvidence(domainIdx));
00918 }
00919 }
00920 trials++;
00921 removeAllEvidences();
00922 successful = true;
00923 }
00924 return new WeightedSample(this, sampleDomainIndices, weight, null, trials);
00925 }
00926
00927 public int[] evidence2DomainIndices(String[][] evidences) {
00928 BeliefNode[] nodes = bn.getNodes();
00929 int[] evidenceDomainIndices = new int[nodes.length];
00930 Arrays.fill(evidenceDomainIndices, -1);
00931 for (String[] evidence: evidences) {
00932 if (evidence == null || evidence.length != 2)
00933 throw new IllegalArgumentException("Evidences not in the correct format: "+Arrays.toString(evidence)+"!");
00934 int nodeIdx = getNodeIndex(evidence[0]);
00935 if (nodeIdx < 0) {
00936 String error = "Variable with the name "+evidence[0]+" not found in model but mentioned in evidence!";
00937 System.err.println("Warning: " + error);
00938 continue;
00939
00940 }
00941 if (evidenceDomainIndices[nodeIdx] > 0)
00942 logger.warn("Evidence "+evidence[0]+" set twice!");
00943 Discrete domain = (Discrete)nodes[nodeIdx].getDomain();
00944 int domainIdx = domain.findName(evidence[1]);
00945 if (domainIdx < 0) {
00946 if (domain instanceof Discretized) {
00947 try {
00948 double value = Double.parseDouble(evidence[1]);
00949 String domainStr = ((Discretized)domain).getNameFromContinuous(value);
00950 domainIdx = domain.findName(domainStr);
00951 } catch (Exception e) {
00952 throw new IllegalArgumentException("Cannot find evidence value "+evidence[1]+" in domain "+domain+"!");
00953 }
00954 }
00955 else {
00956 throw new IllegalArgumentException("Cannot find evidence value "+evidence[1]+" in domain "+domain+" of node " + nodes[nodeIdx].getName());
00957 }
00958 }
00959 evidenceDomainIndices[nodeIdx]=domainIdx;
00960 }
00961 return evidenceDomainIndices;
00962 }
00963
00971 public HashMap<String,String> getSample(Random generator) throws Exception {
00972 if(generator == null)
00973 generator = new Random();
00974 HashMap<String,String> ret = new HashMap<String,String>();
00975
00976 TopologicalSort topsort = new TopologicalSort();
00977 topsort.execute(bn.getGraph());
00978 int[] order = topsort.alpha;
00979
00980 BeliefNode[] nodes = bn.getNodes();
00981 boolean succeeded = false;
00982 while(!succeeded) {
00983 ArrayList<BeliefNode> setEvidences = new ArrayList<BeliefNode>();
00984 for(int i = 0; i < order.length; i++) {
00985 BeliefNode node = nodes[order[i]];
00986 if(node.hasEvidence()) {
00987 throw new Exception("At least one node has evidence. You can only sample from the marginal distribution!");
00988 }
00989 int idxValue = ForwardSampling.sampleForward(node, bn, generator);
00990 if(idxValue == -1) {
00991
00992 succeeded = false;
00993 break;
00994 }
00995 succeeded = true;
00996 Domain dom = node.getDomain();
00997
00998 ret.put(node.getName(), dom.getName(idxValue));
00999 node.setEvidence(new DiscreteEvidence(idxValue));
01000 setEvidences.add(node);
01001 }
01002
01003 for(BeliefNode node : setEvidences) {
01004 node.setEvidence(null);
01005 }
01006 }
01007 return ret;
01008 }
01009
01010 public static String[] getDiscreteDomainAsArray(BeliefNode node) {
01011 Discrete domain = (Discrete)node.getDomain();
01012 String[] ret = new String[domain.getOrder()];
01013 for(int i = 0; i < ret.length; i++)
01014 ret[i] = domain.getName(i);
01015 return ret;
01016 }
01017
01018 public String[] getDiscreteDomainAsArray(String nodeName) {
01019 return getDiscreteDomainAsArray(getNode(nodeName));
01020 }
01021
01025 public void dump() {
01026 BeliefNode[] nodes = bn.getNodes();
01027 for (int i=0; i<nodes.length; i++) {
01028 logger.debug("Node "+i+": "+nodes[i].getName());
01029 logger.debug("\tAttribute: "+getAttributeNameForNode(nodes[i].getName()));
01030 }
01031 for (String attributeName: attributeToNodeNameMapping.keySet()) {
01032 logger.debug("Attribute "+attributeName+": "+attributeToNodeNameMapping.get(attributeName));
01033 }
01034 }
01035
01036 public abstract class CPTWalker {
01037 public abstract void tellSize(int childConfigs, int parentConfigs);
01038 public abstract void tellNodeOrder(BeliefNode n);
01039 public abstract void tellValue(double v);
01040 }
01041
01042 public void walkCPT(BeliefNode node, CPTWalker walker) {
01043 CPF cpf = node.getCPF();
01044 BeliefNode[] nodes = cpf.getDomainProduct();
01045 int parentConfigs = 1;
01046 for(int i = 1; i < nodes.length; i++)
01047 parentConfigs *= nodes[i].getDomain().getOrder();
01048 walker.tellSize(nodes[0].getDomain().getOrder(), parentConfigs);
01049 int[] addr = new int[cpf.getDomainProduct().length];
01050 walkCPT(walker, cpf, addr, 0);
01051 }
01052
01053 protected void walkCPT(CPTWalker walker, CPF cpf, int[] addr, int i) {
01054 BeliefNode[] nodes = cpf.getDomainProduct();
01055 if(i == addr.length) {
01056
01057 int realAddr = cpf.addr2realaddr(addr);
01058 double value = ((ValueDouble)cpf.get(realAddr)).getValue();
01059 walker.tellValue(value);
01060 }
01061 else {
01062 walker.tellNodeOrder(nodes[i]);
01063 Discrete dom = (Discrete)nodes[i].getDomain();
01064 for(int j = 0; j < dom.getOrder(); j++) {
01065 addr[i] = j;
01066 walkCPT(walker, cpf, addr, i+1);
01067 }
01068 }
01069 }
01070
01077 public int getDomainIndex(BeliefNode node, String value) {
01078 Discrete domain = (Discrete)node.getDomain();
01079 return domain.findName(value);
01080 }
01081
01087 public HashMap<BeliefNode, double[]> computePriors(int[] evidenceDomainIndices) {
01088 HashMap<BeliefNode, double[]> priors = new HashMap<BeliefNode, double[]>();
01089 BeliefNode[] nodes = bn.getNodes();
01090 int[] topOrder = getTopologicalOrder();
01091 for(int i : topOrder) {
01092 BeliefNode node = nodes[i];
01093 double[] dist = new double[node.getDomain().getOrder()];
01094 int evidence = evidenceDomainIndices != null ? evidenceDomainIndices[i] : -1;
01095 if(evidence >= 0) {
01096 for(int j = 0; j < dist.length; j++)
01097 dist[j] = evidence == j ? 1.0 : 0.0;
01098 }
01099 else {
01100 CPF cpf = node.getCPF();
01101 computePrior(priors, evidenceDomainIndices, cpf, 0, new int[cpf.getDomainProduct().length], dist);
01102 }
01103 priors.put(node, dist);
01104 }
01105 return priors;
01106 }
01107
01108 protected void computePrior(HashMap<BeliefNode, double[]> priors, int[] evidenceDomainIndices, CPF cpf, int i, int[] addr, double[] dist) {
01109 BeliefNode[] domProd = cpf.getDomainProduct();
01110 if(i == addr.length) {
01111 double p = cpf.getDouble(addr);
01112 for(int j = 1; j < addr.length; j++) {
01113 double[] parentPrior = priors.get(domProd[j]);
01114 p *= parentPrior[addr[j]];
01115 }
01116 dist[addr[0]] += p;
01117 return;
01118 }
01119 BeliefNode node = domProd[i];
01120 int nodeIdx = getNodeIndex(node);
01121 if(evidenceDomainIndices[nodeIdx] >= 0) {
01122 addr[i] = evidenceDomainIndices[nodeIdx];
01123 computePrior(priors, evidenceDomainIndices, cpf, i+1, addr, dist);
01124 }
01125 else {
01126 Domain dom = node.getDomain();
01127 for(int j = 0; j < dom.getOrder(); j++) {
01128 addr[i] = j;
01129 computePrior(priors, evidenceDomainIndices, cpf, i+1, addr, dist);
01130 }
01131 }
01132 }
01133
01139 public double getWorldProbability(int[] nodeDomainIndices) {
01140 BeliefNode[] nodes = bn.getNodes();
01141 double ret = 1.0;
01142 for(int i = 0; i < nodes.length; i++)
01143 ret *= getCPTProbability(nodes[i], nodeDomainIndices);
01144 return ret;
01145 }
01146
01147 public BeliefNode[] getNodes() {
01148 return bn.getNodes();
01149 }
01150
01154 public double getNumWorlds() {
01155 double num = 1;
01156 for(BeliefNode n : getNodes())
01157 num *= n.getDomain().getOrder();
01158 return num;
01159 }
01160 }