2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, 
    3 Atlanta, Georgia 30332-0415 
    6 See LICENSE for the license information 
    8 Unit tests for Discrete Factor Graphs. 
   18 from dfg_utils 
import make_key, generate_transition_cpt, generate_observation_cpt
 
   29 OrderingType = Ordering.OrderingType
 
   33     """Tests for Discrete Factor Graphs.""" 
   36         """Test constructing and evaluating a discrete factor graph.""" 
   47         graph.add(P1, [0.9, 0.3])
 
   48         graph.add(P2, 
"0.9 0.6")
 
   51         graph.add([P1, P2], 
"4 1 10 4")
 
   59         self.assertAlmostEqual(0.72, 
graph(assignment))
 
   62         graph.add(P3, 
"0.9 0.2 0.5")
 
   67         graph.add(keys, 
"1 2 3 4 5 6 7 8 9 10 11 12")
 
   75         self.assertAlmostEqual(4.32, 
graph(assignment))
 
   83         self.assertAlmostEqual(1.944, 
graph(assignment))
 
   86         product = graph.product()
 
   87         self.assertAlmostEqual(1.944, 
product(assignment))
 
   90         """Test constructing and optizing a discrete factor graph.""" 
  100         graph.add([A, C], 
"3 1 1 3")
 
  101         graph.add([C, B], 
"3 1 1 3")
 
  105         expectedValues[0] = 0
 
  106         expectedValues[1] = 0
 
  107         expectedValues[2] = 0
 
  108         actualValues = graph.optimize()
 
  109         self.assertEqual(
list(actualValues.items()), 
list(expectedValues.items()))
 
  112         """Test maximum probable explanation (MPE): same as optimize.""" 
  115         C, A, B = (0, 2), (1, 2), (2, 2)
 
  119         graph.add([C, A], 
"0.2 0.8 0.3 0.7")
 
  120         graph.add([C, B], 
"0.1 0.9 0.4 0.6")
 
  129         dag = graph.maxProduct(OrderingType.COLAMD)
 
  130         actualMPE = dag.argmax()
 
  131         self.assertEqual(
list(actualMPE.items()), 
list(mpe.items()))
 
  134         actualMPE2 = graph.optimize()
 
  135         self.assertEqual(
list(actualMPE2.items()), 
list(mpe.items()))
 
  138         """Test sumProduct.""" 
  141         C, A, B = (0, 2), (1, 2), (2, 2)
 
  145         graph.add([C, A], 
"0.2 0.8 0.3 0.7")
 
  146         graph.add([C, B], 
"0.1 0.9 0.4 0.6")
 
  155         bayesNet = graph.sumProduct()
 
  157         self.assertAlmostEqual(mpeProbability, 0.36)  
 
  160         for ordering_type 
in [
 
  163             OrderingType.NATURAL,
 
  166             bayesNet = graph.sumProduct(ordering_type)
 
  167             self.assertEqual(
bayesNet(mpe), mpeProbability)
 
  173         Test for numerical underflow in EliminateMPE on long chains. 
  174         Adapted from the toy problem of @pcl15423 
  175         Ref: https://github.com/borglab/gtsam/issues/1448 
  183         Z = {index: 
make_key(
"Z", index, num_obs + 1) 
for index 
in range(num_obs)}
 
  187         for i 
in reversed(
range(1, num_obs)):
 
  189                 X[i], [X[i - 1]], transition_cpt
 
  191             graph.push_back(transition_conditional)
 
  197         for i 
in range(num_obs):
 
  199             factor = obs_conditional.likelihood(i)
 
  200             graph.push_back(factor)
 
  202         mpe = graph.optimize()
 
  203         vals = [mpe[X[i][0]] 
for i 
in range(num_obs)]
 
  205         self.assertEqual(vals, [desired_state] * num_obs)
 
  209         Test for numerical underflow in EliminateDiscrete on long chains. 
  210         Adapted from the toy problem of @pcl15423 
  211         Ref: https://github.com/borglab/gtsam/issues/1448 
  217         X = {index: 
make_key(
"X", index, 
len(states)) 
for index 
in range(chain_length)}
 
  221         transitions = np.diag([1.0, 0.5, 0.1])
 
  222         transitions += 0.1/(num_states)
 
  225         transitions /= np.sum(transitions, axis=0)
 
  228         eigvals, eigvecs = np.linalg.eig(transitions)
 
  229         stationary_idx = np.where(np.isclose(eigvals, 1.0))
 
  230         stationary_dist = eigvecs[:, stationary_idx]
 
  233         stationary_dist /= np.sum(stationary_dist)
 
  239         for i 
in reversed(
range(1, chain_length)):
 
  241                 X[i], [X[i - 1]], transition_cpt
 
  243             graph.push_back(transition_conditional)
 
  247         sum_product = graph.sumProduct(OrderingType.NATURAL)
 
  250         last_marginal = sum_product.at(chain_length - 1)
 
  256 if __name__ == 
"__main__":