2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Bayes Nets.
21 from gtsam
import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution,
22 DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering)
37 """Tests for Discrete Conditional"""
46 """Tests to check sampling in DiscreteConditionals"""
47 rng = gtsam.MT19937(11)
54 for _
in range(niters):
55 p += conditional.sample(rng)
57 self.assertAlmostEqual(p / niters, 0.3, 1)
64 parentValues[self.
parent[0]] = 1
65 for _
in range(niters):
66 p += conditional.sample(parentValues, rng)
68 self.assertAlmostEqual(p / niters, 0.8, 1)
72 """Tests for Discrete Bayes Nets."""
75 """Test constructing a Bayes net."""
78 Parent, Child = (0, 2), (1, 2)
84 parents.push_back(Parent)
86 bayesNet.add(conditional)
90 self.assertEqual(fg.size(), 2)
91 self.assertEqual(fg.at(1).
size(), 2)
94 """Test full Asia example."""
97 asia.add(Asia,
"99/1")
98 asia.add(Smoking,
"50/50")
100 asia.add(Tuberculosis, [Asia],
"99/1 95/5")
101 asia.add(LungCancer, [Smoking],
"99/1 90/10")
102 asia.add(Bronchitis, [Smoking],
"70/30 40/60")
104 asia.add(Either, [Tuberculosis, LungCancer],
"F T T T")
106 asia.add(XRay, [Either],
"95/5 2/98")
107 asia.add(Dyspnea, [Either, Bronchitis],
"9/1 2/8 3/7 1/9")
115 ordering.push_back(j)
116 chordal = fg.eliminateSequential(ordering)
121 actualMPE = fg.optimize()
124 Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
127 expectedMPE[key[0]] = 0
128 self.assertEqual(
list(actualMPE.items()),
list(expectedMPE.items()))
131 self.assertAlmostEqual(
asia(actualMPE), fg(actualMPE))
135 fg.add(Dyspnea,
"0 1")
138 actualMPE2 = fg.optimize()
140 for key
in [XRay, Tuberculosis, Either, LungCancer]:
141 expectedMPE2[key[0]] = 0
142 for key
in [Asia, Dyspnea, Smoking, Bronchitis]:
143 expectedMPE2[key[0]] = 1
144 self.assertEqual(
list(actualMPE2.items()),
list(expectedMPE2.items()))
147 chordal2 = fg.eliminateSequential(ordering)
148 actualSample = chordal2.sample()
154 """Test evaluate/sampling/optimizing for Asia fragment."""
158 fragment.add(Either, [Tuberculosis, LungCancer],
"F T T T")
159 fragment.add(Tuberculosis, [Asia],
"99/1 95/5")
160 fragment.add(LungCancer, [Smoking],
"99/1 90/10")
164 for key
in [Asia, Smoking]:
168 values = fragment.sample(given)
174 self.assertAlmostEqual(
175 fragment.at(i).logProbability(values),
176 math.log(fragment.at(i).evaluate(values)))
177 self.assertAlmostEqual(fragment.logProbability(values),
178 math.log(fragment.evaluate(values)))
179 actual = fragment.sample(given)
185 """Check that dot works with position hints."""
187 fragment.add(Either, [Tuberculosis, LungCancer],
"F T T T")
189 fragment.add(Tuberculosis, [MyAsia],
"99/1 95/5")
190 fragment.add(LungCancer, [Smoking],
"99/1 90/10")
194 ph: dict = writer.positionHints
196 writer.positionHints = ph
199 actual = fragment.dot(writer=writer)
200 expected_result =
"""\
208 var6989586621679009792[label="a0", pos="0,2!"];
211 var6989586621679009792->var3
215 self.assertEqual(actual, textwrap.dedent(expected_result))
218 if __name__ ==
"__main__":