test_DiscreteFactorGraph.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Factor Graphs.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 import numpy as np
17 from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 OrderingType = Ordering.OrderingType
21 
22 
24  """Tests for Discrete Factor Graphs."""
25 
26  def test_evaluation(self):
27  """Test constructing and evaluating a discrete factor graph."""
28 
29  # Three keys
30  P1 = (0, 2)
31  P2 = (1, 2)
32  P3 = (2, 3)
33 
34  # Create the DiscreteFactorGraph
35  graph = DiscreteFactorGraph()
36 
37  # Add two unary factors (priors)
38  graph.add(P1, [0.9, 0.3])
39  graph.add(P2, "0.9 0.6")
40 
41  # Add a binary factor
42  graph.add([P1, P2], "4 1 10 4")
43 
44  # Instantiate Values
45  assignment = DiscreteValues()
46  assignment[0] = 1
47  assignment[1] = 1
48 
49  # Check if graph evaluation works ( 0.3*0.6*4 )
50  self.assertAlmostEqual(.72, graph(assignment))
51 
52  # Create a new test with third node and adding unary and ternary factor
53  graph.add(P3, "0.9 0.2 0.5")
54  keys = DiscreteKeys()
55  keys.push_back(P1)
56  keys.push_back(P2)
57  keys.push_back(P3)
58  graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
59 
60  # Below assignment selects the 8th index in the ternary factor table
61  assignment[0] = 1
62  assignment[1] = 0
63  assignment[2] = 1
64 
65  # Check if graph evaluation works (0.3*0.9*1*0.2*8)
66  self.assertAlmostEqual(4.32, graph(assignment))
67 
68  # Below assignment selects the 3rd index in the ternary factor table
69  assignment[0] = 0
70  assignment[1] = 1
71  assignment[2] = 0
72 
73  # Check if graph evaluation works (0.9*0.6*1*0.9*4)
74  self.assertAlmostEqual(1.944, graph(assignment))
75 
76  # Check if graph product works
77  product = graph.product()
78  self.assertAlmostEqual(1.944, product(assignment))
79 
80  def test_optimize(self):
81  """Test constructing and optizing a discrete factor graph."""
82 
83  # Three keys
84  C = (0, 2)
85  B = (1, 2)
86  A = (2, 2)
87 
88  # A simple factor graph (A)-fAC-(C)-fBC-(B)
89  # with smoothness priors
90  graph = DiscreteFactorGraph()
91  graph.add([A, C], "3 1 1 3")
92  graph.add([C, B], "3 1 1 3")
93 
94  # Test optimization
95  expectedValues = DiscreteValues()
96  expectedValues[0] = 0
97  expectedValues[1] = 0
98  expectedValues[2] = 0
99  actualValues = graph.optimize()
100  self.assertEqual(list(actualValues.items()),
101  list(expectedValues.items()))
102 
103  def test_MPE(self):
104  """Test maximum probable explanation (MPE): same as optimize."""
105 
106  # Declare a bunch of keys
107  C, A, B = (0, 2), (1, 2), (2, 2)
108 
109  # Create Factor graph
110  graph = DiscreteFactorGraph()
111  graph.add([C, A], "0.2 0.8 0.3 0.7")
112  graph.add([C, B], "0.1 0.9 0.4 0.6")
113 
114  # We know MPE
115  mpe = DiscreteValues()
116  mpe[0] = 0
117  mpe[1] = 1
118  mpe[2] = 1
119 
120  # Use maxProduct
121  dag = graph.maxProduct(OrderingType.COLAMD)
122  actualMPE = dag.argmax()
123  self.assertEqual(list(actualMPE.items()),
124  list(mpe.items()))
125 
126  # All in one
127  actualMPE2 = graph.optimize()
128  self.assertEqual(list(actualMPE2.items()),
129  list(mpe.items()))
130 
131  def test_sumProduct(self):
132  """Test sumProduct."""
133 
134  # Declare a bunch of keys
135  C, A, B = (0, 2), (1, 2), (2, 2)
136 
137  # Create Factor graph
138  graph = DiscreteFactorGraph()
139  graph.add([C, A], "0.2 0.8 0.3 0.7")
140  graph.add([C, B], "0.1 0.9 0.4 0.6")
141 
142  # We know MPE
143  mpe = DiscreteValues()
144  mpe[0] = 0
145  mpe[1] = 1
146  mpe[2] = 1
147 
148  # Use default sumProduct
149  bayesNet = graph.sumProduct()
150  mpeProbability = bayesNet(mpe)
151  self.assertAlmostEqual(mpeProbability, 0.36) # regression
152 
153  # Use sumProduct
154  for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
155  OrderingType.CUSTOM]:
156  bayesNet = graph.sumProduct(ordering_type)
157  self.assertEqual(bayesNet(mpe), mpeProbability)
158 
159  def test_MPE_chain(self):
160  """
161  Test for numerical underflow in EliminateMPE on long chains.
162  Adapted from the toy problem of @pcl15423
163  Ref: https://github.com/borglab/gtsam/issues/1448
164  """
165  num_states = 3
166  num_obs = 200
167  desired_state = 1
168  states = list(range(num_states))
169 
170  # Helper function to mimic the behavior of gtbook.Variables discrete_series function
171  def make_key(character, index, cardinality):
172  symbol = Symbol(character, index)
173  key = symbol.key()
174  return (key, cardinality)
175 
176  X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
177  Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
178  graph = DiscreteFactorGraph()
179 
180  # Mostly identity transition matrix
181  transitions = np.eye(num_states)
182 
183  # Needed otherwise mpe is always state 0?
184  transitions += 0.1/(num_states)
185 
186  transition_cpt = []
187  for i in range(0, num_states):
188  transition_row = "/".join([str(x) for x in transitions[i]])
189  transition_cpt.append(transition_row)
190  transition_cpt = " ".join(transition_cpt)
191 
192  for i in reversed(range(1, num_obs)):
193  transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
194  graph.push_back(transition_conditional)
195 
196  # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
197  # but all other states always give measurement num_obs
198  obs = np.zeros((num_states, num_obs+1))
199  obs[:,-1] = 1
200  obs[desired_state,0: -1] = 1
201  obs[desired_state,-1] = 0
202  obs_cpt_list = []
203  for i in range(0, num_states):
204  obs_row = "/".join([str(z) for z in obs[i]])
205  obs_cpt_list.append(obs_row)
206  obs_cpt = " ".join(obs_cpt_list)
207 
208  # Contrived example where each measurement is its own index
209  for i in range(0, num_obs):
210  obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
211  factor = obs_conditional.likelihood(i)
212  graph.push_back(factor)
213 
214  mpe = graph.optimize()
215  vals = [mpe[X[i][0]] for i in range(num_obs)]
216 
217  self.assertEqual(vals, [desired_state]*num_obs)
218 
220  """
221  Test for numerical underflow in EliminateDiscrete on long chains.
222  Adapted from the toy problem of @pcl15423
223  Ref: https://github.com/borglab/gtsam/issues/1448
224  """
225  num_states = 3
226  chain_length = 400
227  desired_state = 1
228  states = list(range(num_states))
229 
230  # Helper function to mimic the behavior of gtbook.Variables discrete_series function
231  def make_key(character, index, cardinality):
232  symbol = Symbol(character, index)
233  key = symbol.key()
234  return (key, cardinality)
235 
236  X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
237  graph = DiscreteFactorGraph()
238 
239  # Construct test transition matrix
240  transitions = np.diag([1.0, 0.5, 0.1])
241  transitions += 0.1/(num_states)
242 
243  # Ensure that the transition matrix is Markov (columns sum to 1)
244  transitions /= np.sum(transitions, axis=0)
245 
246  # The stationary distribution is the eigenvector corresponding to eigenvalue 1
247  eigvals, eigvecs = np.linalg.eig(transitions)
248  stationary_idx = np.where(np.isclose(eigvals, 1.0))
249  stationary_dist = eigvecs[:, stationary_idx]
250 
251  # Ensure that the stationary distribution is positive and normalized
252  stationary_dist /= np.sum(stationary_dist)
253  expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
254 
255  # The transition matrix parsed by DiscreteConditional is a row-wise CPT
256  transitions = transitions.T
257  transition_cpt = []
258  for i in range(0, num_states):
259  transition_row = "/".join([str(x) for x in transitions[i]])
260  transition_cpt.append(transition_row)
261  transition_cpt = " ".join(transition_cpt)
262 
263  for i in reversed(range(1, chain_length)):
264  transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
265  graph.push_back(transition_conditional)
266 
267  # Run sum product using natural ordering so the resulting Bayes net has the form:
268  # X_0 <- X_1 <- ... <- X_n
269  sum_product = graph.sumProduct(OrderingType.NATURAL)
270 
271  # Get the DiscreteConditional representing the marginal on the last factor
272  last_marginal = sum_product.at(chain_length - 1)
273 
274  # Ensure marginal probabilities are close to the stationary distribution
275  self.gtsamAssertEquals(expected, last_marginal)
276 
277 if __name__ == "__main__":
278  unittest.main()
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:18
NonlinearFactorGraph graph
Definition: pytypes.h:1403
Definition: pytypes.h:1979
Double_ range(const Point2_ &p, const Point2_ &q)
size_t len(handle h)
Get the length of a Python object.
Definition: pytypes.h:2244
void product(const MatrixType &m)
Definition: product.h:20
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:37:45