test_DiscreteBayesTree.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Bayes trees.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 import numpy as np
17 from gtsam.symbol_shorthand import A, X
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 import gtsam
21 from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
22  DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
23  Ordering)
24 
25 
27  """Tests for Discrete Bayes Nets."""
28 
29  def test_elimination(self):
30  """Test Multifrontal elimination."""
31 
32  # Define DiscreteKey pairs.
33  keys = [(j, 2) for j in range(15)]
34 
35  # Create thin-tree Bayes net.
36  bayesNet = DiscreteBayesNet()
37 
38  bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
39  bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4")
40  bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1")
41  bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1")
42 
43  bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1")
44  bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4")
45  bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1")
46  bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1")
47 
48  bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1")
49  bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4")
50  bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1")
51  bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1")
52 
53  bayesNet.add(keys[12], [keys[14]], "3/1 3/1")
54  bayesNet.add(keys[13], [keys[14]], "1/3 3/1")
55 
56  bayesNet.add(keys[14], "1/3")
57 
58  # Create a factor graph out of the Bayes net.
59  factorGraph = DiscreteFactorGraph(bayesNet)
60 
61  # Create a BayesTree out of the factor graph.
62  ordering = Ordering()
63  for j in range(15):
64  ordering.push_back(j)
65  bayesTree = factorGraph.eliminateMultifrontal(ordering)
66 
67  # Uncomment these for visualization:
68  # print(bayesTree)
69  # for key in range(15):
70  # bayesTree[key].printSignature()
71  # bayesTree.saveGraph("test_DiscreteBayesTree.dot")
72 
73  # The root is P( 8 12 14), we can retrieve it by key:
74  root = bayesTree[8]
75  self.assertIsInstance(root, DiscreteBayesTreeClique)
76  self.assertTrue(root.isRoot())
77  self.assertIsInstance(root.conditional(), DiscreteConditional)
78 
79  # Test all methods in DiscreteBayesTree
80  self.gtsamAssertEquals(bayesTree, bayesTree)
81 
82  # Check value at 0
83  zero_values = DiscreteValues()
84  for j in range(15):
85  zero_values[j] = 0
86  value_at_zeros = bayesTree.evaluate(zero_values)
87  self.assertAlmostEqual(value_at_zeros, 0.0)
88 
89  # Check value at max
90  values_star = factorGraph.optimize()
91  max_value = bayesTree.evaluate(values_star)
92  self.assertAlmostEqual(max_value, 0.002548)
93 
94  # Check operator sugar
95  max_value = bayesTree(values_star)
96  self.assertAlmostEqual(max_value, 0.002548)
97 
98  self.assertFalse(bayesTree.empty())
99  self.assertEqual(12, bayesTree.size())
100 
101  @unittest.skip("Too Slow")
103  """Check that we can have a multi-frontal lookup table."""
104  # Make a small planning-like graph: 3 states, 2 actions
105  graph = DiscreteFactorGraph()
106  x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
107  a1, a2 = (A(1), 2), (A(2), 2)
108 
109  # Constraint on start and goal
110  graph.add([x1], np.array([1, 0, 0]))
111  graph.add([x3], np.array([0, 0, 1]))
112 
113  # Should I stay or should I go?
114  # "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
115  r = 10
116  table = np.array([
117  r, 0, 0, 0, r, 0, # x1 = 0
118  0, r, 0, 0, 0, r, # x1 = 1
119  0, 0, r, 0, 0, r # x1 = 2
120  ])
121  graph.add([x1, a1, x2], table)
122  graph.add([x2, a2, x3], table)
123 
124  # Eliminate for MPE (maximum probable explanation).
125  ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
126  lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
127 
128  # Check that the lookup table is correct
129  assert lookup.size() == 2
130  lookup_x1_a1_x2 = lookup[X(1)].conditional()
131  assert lookup_x1_a1_x2.nrFrontals() == 3
132  # Check that sum is 1.0 (not 100, as we now normalize to prevent underflow)
133  empty = gtsam.DiscreteValues()
134  self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
135  # And that only non-zero reward is for x1 a1 x2 == 0 1 1
136  values = DiscreteValues()
137  values[X(1)] = 0
138  values[A(1)] = 1
139  values[X(2)] = 1
140  self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
141 
142  lookup_a2_x3 = lookup[X(3)].conditional()
143  # Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
144  sum_x2 = lookup_a2_x3.sum(2)
145  values = DiscreteValues()
146  values[X(2)] = 0
147  self.assertAlmostEqual(sum_x2(values), 0)
148  values[X(2)] = 1
149  self.assertAlmostEqual(sum_x2(values), 1.0) # not 10, as we normalize
150  values[X(2)] = 2
151  self.assertAlmostEqual(sum_x2(values), 2.0) # not 20, as we normalize
152  assert lookup_a2_x3.nrFrontals() == 2
153  # And that the non-zero rewards are for x2 a2 x3 == 1 1 2
154  values = DiscreteValues()
155  values[X(2)] = 1
156  values[A(2)] = 1
157  values[X(3)] = 2
158  self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
159 
161  """Test creating a Bayes tree directly from cliques."""
162  # Create a BayesNet
163  bayesNet = DiscreteBayesNet()
164  A, B, C = (0, 2), (1, 2), (2, 2)
165  bayesNet.add(A, "1/3")
166  bayesNet.add(B, [A], "1/3 3/1")
167  bayesNet.add(C, [B], "3/1 3/1")
168 
169  # Create cliques directly
170  clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1"))
171  clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1"))
172  clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3"))
173 
174  # Create a BayesTree
175  bayesTree = gtsam.DiscreteBayesTree()
176  bayesTree.insertRoot(clique2)
177  bayesTree.addClique(clique1, clique2)
178  bayesTree.addClique(clique0, clique1)
179 
180  # Check that the BayesTree is correct
181  values = DiscreteValues()
182  values[0] = 1
183  values[1] = 1
184  values[2] = 1
185 
186  # regression
187  expected = .046875
188  self.assertAlmostEqual(expected, bayesNet.evaluate(values))
189 
190 
191 if __name__ == "__main__":
192  unittest.main()
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
asia::bayesTree
static const DiscreteBayesTree bayesTree
Definition: testDiscreteSearch.cpp:40
gtsam::symbol_shorthand
Definition: inference/Symbol.h:147
X
#define X
Definition: icosphere.cpp:20
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::utils.test_case
Definition: test_case.py:1
test_DiscreteBayesTree.TestDiscreteBayesNet
Definition: test_DiscreteBayesTree.py:26
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::DiscreteBayesTree
A Bayes tree representing a Discrete distribution.
Definition: DiscreteBayesTree.h:73
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_DiscreteBayesTree.TestDiscreteBayesNet.test_discrete_bayes_tree_lookup
def test_discrete_bayes_tree_lookup(self)
Definition: test_DiscreteBayesTree.py:102
gtsam::DiscreteBayesTreeClique
Definition: DiscreteBayesTree.h:39
gtsam::Ordering
Definition: inference/Ordering.h:33
test_DiscreteBayesTree.TestDiscreteBayesNet.test_direct_from_cliques
def test_direct_from_cliques(self)
Definition: test_DiscreteBayesTree.py:160
test_DiscreteBayesTree.TestDiscreteBayesNet.test_elimination
def test_elimination(self)
Definition: test_DiscreteBayesTree.py:29


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:06:17