test_DiscreteBayesNet.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Bayes Nets.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import math
15 import textwrap
16 import unittest
17 
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 import gtsam
21 from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution,
22  DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering)
23 
24 # Some keys:
25 Asia = (0, 2)
26 Smoking = (4, 2)
27 Tuberculosis = (3, 2)
28 LungCancer = (6, 2)
29 
30 Bronchitis = (7, 2)
31 Either = (5, 2)
32 XRay = (2, 2)
33 Dyspnea = (1, 2)
34 
35 
37  """Tests for Discrete Bayes Nets."""
38 
39  def test_constructor(self):
40  """Test constructing a Bayes net."""
41 
42  bayesNet = DiscreteBayesNet()
43  Parent, Child = (0, 2), (1, 2)
44  empty = DiscreteKeys()
45  prior = DiscreteConditional(Parent, empty, "6/4")
46  bayesNet.add(prior)
47 
48  parents = DiscreteKeys()
49  parents.push_back(Parent)
50  conditional = DiscreteConditional(Child, parents, "7/3 8/2")
51  bayesNet.add(conditional)
52 
53  # Check conversion to factor graph:
54  fg = DiscreteFactorGraph(bayesNet)
55  self.assertEqual(fg.size(), 2)
56  self.assertEqual(fg.at(1).size(), 2)
57 
58  def test_Asia(self):
59  """Test full Asia example."""
60 
61  asia = DiscreteBayesNet()
62  asia.add(Asia, "99/1")
63  asia.add(Smoking, "50/50")
64 
65  asia.add(Tuberculosis, [Asia], "99/1 95/5")
66  asia.add(LungCancer, [Smoking], "99/1 90/10")
67  asia.add(Bronchitis, [Smoking], "70/30 40/60")
68 
69  asia.add(Either, [Tuberculosis, LungCancer], "F T T T")
70 
71  asia.add(XRay, [Either], "95/5 2/98")
72  asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")
73 
74  # Convert to factor graph
75  fg = DiscreteFactorGraph(asia)
76 
77  # Create solver and eliminate
78  ordering = Ordering()
79  for j in range(8):
80  ordering.push_back(j)
81  chordal = fg.eliminateSequential(ordering)
82  expected2 = DiscreteDistribution(Bronchitis, "11/9")
83  self.gtsamAssertEquals(chordal.at(7), expected2)
84 
85  # solve
86  actualMPE = fg.optimize()
87  expectedMPE = DiscreteValues()
88  for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
89  expectedMPE[key[0]] = 0
90  self.assertEqual(list(actualMPE.items()),
91  list(expectedMPE.items()))
92 
93  # Check value for MPE is the same
94  self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
95 
96  # add evidence, we were in Asia and we have dyspnea
97  fg.add(Asia, "0 1")
98  fg.add(Dyspnea, "0 1")
99 
100  # solve again, now with evidence
101  actualMPE2 = fg.optimize()
102  expectedMPE2 = DiscreteValues()
103  for key in [XRay, Tuberculosis, Either, LungCancer]:
104  expectedMPE2[key[0]] = 0
105  for key in [Asia, Dyspnea, Smoking, Bronchitis]:
106  expectedMPE2[key[0]] = 1
107  self.assertEqual(list(actualMPE2.items()),
108  list(expectedMPE2.items()))
109 
110  # now sample from it
111  chordal2 = fg.eliminateSequential(ordering)
112  actualSample = chordal2.sample()
113  # TODO(kartikarcot): Resolve the len function issue. Probably
114  # due to a use of initializer list which is not supported in CPP17
115  # self.assertEqual(len(actualSample), 8)
116 
117  def test_fragment(self):
118  """Test evaluate/sampling/optimizing for Asia fragment."""
119 
120  # Create a reverse-topologically sorted fragment:
121  fragment = DiscreteBayesNet()
122  fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
123  fragment.add(Tuberculosis, [Asia], "99/1 95/5")
124  fragment.add(LungCancer, [Smoking], "99/1 90/10")
125 
126  # Create assignment with missing values:
127  given = DiscreteValues()
128  for key in [Asia, Smoking]:
129  given[key[0]] = 0
130 
131  # Now sample from fragment:
132  values = fragment.sample(given)
133  # TODO(kartikarcot): Resolve the len function issue. Probably
134  # due to a use of initializer list which is not supported in CPP17
135  # self.assertEqual(len(values), 5)
136 
137  for i in [0, 1, 2]:
138  self.assertAlmostEqual(fragment.at(i).logProbability(values),
139  math.log(fragment.at(i).evaluate(values)))
140  self.assertAlmostEqual(fragment.logProbability(values),
141  math.log(fragment.evaluate(values)))
142  actual = fragment.sample(given)
143  # TODO(kartikarcot): Resolve the len function issue. Probably
144  # due to a use of initializer list which is not supported in CPP17
145  # self.assertEqual(len(actual), 5)
146 
147  def test_dot(self):
148  """Check that dot works with position hints."""
149  fragment = DiscreteBayesNet()
150  fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
151  MyAsia = gtsam.symbol('a', 0), 2 # use a symbol!
152  fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
153  fragment.add(LungCancer, [Smoking], "99/1 90/10")
154 
155  # Make sure we can *update* position hints
156  writer = gtsam.DotWriter()
157  ph: dict = writer.positionHints
158  ph['a'] = 2 # hint at symbol position
159  writer.positionHints = ph
160 
161  # Check the output of dot
162  actual = fragment.dot(writer=writer)
163  expected_result = """\
164  digraph {
165  size="5,5";
166 
167  var3[label="3"];
168  var4[label="4"];
169  var5[label="5"];
170  var6[label="6"];
171  var6989586621679009792[label="a0", pos="0,2!"];
172 
173  var4->var6
174  var6989586621679009792->var3
175  var3->var5
176  var6->var5
177  }"""
178  self.assertEqual(actual, textwrap.dedent(expected_result))
179 
180 
181 if __name__ == "__main__":
182  unittest.main()
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:18
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
Key symbol(unsigned char c, std::uint64_t j)
DotWriter is a helper class for writing graphviz .dot files.
Definition: DotWriter.h:36
Definition: pytypes.h:1979
Double_ range(const Point2_ &p, const Point2_ &q)


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:37:45