2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Unit tests for Discrete Factor Graphs.
17 from gtsam
import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
20 OrderingType = Ordering.OrderingType
24 """Tests for Discrete Factor Graphs."""
27 """Test constructing and evaluating a discrete factor graph."""
38 graph.add(P1, [0.9, 0.3])
39 graph.add(P2,
"0.9 0.6")
42 graph.add([P1, P2],
"4 1 10 4")
50 self.assertAlmostEqual(.72,
graph(assignment))
53 graph.add(P3,
"0.9 0.2 0.5")
58 graph.add(keys,
"1 2 3 4 5 6 7 8 9 10 11 12")
66 self.assertAlmostEqual(4.32,
graph(assignment))
74 self.assertAlmostEqual(1.944,
graph(assignment))
77 product = graph.product()
78 self.assertAlmostEqual(1.944,
product(assignment))
81 """Test constructing and optizing a discrete factor graph."""
91 graph.add([A, C],
"3 1 1 3")
92 graph.add([C, B],
"3 1 1 3")
99 actualValues = graph.optimize()
100 self.assertEqual(
list(actualValues.items()),
101 list(expectedValues.items()))
104 """Test maximum probable explanation (MPE): same as optimize."""
107 C, A, B = (0, 2), (1, 2), (2, 2)
111 graph.add([C, A],
"0.2 0.8 0.3 0.7")
112 graph.add([C, B],
"0.1 0.9 0.4 0.6")
121 dag = graph.maxProduct(OrderingType.COLAMD)
122 actualMPE = dag.argmax()
123 self.assertEqual(
list(actualMPE.items()),
127 actualMPE2 = graph.optimize()
128 self.assertEqual(
list(actualMPE2.items()),
132 """Test sumProduct."""
135 C, A, B = (0, 2), (1, 2), (2, 2)
139 graph.add([C, A],
"0.2 0.8 0.3 0.7")
140 graph.add([C, B],
"0.1 0.9 0.4 0.6")
149 bayesNet = graph.sumProduct()
151 self.assertAlmostEqual(mpeProbability, 0.36)
154 for ordering_type
in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
155 OrderingType.CUSTOM]:
156 bayesNet = graph.sumProduct(ordering_type)
157 self.assertEqual(
bayesNet(mpe), mpeProbability)
161 Test for numerical underflow in EliminateMPE on long chains.
162 Adapted from the toy problem of @pcl15423
163 Ref: https://github.com/borglab/gtsam/issues/1448
171 def make_key(character, index, cardinality):
172 symbol =
Symbol(character, index)
174 return (key, cardinality)
176 X = {index: make_key(
"X", index,
len(states))
for index
in range(num_obs)}
177 Z = {index: make_key(
"Z", index, num_obs + 1)
for index
in range(num_obs)}
181 transitions = np.eye(num_states)
184 transitions += 0.1/(num_states)
187 for i
in range(0, num_states):
188 transition_row =
"/".join([
str(x)
for x
in transitions[i]])
189 transition_cpt.append(transition_row)
190 transition_cpt =
" ".join(transition_cpt)
192 for i
in reversed(
range(1, num_obs)):
194 graph.push_back(transition_conditional)
198 obs = np.zeros((num_states, num_obs+1))
200 obs[desired_state,0: -1] = 1
201 obs[desired_state,-1] = 0
203 for i
in range(0, num_states):
204 obs_row =
"/".join([
str(z)
for z
in obs[i]])
205 obs_cpt_list.append(obs_row)
206 obs_cpt =
" ".join(obs_cpt_list)
209 for i
in range(0, num_obs):
211 factor = obs_conditional.likelihood(i)
212 graph.push_back(factor)
214 mpe = graph.optimize()
215 vals = [mpe[X[i][0]]
for i
in range(num_obs)]
217 self.assertEqual(vals, [desired_state]*num_obs)
221 Test for numerical underflow in EliminateDiscrete on long chains.
222 Adapted from the toy problem of @pcl15423
223 Ref: https://github.com/borglab/gtsam/issues/1448
231 def make_key(character, index, cardinality):
232 symbol =
Symbol(character, index)
234 return (key, cardinality)
236 X = {index: make_key(
"X", index,
len(states))
for index
in range(chain_length)}
240 transitions = np.diag([1.0, 0.5, 0.1])
241 transitions += 0.1/(num_states)
244 transitions /= np.sum(transitions, axis=0)
247 eigvals, eigvecs = np.linalg.eig(transitions)
248 stationary_idx = np.where(np.isclose(eigvals, 1.0))
249 stationary_dist = eigvecs[:, stationary_idx]
252 stationary_dist /= np.sum(stationary_dist)
256 transitions = transitions.T
258 for i
in range(0, num_states):
259 transition_row =
"/".join([
str(x)
for x
in transitions[i]])
260 transition_cpt.append(transition_row)
261 transition_cpt =
" ".join(transition_cpt)
263 for i
in reversed(
range(1, chain_length)):
265 graph.push_back(transition_conditional)
269 sum_product = graph.sumProduct(OrderingType.NATURAL)
272 last_marginal = sum_product.at(chain_length - 1)
277 if __name__ ==
"__main__":