2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Bayes trees.
21 from gtsam
import (DiscreteBayesNet, DiscreteBayesTreeClique,
22 DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
27 """Tests for Discrete Bayes Nets."""
30 """Test Multifrontal elimination."""
33 keys = [(j, 2)
for j
in range(15)]
38 bayesNet.add(keys[0], [keys[8], keys[12]],
"2/3 1/4 3/2 4/1")
39 bayesNet.add(keys[1], [keys[8], keys[12]],
"4/1 2/3 3/2 1/4")
40 bayesNet.add(keys[2], [keys[9], keys[12]],
"1/4 8/2 2/3 4/1")
41 bayesNet.add(keys[3], [keys[9], keys[12]],
"1/4 2/3 3/2 4/1")
43 bayesNet.add(keys[4], [keys[10], keys[13]],
"2/3 1/4 3/2 4/1")
44 bayesNet.add(keys[5], [keys[10], keys[13]],
"4/1 2/3 3/2 1/4")
45 bayesNet.add(keys[6], [keys[11], keys[13]],
"1/4 3/2 2/3 4/1")
46 bayesNet.add(keys[7], [keys[11], keys[13]],
"1/4 2/3 3/2 4/1")
48 bayesNet.add(keys[8], [keys[12], keys[14]],
"T 1/4 3/2 4/1")
49 bayesNet.add(keys[9], [keys[12], keys[14]],
"4/1 2/3 F 1/4")
50 bayesNet.add(keys[10], [keys[13], keys[14]],
"1/4 3/2 2/3 4/1")
51 bayesNet.add(keys[11], [keys[13], keys[14]],
"1/4 2/3 3/2 4/1")
53 bayesNet.add(keys[12], [keys[14]],
"3/1 3/1")
54 bayesNet.add(keys[13], [keys[14]],
"1/3 3/1")
56 bayesNet.add(keys[14],
"1/3")
65 bayesTree = factorGraph.eliminateMultifrontal(ordering)
75 self.assertIsInstance(root, DiscreteBayesTreeClique)
76 self.assertTrue(root.isRoot())
77 self.assertIsInstance(root.conditional(), DiscreteConditional)
86 value_at_zeros = bayesTree.evaluate(zero_values)
87 self.assertAlmostEqual(value_at_zeros, 0.0)
90 values_star = factorGraph.optimize()
91 max_value = bayesTree.evaluate(values_star)
92 self.assertAlmostEqual(max_value, 0.002548)
96 self.assertAlmostEqual(max_value, 0.002548)
98 self.assertFalse(bayesTree.empty())
99 self.assertEqual(12, bayesTree.size())
101 @unittest.skip(
"Too Slow")
103 """Check that we can have a multi-frontal lookup table."""
106 x1, x2, x3 = (
X(1), 3), (
X(2), 3), (
X(3), 3)
107 a1, a2 = (A(1), 2), (A(2), 2)
110 graph.add([x1], np.array([1, 0, 0]))
111 graph.add([x3], np.array([0, 0, 1]))
121 graph.add([x1, a1, x2], table)
122 graph.add([x2, a2, x3], table)
125 ordering =
Ordering(keys=[A(2),
X(3),
X(1), A(1),
X(2)])
126 lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
129 assert lookup.size() == 2
130 lookup_x1_a1_x2 = lookup[
X(1)].conditional()
131 assert lookup_x1_a1_x2.nrFrontals() == 3
134 self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
140 self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
142 lookup_a2_x3 = lookup[
X(3)].conditional()
144 sum_x2 = lookup_a2_x3.sum(2)
147 self.assertAlmostEqual(sum_x2(values), 0)
149 self.assertAlmostEqual(sum_x2(values), 1.0)
151 self.assertAlmostEqual(sum_x2(values), 2.0)
152 assert lookup_a2_x3.nrFrontals() == 2
158 self.assertAlmostEqual(lookup_a2_x3(values), 1.0)
161 """Test creating a Bayes tree directly from cliques."""
164 A, B, C = (0, 2), (1, 2), (2, 2)
165 bayesNet.add(A,
"1/3")
166 bayesNet.add(B, [A],
"1/3 3/1")
167 bayesNet.add(C, [B],
"3/1 3/1")
176 bayesTree.insertRoot(clique2)
177 bayesTree.addClique(clique1, clique2)
178 bayesTree.addClique(clique0, clique1)
188 self.assertAlmostEqual(expected, bayesNet.evaluate(values))
191 if __name__ ==
"__main__":