2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Bayes Nets.
21 from gtsam
import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution,
22 DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering)
37 """Tests for Discrete Bayes Nets."""
40 """Test constructing a Bayes net."""
43 Parent, Child = (0, 2), (1, 2)
49 parents.push_back(Parent)
51 bayesNet.add(conditional)
55 self.assertEqual(fg.size(), 2)
56 self.assertEqual(fg.at(1).
size(), 2)
59 """Test full Asia example."""
62 asia.add(Asia,
"99/1")
63 asia.add(Smoking,
"50/50")
65 asia.add(Tuberculosis, [Asia],
"99/1 95/5")
66 asia.add(LungCancer, [Smoking],
"99/1 90/10")
67 asia.add(Bronchitis, [Smoking],
"70/30 40/60")
69 asia.add(Either, [Tuberculosis, LungCancer],
"F T T T")
71 asia.add(XRay, [Either],
"95/5 2/98")
72 asia.add(Dyspnea, [Either, Bronchitis],
"9/1 2/8 3/7 1/9")
81 chordal = fg.eliminateSequential(ordering)
86 actualMPE = fg.optimize()
88 for key
in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
89 expectedMPE[key[0]] = 0
90 self.assertEqual(
list(actualMPE.items()),
91 list(expectedMPE.items()))
94 self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
98 fg.add(Dyspnea,
"0 1")
101 actualMPE2 = fg.optimize()
103 for key
in [XRay, Tuberculosis, Either, LungCancer]:
104 expectedMPE2[key[0]] = 0
105 for key
in [Asia, Dyspnea, Smoking, Bronchitis]:
106 expectedMPE2[key[0]] = 1
107 self.assertEqual(
list(actualMPE2.items()),
108 list(expectedMPE2.items()))
111 chordal2 = fg.eliminateSequential(ordering)
112 actualSample = chordal2.sample()
118 """Test evaluate/sampling/optimizing for Asia fragment."""
122 fragment.add(Either, [Tuberculosis, LungCancer],
"F T T T")
123 fragment.add(Tuberculosis, [Asia],
"99/1 95/5")
124 fragment.add(LungCancer, [Smoking],
"99/1 90/10")
128 for key
in [Asia, Smoking]:
132 values = fragment.sample(given)
138 self.assertAlmostEqual(fragment.at(i).logProbability(values),
139 math.log(fragment.at(i).evaluate(values)))
140 self.assertAlmostEqual(fragment.logProbability(values),
141 math.log(fragment.evaluate(values)))
142 actual = fragment.sample(given)
148 """Check that dot works with position hints."""
150 fragment.add(Either, [Tuberculosis, LungCancer],
"F T T T")
152 fragment.add(Tuberculosis, [MyAsia],
"99/1 95/5")
153 fragment.add(LungCancer, [Smoking],
"99/1 90/10")
157 ph: dict = writer.positionHints
159 writer.positionHints = ph
162 actual = fragment.dot(writer=writer)
163 expected_result =
"""\
171 var6989586621679009792[label="a0", pos="0,2!"];
174 var6989586621679009792->var3
178 self.assertEqual(actual, textwrap.dedent(expected_result))
181 if __name__ ==
"__main__":