2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Bayes trees.
21 from gtsam
import (DiscreteBayesNet, DiscreteBayesTreeClique,
22 DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
27 """Tests for Discrete Bayes Nets."""
30 """Test Multifrontal elimination."""
33 keys = [(j, 2)
for j
in range(15)]
38 bayesNet.add(keys[0], [keys[8], keys[12]],
"2/3 1/4 3/2 4/1")
39 bayesNet.add(keys[1], [keys[8], keys[12]],
"4/1 2/3 3/2 1/4")
40 bayesNet.add(keys[2], [keys[9], keys[12]],
"1/4 8/2 2/3 4/1")
41 bayesNet.add(keys[3], [keys[9], keys[12]],
"1/4 2/3 3/2 4/1")
43 bayesNet.add(keys[4], [keys[10], keys[13]],
"2/3 1/4 3/2 4/1")
44 bayesNet.add(keys[5], [keys[10], keys[13]],
"4/1 2/3 3/2 1/4")
45 bayesNet.add(keys[6], [keys[11], keys[13]],
"1/4 3/2 2/3 4/1")
46 bayesNet.add(keys[7], [keys[11], keys[13]],
"1/4 2/3 3/2 4/1")
48 bayesNet.add(keys[8], [keys[12], keys[14]],
"T 1/4 3/2 4/1")
49 bayesNet.add(keys[9], [keys[12], keys[14]],
"4/1 2/3 F 1/4")
50 bayesNet.add(keys[10], [keys[13], keys[14]],
"1/4 3/2 2/3 4/1")
51 bayesNet.add(keys[11], [keys[13], keys[14]],
"1/4 2/3 3/2 4/1")
53 bayesNet.add(keys[12], [keys[14]],
"3/1 3/1")
54 bayesNet.add(keys[13], [keys[14]],
"1/3 3/1")
56 bayesNet.add(keys[14],
"1/3")
65 bayesTree = factorGraph.eliminateMultifrontal(ordering)
75 self.assertIsInstance(root, DiscreteBayesTreeClique)
76 self.assertTrue(root.isRoot())
77 self.assertIsInstance(root.conditional(), DiscreteConditional)
86 value_at_zeros = bayesTree.evaluate(zero_values)
87 self.assertAlmostEqual(value_at_zeros, 0.0)
90 values_star = factorGraph.optimize()
91 max_value = bayesTree.evaluate(values_star)
92 self.assertAlmostEqual(max_value, 0.002548)
95 max_value = bayesTree(values_star)
96 self.assertAlmostEqual(max_value, 0.002548)
98 self.assertFalse(bayesTree.empty())
99 self.assertEqual(12, bayesTree.size())
102 """Check that we can have a multi-frontal lookup table."""
105 x1, x2, x3 = (
X(1), 3), (
X(2), 3), (
X(3), 3)
106 a1, a2 = (A(1), 2), (A(2), 2)
109 graph.add([x1], np.array([1, 0, 0]))
110 graph.add([x3], np.array([0, 0, 1]))
120 graph.add([x1, a1, x2], table)
121 graph.add([x2, a2, x3], table)
124 ordering =
Ordering(keys=[A(2),
X(3),
X(1), A(1),
X(2)])
125 lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
128 assert lookup.size() == 2
129 lookup_x1_a1_x2 = lookup[
X(1)].conditional()
130 assert lookup_x1_a1_x2.nrFrontals() == 3
133 self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
139 self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
141 lookup_a2_x3 = lookup[
X(3)].conditional()
143 sum_x2 = lookup_a2_x3.sum(2)
146 self.assertAlmostEqual(sum_x2(values), 0)
148 self.assertAlmostEqual(sum_x2(values), 1.0)
150 self.assertAlmostEqual(sum_x2(values), 2.0)
151 assert lookup_a2_x3.nrFrontals() == 2
157 self.assertAlmostEqual(lookup_a2_x3(values), 1.0)
159 if __name__ ==
"__main__":