2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Factor Graphs.
19 from gtsam
import (DecisionTreeFactor, DiscreteConditional,
20 DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
23 OrderingType = Ordering.OrderingType
27 """Tests for Discrete Factor Graphs."""
30 """Test constructing and evaluating a discrete factor graph."""
41 graph.add(P1, [0.9, 0.3])
42 graph.add(P2,
"0.9 0.6")
45 graph.add([P1, P2],
"4 1 10 4")
53 self.assertAlmostEqual(.72,
graph(assignment))
56 graph.add(P3,
"0.9 0.2 0.5")
61 graph.add(keys,
"1 2 3 4 5 6 7 8 9 10 11 12")
69 self.assertAlmostEqual(4.32,
graph(assignment))
77 self.assertAlmostEqual(1.944,
graph(assignment))
80 product = graph.product()
81 self.assertAlmostEqual(1.944,
product(assignment))
84 """Test constructing and optizing a discrete factor graph."""
94 graph.add([A, C],
"3 1 1 3")
95 graph.add([C, B],
"3 1 1 3")
100 expectedValues[1] = 0
101 expectedValues[2] = 0
102 actualValues = graph.optimize()
103 self.assertEqual(
list(actualValues.items()),
104 list(expectedValues.items()))
107 """Test maximum probable explanation (MPE): same as optimize."""
110 C, A, B = (0, 2), (1, 2), (2, 2)
114 graph.add([C, A],
"0.2 0.8 0.3 0.7")
115 graph.add([C, B],
"0.1 0.9 0.4 0.6")
124 dag = graph.maxProduct(OrderingType.COLAMD)
125 actualMPE = dag.argmax()
126 self.assertEqual(
list(actualMPE.items()),
130 actualMPE2 = graph.optimize()
131 self.assertEqual(
list(actualMPE2.items()),
135 """Test sumProduct."""
138 C, A, B = (0, 2), (1, 2), (2, 2)
142 graph.add([C, A],
"0.2 0.8 0.3 0.7")
143 graph.add([C, B],
"0.1 0.9 0.4 0.6")
152 bayesNet = graph.sumProduct()
154 self.assertAlmostEqual(mpeProbability, 0.36)
157 for ordering_type
in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
158 OrderingType.CUSTOM]:
159 bayesNet = graph.sumProduct(ordering_type)
160 self.assertEqual(
bayesNet(mpe), mpeProbability)
164 Test for numerical underflow in EliminateMPE on long chains.
165 Adapted from the toy problem of @pcl15423
166 Ref: https://github.com/borglab/gtsam/issues/1448
174 def make_key(character, index, cardinality):
175 symbol =
Symbol(character, index)
177 return (key, cardinality)
179 X = {index: make_key(
"X", index,
len(states))
for index
in range(num_obs)}
180 Z = {index: make_key(
"Z", index, num_obs + 1)
for index
in range(num_obs)}
184 transitions = np.eye(num_states)
187 transitions += 0.1/(num_states)
190 for i
in range(0, num_states):
191 transition_row =
"/".join([
str(x)
for x
in transitions[i]])
192 transition_cpt.append(transition_row)
193 transition_cpt =
" ".join(transition_cpt)
195 for i
in reversed(
range(1, num_obs)):
197 graph.push_back(transition_conditional)
201 obs = np.zeros((num_states, num_obs+1))
203 obs[desired_state,0: -1] = 1
204 obs[desired_state,-1] = 0
206 for i
in range(0, num_states):
207 obs_row =
"/".join([
str(z)
for z
in obs[i]])
208 obs_cpt_list.append(obs_row)
209 obs_cpt =
" ".join(obs_cpt_list)
212 for i
in range(0, num_obs):
214 factor = obs_conditional.likelihood(i)
215 graph.push_back(factor)
217 mpe = graph.optimize()
218 vals = [mpe[X[i][0]]
for i
in range(num_obs)]
220 self.assertEqual(vals, [desired_state]*num_obs)
224 Test for numerical underflow in EliminateDiscrete on long chains.
225 Adapted from the toy problem of @pcl15423
226 Ref: https://github.com/borglab/gtsam/issues/1448
234 def make_key(character, index, cardinality):
235 symbol =
Symbol(character, index)
237 return (key, cardinality)
239 X = {index: make_key(
"X", index,
len(states))
for index
in range(chain_length)}
243 transitions = np.diag([1.0, 0.5, 0.1])
244 transitions += 0.1/(num_states)
247 transitions /= np.sum(transitions, axis=0)
250 eigvals, eigvecs = np.linalg.eig(transitions)
251 stationary_idx = np.where(np.isclose(eigvals, 1.0))
252 stationary_dist = eigvecs[:, stationary_idx]
255 stationary_dist /= np.sum(stationary_dist)
259 transitions = transitions.T
261 for i
in range(0, num_states):
262 transition_row =
"/".join([
str(x)
for x
in transitions[i]])
263 transition_cpt.append(transition_row)
264 transition_cpt =
" ".join(transition_cpt)
266 for i
in reversed(
range(1, chain_length)):
268 graph.push_back(transition_conditional)
272 sum_product = graph.sumProduct(OrderingType.NATURAL)
275 last_marginal = sum_product.at(chain_length - 1)
280 if __name__ ==
"__main__":