test_DiscreteFactorGraph.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Factor Graphs.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 import numpy as np
17 from gtsam.utils.test_case import GtsamTestCase
18 from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt
19 
20 from gtsam import (
21  DecisionTreeFactor,
22  DiscreteConditional,
23  DiscreteFactorGraph,
24  DiscreteKeys,
25  DiscreteValues,
26  Ordering,
27 )
28 
29 OrderingType = Ordering.OrderingType
30 
31 
33  """Tests for Discrete Factor Graphs."""
34 
35  def test_evaluation(self):
36  """Test constructing and evaluating a discrete factor graph."""
37 
38  # Three keys
39  P1 = (0, 2)
40  P2 = (1, 2)
41  P3 = (2, 3)
42 
43  # Create the DiscreteFactorGraph
44  graph = DiscreteFactorGraph()
45 
46  # Add two unary factors (priors)
47  graph.add(P1, [0.9, 0.3])
48  graph.add(P2, "0.9 0.6")
49 
50  # Add a binary factor
51  graph.add([P1, P2], "4 1 10 4")
52 
53  # Instantiate Values
54  assignment = DiscreteValues()
55  assignment[0] = 1
56  assignment[1] = 1
57 
58  # Check if graph evaluation works ( 0.3*0.6*4 )
59  self.assertAlmostEqual(0.72, graph(assignment))
60 
61  # Create a new test with third node and adding unary and ternary factor
62  graph.add(P3, "0.9 0.2 0.5")
63  keys = DiscreteKeys()
64  keys.push_back(P1)
65  keys.push_back(P2)
66  keys.push_back(P3)
67  graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
68 
69  # Below assignment selects the 8th index in the ternary factor table
70  assignment[0] = 1
71  assignment[1] = 0
72  assignment[2] = 1
73 
74  # Check if graph evaluation works (0.3*0.9*1*0.2*8)
75  self.assertAlmostEqual(4.32, graph(assignment))
76 
77  # Below assignment selects the 3rd index in the ternary factor table
78  assignment[0] = 0
79  assignment[1] = 1
80  assignment[2] = 0
81 
82  # Check if graph evaluation works (0.9*0.6*1*0.9*4)
83  self.assertAlmostEqual(1.944, graph(assignment))
84 
85  # Check if graph product works
86  product = graph.product()
87  self.assertAlmostEqual(1.944, product(assignment))
88 
89  def test_optimize(self):
90  """Test constructing and optizing a discrete factor graph."""
91 
92  # Three keys
93  C = (0, 2)
94  B = (1, 2)
95  A = (2, 2)
96 
97  # A simple factor graph (A)-fAC-(C)-fBC-(B)
98  # with smoothness priors
99  graph = DiscreteFactorGraph()
100  graph.add([A, C], "3 1 1 3")
101  graph.add([C, B], "3 1 1 3")
102 
103  # Test optimization
104  expectedValues = DiscreteValues()
105  expectedValues[0] = 0
106  expectedValues[1] = 0
107  expectedValues[2] = 0
108  actualValues = graph.optimize()
109  self.assertEqual(list(actualValues.items()), list(expectedValues.items()))
110 
111  def test_MPE(self):
112  """Test maximum probable explanation (MPE): same as optimize."""
113 
114  # Declare a bunch of keys
115  C, A, B = (0, 2), (1, 2), (2, 2)
116 
117  # Create Factor graph
118  graph = DiscreteFactorGraph()
119  graph.add([C, A], "0.2 0.8 0.3 0.7")
120  graph.add([C, B], "0.1 0.9 0.4 0.6")
121 
122  # We know MPE
123  mpe = DiscreteValues()
124  mpe[0] = 0
125  mpe[1] = 1
126  mpe[2] = 1
127 
128  # Use maxProduct
129  dag = graph.maxProduct(OrderingType.COLAMD)
130  actualMPE = dag.argmax()
131  self.assertEqual(list(actualMPE.items()), list(mpe.items()))
132 
133  # All in one
134  actualMPE2 = graph.optimize()
135  self.assertEqual(list(actualMPE2.items()), list(mpe.items()))
136 
137  def test_sumProduct(self):
138  """Test sumProduct."""
139 
140  # Declare a bunch of keys
141  C, A, B = (0, 2), (1, 2), (2, 2)
142 
143  # Create Factor graph
144  graph = DiscreteFactorGraph()
145  graph.add([C, A], "0.2 0.8 0.3 0.7")
146  graph.add([C, B], "0.1 0.9 0.4 0.6")
147 
148  # We know MPE
149  mpe = DiscreteValues()
150  mpe[0] = 0
151  mpe[1] = 1
152  mpe[2] = 1
153 
154  # Use default sumProduct
155  bayesNet = graph.sumProduct()
156  mpeProbability = bayesNet(mpe)
157  self.assertAlmostEqual(mpeProbability, 0.36) # regression
158 
159  # Use sumProduct
160  for ordering_type in [
161  OrderingType.COLAMD,
162  OrderingType.METIS,
163  OrderingType.NATURAL,
164  OrderingType.CUSTOM,
165  ]:
166  bayesNet = graph.sumProduct(ordering_type)
167  self.assertEqual(bayesNet(mpe), mpeProbability)
168 
169 
171  def test_MPE_chain(self):
172  """
173  Test for numerical underflow in EliminateMPE on long chains.
174  Adapted from the toy problem of @pcl15423
175  Ref: https://github.com/borglab/gtsam/issues/1448
176  """
177  num_states = 3
178  num_obs = 200
179  desired_state = 1
180  states = list(range(num_states))
181 
182  X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
183  Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
184  graph = DiscreteFactorGraph()
185 
186  transition_cpt = generate_transition_cpt(num_states)
187  for i in reversed(range(1, num_obs)):
188  transition_conditional = DiscreteConditional(
189  X[i], [X[i - 1]], transition_cpt
190  )
191  graph.push_back(transition_conditional)
192 
193  # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
194  # but all other states always give measurement num_obs
195  obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
196  # Contrived example where each measurement is its own index
197  for i in range(num_obs):
198  obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
199  factor = obs_conditional.likelihood(i)
200  graph.push_back(factor)
201 
202  mpe = graph.optimize()
203  vals = [mpe[X[i][0]] for i in range(num_obs)]
204 
205  self.assertEqual(vals, [desired_state] * num_obs)
206 
208  """
209  Test for numerical underflow in EliminateDiscrete on long chains.
210  Adapted from the toy problem of @pcl15423
211  Ref: https://github.com/borglab/gtsam/issues/1448
212  """
213  num_states = 3
214  chain_length = 400
215  states = list(range(num_states))
216 
217  X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
218  graph = DiscreteFactorGraph()
219 
220  # Construct test transition matrix
221  transitions = np.diag([1.0, 0.5, 0.1])
222  transitions += 0.1/(num_states)
223 
224  # Ensure that the transition matrix is Markov (columns sum to 1)
225  transitions /= np.sum(transitions, axis=0)
226 
227  # The stationary distribution is the eigenvector corresponding to eigenvalue 1
228  eigvals, eigvecs = np.linalg.eig(transitions)
229  stationary_idx = np.where(np.isclose(eigvals, 1.0))
230  stationary_dist = eigvecs[:, stationary_idx]
231 
232  # Ensure that the stationary distribution is positive and normalized
233  stationary_dist /= np.sum(stationary_dist)
234  expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel())
235 
236  # The transition matrix parsed by DiscreteConditional is a row-wise CPT
237  transition_cpt = generate_transition_cpt(num_states, transitions.T)
238 
239  for i in reversed(range(1, chain_length)):
240  transition_conditional = DiscreteConditional(
241  X[i], [X[i - 1]], transition_cpt
242  )
243  graph.push_back(transition_conditional)
244 
245  # Run sum product using natural ordering so the resulting Bayes net has the form:
246  # X_0 <- X_1 <- ... <- X_n
247  sum_product = graph.sumProduct(OrderingType.NATURAL)
248 
249  # Get the DiscreteConditional representing the marginal on the last factor
250  last_marginal = sum_product.at(chain_length - 1)
251 
252  # Ensure marginal probabilities are close to the stationary distribution
253  self.gtsamAssertEquals(expected, last_marginal)
254 
255 
256 if __name__ == "__main__":
257  unittest.main()
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
asia::bayesNet
static const DiscreteBayesNet bayesNet
Definition: testDiscreteSearch.cpp:30
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
list
Definition: pytypes.h:2166
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_MPE
def test_MPE(self)
Definition: test_DiscreteFactorGraph.py:111
test_DiscreteFactorGraph.TestDiscreteFactorGraph
Definition: test_DiscreteFactorGraph.py:32
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
test_DiscreteFactorGraph.TestChains.test_MPE_chain
def test_MPE_chain(self)
Definition: test_DiscreteFactorGraph.py:171
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
dfg_utils.generate_observation_cpt
def generate_observation_cpt(num_states, num_obs, desired_state)
Definition: dfg_utils.py:27
dfg_utils.generate_transition_cpt
def generate_transition_cpt(num_states, transitions=None)
Definition: dfg_utils.py:14
test_DiscreteFactorGraph.TestChains.test_sumProduct_chain
def test_sumProduct_chain(self)
Definition: test_DiscreteFactorGraph.py:207
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_evaluation
def test_evaluation(self)
Definition: test_DiscreteFactorGraph.py:35
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_sumProduct
def test_sumProduct(self)
Definition: test_DiscreteFactorGraph.py:137
gtsam::utils.test_case
Definition: test_case.py:1
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_optimize
def test_optimize(self)
Definition: test_DiscreteFactorGraph.py:89
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
test_DiscreteFactorGraph.TestChains
Definition: test_DiscreteFactorGraph.py:170
product
void product(const MatrixType &m)
Definition: product.h:20
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
len
size_t len(handle h)
Get the length of a Python object.
Definition: pytypes.h:2446
dfg_utils.make_key
def make_key(character, index, cardinality)
Definition: dfg_utils.py:5
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2


gtsam
Author(s):
autogenerated on Mon Mar 10 2025 03:07:20