test_DiscreteBayesTree.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Bayes trees.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 import numpy as np
17 from gtsam.symbol_shorthand import A, X
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 import gtsam
21 from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
22  DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
23  Ordering)
24 
25 
27  """Tests for Discrete Bayes Nets."""
28 
29  def test_elimination(self):
30  """Test Multifrontal elimination."""
31 
32  # Define DiscreteKey pairs.
33  keys = [(j, 2) for j in range(15)]
34 
35  # Create thin-tree Bayes net.
36  bayesNet = DiscreteBayesNet()
37 
38  bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
39  bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4")
40  bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1")
41  bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1")
42 
43  bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1")
44  bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4")
45  bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1")
46  bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1")
47 
48  bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1")
49  bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4")
50  bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1")
51  bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1")
52 
53  bayesNet.add(keys[12], [keys[14]], "3/1 3/1")
54  bayesNet.add(keys[13], [keys[14]], "1/3 3/1")
55 
56  bayesNet.add(keys[14], "1/3")
57 
58  # Create a factor graph out of the Bayes net.
59  factorGraph = DiscreteFactorGraph(bayesNet)
60 
61  # Create a BayesTree out of the factor graph.
62  ordering = Ordering()
63  for j in range(15):
64  ordering.push_back(j)
65  bayesTree = factorGraph.eliminateMultifrontal(ordering)
66 
67  # Uncomment these for visualization:
68  # print(bayesTree)
69  # for key in range(15):
70  # bayesTree[key].printSignature()
71  # bayesTree.saveGraph("test_DiscreteBayesTree.dot")
72 
73  # The root is P( 8 12 14), we can retrieve it by key:
74  root = bayesTree[8]
75  self.assertIsInstance(root, DiscreteBayesTreeClique)
76  self.assertTrue(root.isRoot())
77  self.assertIsInstance(root.conditional(), DiscreteConditional)
78 
79  # Test all methods in DiscreteBayesTree
80  self.gtsamAssertEquals(bayesTree, bayesTree)
81 
82  # Check value at 0
83  zero_values = DiscreteValues()
84  for j in range(15):
85  zero_values[j] = 0
86  value_at_zeros = bayesTree.evaluate(zero_values)
87  self.assertAlmostEqual(value_at_zeros, 0.0)
88 
89  # Check value at max
90  values_star = factorGraph.optimize()
91  max_value = bayesTree.evaluate(values_star)
92  self.assertAlmostEqual(max_value, 0.002548)
93 
94  # Check operator sugar
95  max_value = bayesTree(values_star)
96  self.assertAlmostEqual(max_value, 0.002548)
97 
98  self.assertFalse(bayesTree.empty())
99  self.assertEqual(12, bayesTree.size())
100 
102  """Check that we can have a multi-frontal lookup table."""
103  # Make a small planning-like graph: 3 states, 2 actions
104  graph = DiscreteFactorGraph()
105  x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
106  a1, a2 = (A(1), 2), (A(2), 2)
107 
108  # Constraint on start and goal
109  graph.add([x1], np.array([1, 0, 0]))
110  graph.add([x3], np.array([0, 0, 1]))
111 
112  # Should I stay or should I go?
113  # "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
114  r = 10
115  table = np.array([
116  r, 0, 0, 0, r, 0, # x1 = 0
117  0, r, 0, 0, 0, r, # x1 = 1
118  0, 0, r, 0, 0, r # x1 = 2
119  ])
120  graph.add([x1, a1, x2], table)
121  graph.add([x2, a2, x3], table)
122 
123  # Eliminate for MPE (maximum probable explanation).
124  ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
125  lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
126 
127  # Check that the lookup table is correct
128  assert lookup.size() == 2
129  lookup_x1_a1_x2 = lookup[X(1)].conditional()
130  assert lookup_x1_a1_x2.nrFrontals() == 3
131  # Check that sum is 1.0 (not 100, as we now normalize to prevent underflow)
132  empty = gtsam.DiscreteValues()
133  self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
134  # And that only non-zero reward is for x1 a1 x2 == 0 1 1
135  values = DiscreteValues()
136  values[X(1)] = 0
137  values[A(1)] = 1
138  values[X(2)] = 1
139  self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
140 
141  lookup_a2_x3 = lookup[X(3)].conditional()
142  # Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
143  sum_x2 = lookup_a2_x3.sum(2)
144  values = DiscreteValues()
145  values[X(2)] = 0
146  self.assertAlmostEqual(sum_x2(values), 0)
147  values[X(2)] = 1
148  self.assertAlmostEqual(sum_x2(values), 1.0) # not 10, as we normalize
149  values[X(2)] = 2
150  self.assertAlmostEqual(sum_x2(values), 2.0) # not 20, as we normalize
151  assert lookup_a2_x3.nrFrontals() == 2
152  # And that the non-zero rewards are for x2 a2 x3 == 1 1 2
153  values = DiscreteValues()
154  values[X(2)] = 1
155  values[A(2)] = 1
156  values[X(3)] = 2
157  self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
158 
159 if __name__ == "__main__":
160  unittest.main()
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
gtsam::symbol_shorthand
Definition: inference/Symbol.h:147
X
#define X
Definition: icosphere.cpp:20
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::utils.test_case
Definition: test_case.py:1
test_DiscreteBayesTree.TestDiscreteBayesNet
Definition: test_DiscreteBayesTree.py:26
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_DiscreteBayesTree.TestDiscreteBayesNet.test_discrete_bayes_tree_lookup
def test_discrete_bayes_tree_lookup(self)
Definition: test_DiscreteBayesTree.py:101
gtsam::Ordering
Definition: inference/Ordering.h:33
test_DiscreteBayesTree.TestDiscreteBayesNet.test_elimination
def test_elimination(self)
Definition: test_DiscreteBayesTree.py:29


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:06:54