2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, 3 Atlanta, Georgia 30332-0415 6 See LICENSE for the license information 8 Unit tests for Discrete Bayes Nets. 21 from gtsam
import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution,
22 DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering)
37 """Tests for Discrete Bayes Nets.""" 40 """Test constructing a Bayes net.""" 42 bayesNet = DiscreteBayesNet()
43 Parent, Child = (0, 2), (1, 2)
44 empty = DiscreteKeys()
45 prior = DiscreteConditional(Parent, empty,
"6/4")
48 parents = DiscreteKeys()
49 parents.push_back(Parent)
50 conditional = DiscreteConditional(Child, parents,
"7/3 8/2")
51 bayesNet.add(conditional)
54 fg = DiscreteFactorGraph(bayesNet)
55 self.assertEqual(fg.size(), 2)
56 self.assertEqual(fg.at(1).
size(), 2)
59 """Test full Asia example.""" 61 asia = DiscreteBayesNet()
62 asia.add(Asia,
"99/1")
63 asia.add(Smoking,
"50/50")
65 asia.add(Tuberculosis, [Asia],
"99/1 95/5")
66 asia.add(LungCancer, [Smoking],
"99/1 90/10")
67 asia.add(Bronchitis, [Smoking],
"70/30 40/60")
69 asia.add(Either, [Tuberculosis, LungCancer],
"F T T T")
71 asia.add(XRay, [Either],
"95/5 2/98")
72 asia.add(Dyspnea, [Either, Bronchitis],
"9/1 2/8 3/7 1/9")
75 fg = DiscreteFactorGraph(asia)
81 chordal = fg.eliminateSequential(ordering)
82 expected2 = DiscreteDistribution(Bronchitis,
"11/9")
86 actualMPE = fg.optimize()
88 for key
in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
89 expectedMPE[key[0]] = 0
90 self.assertEqual(
list(actualMPE.items()),
91 list(expectedMPE.items()))
94 self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
98 fg.add(Dyspnea,
"0 1")
101 actualMPE2 = fg.optimize()
103 for key
in [XRay, Tuberculosis, Either, LungCancer]:
104 expectedMPE2[key[0]] = 0
105 for key
in [Asia, Dyspnea, Smoking, Bronchitis]:
106 expectedMPE2[key[0]] = 1
107 self.assertEqual(
list(actualMPE2.items()),
108 list(expectedMPE2.items()))
111 chordal2 = fg.eliminateSequential(ordering)
112 actualSample = chordal2.sample()
118 """Test evaluate/sampling/optimizing for Asia fragment.""" 121 fragment = DiscreteBayesNet()
122 fragment.add(Either, [Tuberculosis, LungCancer],
"F T T T")
123 fragment.add(Tuberculosis, [Asia],
"99/1 95/5")
124 fragment.add(LungCancer, [Smoking],
"99/1 90/10")
128 for key
in [Asia, Smoking]:
132 values = fragment.sample(given)
138 self.assertAlmostEqual(fragment.at(i).logProbability(values),
139 math.log(fragment.at(i).evaluate(values)))
140 self.assertAlmostEqual(fragment.logProbability(values),
141 math.log(fragment.evaluate(values)))
142 actual = fragment.sample(given)
148 """Check that dot works with position hints.""" 149 fragment = DiscreteBayesNet()
150 fragment.add(Either, [Tuberculosis, LungCancer],
"F T T T")
152 fragment.add(Tuberculosis, [MyAsia],
"99/1 95/5")
153 fragment.add(LungCancer, [Smoking],
"99/1 90/10")
157 ph: dict = writer.positionHints
159 writer.positionHints = ph
162 actual = fragment.dot(writer=writer)
163 expected_result =
"""\ 171 var6989586621679009792[label="a0", pos="0,2!"]; 174 var6989586621679009792->var3 178 self.assertEqual(actual, textwrap.dedent(expected_result))
181 if __name__ ==
"__main__":
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Key symbol(unsigned char c, std::uint64_t j)
DotWriter is a helper class for writing graphviz .dot files.
Double_ range(const Point2_ &p, const Point2_ &q)
def test_constructor(self)