test_DiscreteFactorGraph.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Factor Graphs.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 import numpy as np
17 from gtsam.utils.test_case import GtsamTestCase
18 
19 from gtsam import (DecisionTreeFactor, DiscreteConditional,
20  DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
21  Symbol)
22 
23 OrderingType = Ordering.OrderingType
24 
25 
27  """Tests for Discrete Factor Graphs."""
28 
29  def test_evaluation(self):
30  """Test constructing and evaluating a discrete factor graph."""
31 
32  # Three keys
33  P1 = (0, 2)
34  P2 = (1, 2)
35  P3 = (2, 3)
36 
37  # Create the DiscreteFactorGraph
38  graph = DiscreteFactorGraph()
39 
40  # Add two unary factors (priors)
41  graph.add(P1, [0.9, 0.3])
42  graph.add(P2, "0.9 0.6")
43 
44  # Add a binary factor
45  graph.add([P1, P2], "4 1 10 4")
46 
47  # Instantiate Values
48  assignment = DiscreteValues()
49  assignment[0] = 1
50  assignment[1] = 1
51 
52  # Check if graph evaluation works ( 0.3*0.6*4 )
53  self.assertAlmostEqual(.72, graph(assignment))
54 
55  # Create a new test with third node and adding unary and ternary factor
56  graph.add(P3, "0.9 0.2 0.5")
57  keys = DiscreteKeys()
58  keys.push_back(P1)
59  keys.push_back(P2)
60  keys.push_back(P3)
61  graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
62 
63  # Below assignment selects the 8th index in the ternary factor table
64  assignment[0] = 1
65  assignment[1] = 0
66  assignment[2] = 1
67 
68  # Check if graph evaluation works (0.3*0.9*1*0.2*8)
69  self.assertAlmostEqual(4.32, graph(assignment))
70 
71  # Below assignment selects the 3rd index in the ternary factor table
72  assignment[0] = 0
73  assignment[1] = 1
74  assignment[2] = 0
75 
76  # Check if graph evaluation works (0.9*0.6*1*0.9*4)
77  self.assertAlmostEqual(1.944, graph(assignment))
78 
79  # Check if graph product works
80  product = graph.product()
81  self.assertAlmostEqual(1.944, product(assignment))
82 
83  def test_optimize(self):
84  """Test constructing and optizing a discrete factor graph."""
85 
86  # Three keys
87  C = (0, 2)
88  B = (1, 2)
89  A = (2, 2)
90 
91  # A simple factor graph (A)-fAC-(C)-fBC-(B)
92  # with smoothness priors
93  graph = DiscreteFactorGraph()
94  graph.add([A, C], "3 1 1 3")
95  graph.add([C, B], "3 1 1 3")
96 
97  # Test optimization
98  expectedValues = DiscreteValues()
99  expectedValues[0] = 0
100  expectedValues[1] = 0
101  expectedValues[2] = 0
102  actualValues = graph.optimize()
103  self.assertEqual(list(actualValues.items()),
104  list(expectedValues.items()))
105 
106  def test_MPE(self):
107  """Test maximum probable explanation (MPE): same as optimize."""
108 
109  # Declare a bunch of keys
110  C, A, B = (0, 2), (1, 2), (2, 2)
111 
112  # Create Factor graph
113  graph = DiscreteFactorGraph()
114  graph.add([C, A], "0.2 0.8 0.3 0.7")
115  graph.add([C, B], "0.1 0.9 0.4 0.6")
116 
117  # We know MPE
118  mpe = DiscreteValues()
119  mpe[0] = 0
120  mpe[1] = 1
121  mpe[2] = 1
122 
123  # Use maxProduct
124  dag = graph.maxProduct(OrderingType.COLAMD)
125  actualMPE = dag.argmax()
126  self.assertEqual(list(actualMPE.items()),
127  list(mpe.items()))
128 
129  # All in one
130  actualMPE2 = graph.optimize()
131  self.assertEqual(list(actualMPE2.items()),
132  list(mpe.items()))
133 
134  def test_sumProduct(self):
135  """Test sumProduct."""
136 
137  # Declare a bunch of keys
138  C, A, B = (0, 2), (1, 2), (2, 2)
139 
140  # Create Factor graph
141  graph = DiscreteFactorGraph()
142  graph.add([C, A], "0.2 0.8 0.3 0.7")
143  graph.add([C, B], "0.1 0.9 0.4 0.6")
144 
145  # We know MPE
146  mpe = DiscreteValues()
147  mpe[0] = 0
148  mpe[1] = 1
149  mpe[2] = 1
150 
151  # Use default sumProduct
152  bayesNet = graph.sumProduct()
153  mpeProbability = bayesNet(mpe)
154  self.assertAlmostEqual(mpeProbability, 0.36) # regression
155 
156  # Use sumProduct
157  for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
158  OrderingType.CUSTOM]:
159  bayesNet = graph.sumProduct(ordering_type)
160  self.assertEqual(bayesNet(mpe), mpeProbability)
161 
162  def test_MPE_chain(self):
163  """
164  Test for numerical underflow in EliminateMPE on long chains.
165  Adapted from the toy problem of @pcl15423
166  Ref: https://github.com/borglab/gtsam/issues/1448
167  """
168  num_states = 3
169  num_obs = 200
170  desired_state = 1
171  states = list(range(num_states))
172 
173  # Helper function to mimic the behavior of gtbook.Variables discrete_series function
174  def make_key(character, index, cardinality):
175  symbol = Symbol(character, index)
176  key = symbol.key()
177  return (key, cardinality)
178 
179  X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
180  Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
181  graph = DiscreteFactorGraph()
182 
183  # Mostly identity transition matrix
184  transitions = np.eye(num_states)
185 
186  # Needed otherwise mpe is always state 0?
187  transitions += 0.1/(num_states)
188 
189  transition_cpt = []
190  for i in range(0, num_states):
191  transition_row = "/".join([str(x) for x in transitions[i]])
192  transition_cpt.append(transition_row)
193  transition_cpt = " ".join(transition_cpt)
194 
195  for i in reversed(range(1, num_obs)):
196  transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
197  graph.push_back(transition_conditional)
198 
199  # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
200  # but all other states always give measurement num_obs
201  obs = np.zeros((num_states, num_obs+1))
202  obs[:,-1] = 1
203  obs[desired_state,0: -1] = 1
204  obs[desired_state,-1] = 0
205  obs_cpt_list = []
206  for i in range(0, num_states):
207  obs_row = "/".join([str(z) for z in obs[i]])
208  obs_cpt_list.append(obs_row)
209  obs_cpt = " ".join(obs_cpt_list)
210 
211  # Contrived example where each measurement is its own index
212  for i in range(0, num_obs):
213  obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
214  factor = obs_conditional.likelihood(i)
215  graph.push_back(factor)
216 
217  mpe = graph.optimize()
218  vals = [mpe[X[i][0]] for i in range(num_obs)]
219 
220  self.assertEqual(vals, [desired_state]*num_obs)
221 
223  """
224  Test for numerical underflow in EliminateDiscrete on long chains.
225  Adapted from the toy problem of @pcl15423
226  Ref: https://github.com/borglab/gtsam/issues/1448
227  """
228  num_states = 3
229  chain_length = 400
230  desired_state = 1
231  states = list(range(num_states))
232 
233  # Helper function to mimic the behavior of gtbook.Variables discrete_series function
234  def make_key(character, index, cardinality):
235  symbol = Symbol(character, index)
236  key = symbol.key()
237  return (key, cardinality)
238 
239  X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
240  graph = DiscreteFactorGraph()
241 
242  # Construct test transition matrix
243  transitions = np.diag([1.0, 0.5, 0.1])
244  transitions += 0.1/(num_states)
245 
246  # Ensure that the transition matrix is Markov (columns sum to 1)
247  transitions /= np.sum(transitions, axis=0)
248 
249  # The stationary distribution is the eigenvector corresponding to eigenvalue 1
250  eigvals, eigvecs = np.linalg.eig(transitions)
251  stationary_idx = np.where(np.isclose(eigvals, 1.0))
252  stationary_dist = eigvecs[:, stationary_idx]
253 
254  # Ensure that the stationary distribution is positive and normalized
255  stationary_dist /= np.sum(stationary_dist)
256  expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
257 
258  # The transition matrix parsed by DiscreteConditional is a row-wise CPT
259  transitions = transitions.T
260  transition_cpt = []
261  for i in range(0, num_states):
262  transition_row = "/".join([str(x) for x in transitions[i]])
263  transition_cpt.append(transition_row)
264  transition_cpt = " ".join(transition_cpt)
265 
266  for i in reversed(range(1, chain_length)):
267  transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
268  graph.push_back(transition_conditional)
269 
270  # Run sum product using natural ordering so the resulting Bayes net has the form:
271  # X_0 <- X_1 <- ... <- X_n
272  sum_product = graph.sumProduct(OrderingType.NATURAL)
273 
274  # Get the DiscreteConditional representing the marginal on the last factor
275  last_marginal = sum_product.at(chain_length - 1)
276 
277  # Ensure marginal probabilities are close to the stationary distribution
278  self.gtsamAssertEquals(expected, last_marginal)
279 
280 if __name__ == "__main__":
281  unittest.main()
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
list
Definition: pytypes.h:2166
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_MPE
def test_MPE(self)
Definition: test_DiscreteFactorGraph.py:106
test_DiscreteFactorGraph.TestDiscreteFactorGraph
Definition: test_DiscreteFactorGraph.py:26
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_MPE_chain
def test_MPE_chain(self)
Definition: test_DiscreteFactorGraph.py:162
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
different_sigmas::bayesNet
const HybridBayesNet bayesNet
Definition: testHybridBayesNet.cpp:242
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_evaluation
def test_evaluation(self)
Definition: test_DiscreteFactorGraph.py:29
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_sumProduct
def test_sumProduct(self)
Definition: test_DiscreteFactorGraph.py:134
gtsam::utils.test_case
Definition: test_case.py:1
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_optimize
def test_optimize(self)
Definition: test_DiscreteFactorGraph.py:83
str
Definition: pytypes.h:1558
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
product
void product(const MatrixType &m)
Definition: product.h:20
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_DiscreteFactorGraph.TestDiscreteFactorGraph.test_sumProduct_chain
def test_sumProduct_chain(self)
Definition: test_DiscreteFactorGraph.py:222
len
size_t len(handle h)
Get the length of a Python object.
Definition: pytypes.h:2446
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2
gtsam::Symbol
Definition: inference/Symbol.h:37


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:06:54