test_DiscreteBayesNet.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Bayes Nets.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import math
15 import textwrap
16 import unittest
17 
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 import gtsam
21 from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution,
22  DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering)
23 
24 # Some keys:
25 Asia = (0, 2)
26 Smoking = (4, 2)
27 Tuberculosis = (3, 2)
28 LungCancer = (6, 2)
29 
30 Bronchitis = (7, 2)
31 Either = (5, 2)
32 XRay = (2, 2)
33 Dyspnea = (1, 2)
34 
35 
37  """Tests for Discrete Conditional"""
38 
39  def setUp(self):
40  self.key = (0, 2)
41  self.parent = (1, 2)
43  self.parents.push_back(self.parent)
44 
45  def test_sample(self):
46  """Tests to check sampling in DiscreteConditionals"""
47  rng = gtsam.MT19937(11)
48  niters = 1000
49 
50  # Sample with only 1 variable
51  conditional = DiscreteConditional(self.key, "7/3")
52  # Sample multiple times and average to get mean
53  p = 0
54  for _ in range(niters):
55  p += conditional.sample(rng)
56 
57  self.assertAlmostEqual(p / niters, 0.3, 1)
58 
59  # Sample with variable and parent
60  conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
61  # Sample multiple times and average to get mean
62  p = 0
63  parentValues = gtsam.DiscreteValues()
64  parentValues[self.parent[0]] = 1
65  for _ in range(niters):
66  p += conditional.sample(parentValues, rng)
67 
68  self.assertAlmostEqual(p / niters, 0.8, 1)
69 
70 
72  """Tests for Discrete Bayes Nets."""
73 
74  def test_constructor(self):
75  """Test constructing a Bayes net."""
76 
77  bayesNet = DiscreteBayesNet()
78  Parent, Child = (0, 2), (1, 2)
79  empty = DiscreteKeys()
80  prior = DiscreteConditional(Parent, empty, "6/4")
81  bayesNet.add(prior)
82 
83  parents = DiscreteKeys()
84  parents.push_back(Parent)
85  conditional = DiscreteConditional(Child, parents, "7/3 8/2")
86  bayesNet.add(conditional)
87 
88  # Check conversion to factor graph:
89  fg = DiscreteFactorGraph(bayesNet)
90  self.assertEqual(fg.size(), 2)
91  self.assertEqual(fg.at(1).size(), 2)
92 
93  def test_Asia(self):
94  """Test full Asia example."""
95 
96  asia = DiscreteBayesNet()
97  asia.add(Asia, "99/1")
98  asia.add(Smoking, "50/50")
99 
100  asia.add(Tuberculosis, [Asia], "99/1 95/5")
101  asia.add(LungCancer, [Smoking], "99/1 90/10")
102  asia.add(Bronchitis, [Smoking], "70/30 40/60")
103 
104  asia.add(Either, [Tuberculosis, LungCancer], "F T T T")
105 
106  asia.add(XRay, [Either], "95/5 2/98")
107  asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")
108 
109  # Convert to factor graph
110  fg = DiscreteFactorGraph(asia)
111 
112  # Create solver and eliminate
113  ordering = Ordering()
114  for j in range(8):
115  ordering.push_back(j)
116  chordal = fg.eliminateSequential(ordering)
117  expected2 = DiscreteDistribution(Bronchitis, "11/9")
118  self.gtsamAssertEquals(chordal.at(7), expected2)
119 
120  # solve
121  actualMPE = fg.optimize()
122  expectedMPE = DiscreteValues()
123  for key in [
124  Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
125  Bronchitis
126  ]:
127  expectedMPE[key[0]] = 0
128  self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
129 
130  # Check value for MPE is the same
131  self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
132 
133  # add evidence, we were in Asia and we have dyspnea
134  fg.add(Asia, "0 1")
135  fg.add(Dyspnea, "0 1")
136 
137  # solve again, now with evidence
138  actualMPE2 = fg.optimize()
139  expectedMPE2 = DiscreteValues()
140  for key in [XRay, Tuberculosis, Either, LungCancer]:
141  expectedMPE2[key[0]] = 0
142  for key in [Asia, Dyspnea, Smoking, Bronchitis]:
143  expectedMPE2[key[0]] = 1
144  self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
145 
146  # now sample from it
147  chordal2 = fg.eliminateSequential(ordering)
148  actualSample = chordal2.sample()
149  # TODO(kartikarcot): Resolve the len function issue. Probably
150  # due to a use of initializer list which is not supported in CPP17
151  # self.assertEqual(len(actualSample), 8)
152 
153  def test_fragment(self):
154  """Test evaluate/sampling/optimizing for Asia fragment."""
155 
156  # Create a reverse-topologically sorted fragment:
157  fragment = DiscreteBayesNet()
158  fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
159  fragment.add(Tuberculosis, [Asia], "99/1 95/5")
160  fragment.add(LungCancer, [Smoking], "99/1 90/10")
161 
162  # Create assignment with missing values:
163  given = DiscreteValues()
164  for key in [Asia, Smoking]:
165  given[key[0]] = 0
166 
167  # Now sample from fragment:
168  values = fragment.sample(given)
169  # TODO(kartikarcot): Resolve the len function issue. Probably
170  # due to a use of initializer list which is not supported in CPP17
171  # self.assertEqual(len(values), 5)
172 
173  for i in [0, 1, 2]:
174  self.assertAlmostEqual(
175  fragment.at(i).logProbability(values),
176  math.log(fragment.at(i).evaluate(values)))
177  self.assertAlmostEqual(fragment.logProbability(values),
178  math.log(fragment.evaluate(values)))
179  actual = fragment.sample(given)
180  # TODO(kartikarcot): Resolve the len function issue. Probably
181  # due to a use of initializer list which is not supported in CPP17
182  # self.assertEqual(len(actual), 5)
183 
184  def test_dot(self):
185  """Check that dot works with position hints."""
186  fragment = DiscreteBayesNet()
187  fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
188  MyAsia = gtsam.symbol('a', 0), 2 # use a symbol!
189  fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
190  fragment.add(LungCancer, [Smoking], "99/1 90/10")
191 
192  # Make sure we can *update* position hints
193  writer = gtsam.DotWriter()
194  ph: dict = writer.positionHints
195  ph['a'] = 2 # hint at symbol position
196  writer.positionHints = ph
197 
198  # Check the output of dot
199  actual = fragment.dot(writer=writer)
200  expected_result = """\
201  digraph {
202  size="5,5";
203 
204  var3[label="3"];
205  var4[label="4"];
206  var5[label="5"];
207  var6[label="6"];
208  var6989586621679009792[label="a0", pos="0,2!"];
209 
210  var4->var6
211  var6989586621679009792->var3
212  var3->var5
213  var6->var5
214  }"""
215  self.assertEqual(actual, textwrap.dedent(expected_result))
216 
217 
218 if __name__ == "__main__":
219  unittest.main()
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
test_DiscreteBayesNet.TestDiscreteBayesNet.test_dot
def test_dot(self)
Definition: test_DiscreteBayesNet.py:184
list
Definition: pytypes.h:2168
gtsam::DiscreteDistribution
Definition: DiscreteDistribution.h:33
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
size
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
test_DiscreteBayesNet.TestDiscreteConditional.setUp
def setUp(self)
Definition: test_DiscreteBayesNet.py:39
test_DiscreteBayesNet.TestDiscreteBayesNet.test_fragment
def test_fragment(self)
Definition: test_DiscreteBayesNet.py:153
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
test_DiscreteBayesNet.TestDiscreteConditional.test_sample
def test_sample(self)
Definition: test_DiscreteBayesNet.py:45
gtsam::symbol
Key symbol(unsigned char c, std::uint64_t j)
Definition: inference/Symbol.h:139
gtsam::utils.test_case
Definition: test_case.py:1
test_DiscreteBayesNet.TestDiscreteBayesNet.test_Asia
def test_Asia(self)
Definition: test_DiscreteBayesNet.py:93
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:38
test_DiscreteBayesNet.TestDiscreteConditional
Definition: test_DiscreteBayesNet.py:36
asia
Definition: testDiscreteSearch.cpp:28
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
test_DiscreteBayesNet.TestDiscreteBayesNet
Definition: test_DiscreteBayesNet.py:71
test_DiscreteBayesNet.TestDiscreteConditional.key
key
Definition: test_DiscreteBayesNet.py:40
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_DiscreteBayesNet.TestDiscreteConditional.parents
parents
Definition: test_DiscreteBayesNet.py:42
gtsam::Ordering
Definition: inference/Ordering.h:33
test_DiscreteBayesNet.TestDiscreteBayesNet.test_constructor
def test_constructor(self)
Definition: test_DiscreteBayesNet.py:74
test_DiscreteBayesNet.TestDiscreteConditional.parent
parent
Definition: test_DiscreteBayesNet.py:41
gtsam::DotWriter
DotWriter is a helper class for writing graphviz .dot files.
Definition: DotWriter.h:36


gtsam
Author(s):
autogenerated on Wed May 28 2025 03:06:00