2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Factor Graphs.
18 from dfg_utils
import make_key, generate_transition_cpt, generate_observation_cpt
29 OrderingType = Ordering.OrderingType
33 """Tests for Discrete Factor Graphs."""
36 """Test constructing and evaluating a discrete factor graph."""
47 graph.add(P1, [0.9, 0.3])
48 graph.add(P2,
"0.9 0.6")
51 graph.add([P1, P2],
"4 1 10 4")
59 self.assertAlmostEqual(0.72,
graph(assignment))
62 graph.add(P3,
"0.9 0.2 0.5")
67 graph.add(keys,
"1 2 3 4 5 6 7 8 9 10 11 12")
75 self.assertAlmostEqual(4.32,
graph(assignment))
83 self.assertAlmostEqual(1.944,
graph(assignment))
86 product = graph.product()
87 self.assertAlmostEqual(1.944,
product(assignment))
90 """Test constructing and optizing a discrete factor graph."""
100 graph.add([A, C],
"3 1 1 3")
101 graph.add([C, B],
"3 1 1 3")
105 expectedValues[0] = 0
106 expectedValues[1] = 0
107 expectedValues[2] = 0
108 actualValues = graph.optimize()
109 self.assertEqual(
list(actualValues.items()),
list(expectedValues.items()))
112 """Test maximum probable explanation (MPE): same as optimize."""
115 C, A, B = (0, 2), (1, 2), (2, 2)
119 graph.add([C, A],
"0.2 0.8 0.3 0.7")
120 graph.add([C, B],
"0.1 0.9 0.4 0.6")
129 dag = graph.maxProduct(OrderingType.COLAMD)
130 actualMPE = dag.argmax()
131 self.assertEqual(
list(actualMPE.items()),
list(mpe.items()))
134 actualMPE2 = graph.optimize()
135 self.assertEqual(
list(actualMPE2.items()),
list(mpe.items()))
138 """Test sumProduct."""
141 C, A, B = (0, 2), (1, 2), (2, 2)
145 graph.add([C, A],
"0.2 0.8 0.3 0.7")
146 graph.add([C, B],
"0.1 0.9 0.4 0.6")
155 bayesNet = graph.sumProduct()
157 self.assertAlmostEqual(mpeProbability, 0.36)
160 for ordering_type
in [
163 OrderingType.NATURAL,
166 bayesNet = graph.sumProduct(ordering_type)
167 self.assertEqual(
bayesNet(mpe), mpeProbability)
173 Test for numerical underflow in EliminateMPE on long chains.
174 Adapted from the toy problem of @pcl15423
175 Ref: https://github.com/borglab/gtsam/issues/1448
183 Z = {index:
make_key(
"Z", index, num_obs + 1)
for index
in range(num_obs)}
187 for i
in reversed(
range(1, num_obs)):
189 X[i], [X[i - 1]], transition_cpt
191 graph.push_back(transition_conditional)
197 for i
in range(num_obs):
199 factor = obs_conditional.likelihood(i)
200 graph.push_back(factor)
202 mpe = graph.optimize()
203 vals = [mpe[X[i][0]]
for i
in range(num_obs)]
205 self.assertEqual(vals, [desired_state] * num_obs)
209 Test for numerical underflow in EliminateDiscrete on long chains.
210 Adapted from the toy problem of @pcl15423
211 Ref: https://github.com/borglab/gtsam/issues/1448
217 X = {index:
make_key(
"X", index,
len(states))
for index
in range(chain_length)}
221 transitions = np.diag([1.0, 0.5, 0.1])
222 transitions += 0.1/(num_states)
225 transitions /= np.sum(transitions, axis=0)
228 eigvals, eigvecs = np.linalg.eig(transitions)
229 stationary_idx = np.where(np.isclose(eigvals, 1.0))
230 stationary_dist = eigvecs[:, stationary_idx]
233 stationary_dist /= np.sum(stationary_dist)
239 for i
in reversed(
range(1, chain_length)):
241 X[i], [X[i - 1]], transition_cpt
243 graph.push_back(transition_conditional)
247 sum_product = graph.sumProduct(OrderingType.NATURAL)
250 last_marginal = sum_product.at(chain_length - 1)
256 if __name__ ==
"__main__":