2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, 
    3 Atlanta, Georgia 30332-0415 
    6 See LICENSE for the license information 
    8 Unit tests for Discrete Bayes trees. 
   21 from gtsam 
import (DiscreteBayesNet, DiscreteBayesTreeClique,
 
   22                    DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
 
   27     """Tests for Discrete Bayes Nets.""" 
   30         """Test Multifrontal elimination.""" 
   33         keys = [(j, 2) 
for j 
in range(15)]
 
   38         bayesNet.add(keys[0], [keys[8], keys[12]], 
"2/3 1/4 3/2 4/1")
 
   39         bayesNet.add(keys[1], [keys[8], keys[12]], 
"4/1 2/3 3/2 1/4")
 
   40         bayesNet.add(keys[2], [keys[9], keys[12]], 
"1/4 8/2 2/3 4/1")
 
   41         bayesNet.add(keys[3], [keys[9], keys[12]], 
"1/4 2/3 3/2 4/1")
 
   43         bayesNet.add(keys[4], [keys[10], keys[13]], 
"2/3 1/4 3/2 4/1")
 
   44         bayesNet.add(keys[5], [keys[10], keys[13]], 
"4/1 2/3 3/2 1/4")
 
   45         bayesNet.add(keys[6], [keys[11], keys[13]], 
"1/4 3/2 2/3 4/1")
 
   46         bayesNet.add(keys[7], [keys[11], keys[13]], 
"1/4 2/3 3/2 4/1")
 
   48         bayesNet.add(keys[8], [keys[12], keys[14]], 
"T 1/4 3/2 4/1")
 
   49         bayesNet.add(keys[9], [keys[12], keys[14]], 
"4/1 2/3 F 1/4")
 
   50         bayesNet.add(keys[10], [keys[13], keys[14]], 
"1/4 3/2 2/3 4/1")
 
   51         bayesNet.add(keys[11], [keys[13], keys[14]], 
"1/4 2/3 3/2 4/1")
 
   53         bayesNet.add(keys[12], [keys[14]], 
"3/1 3/1")
 
   54         bayesNet.add(keys[13], [keys[14]], 
"1/3 3/1")
 
   56         bayesNet.add(keys[14], 
"1/3")
 
   65         bayesTree = factorGraph.eliminateMultifrontal(ordering)
 
   75         self.assertIsInstance(root, DiscreteBayesTreeClique)
 
   76         self.assertTrue(root.isRoot())
 
   77         self.assertIsInstance(root.conditional(), DiscreteConditional)
 
   86         value_at_zeros = bayesTree.evaluate(zero_values)
 
   87         self.assertAlmostEqual(value_at_zeros, 0.0)
 
   90         values_star = factorGraph.optimize()
 
   91         max_value = bayesTree.evaluate(values_star)
 
   92         self.assertAlmostEqual(max_value, 0.002548)
 
   96         self.assertAlmostEqual(max_value, 0.002548)
 
   98         self.assertFalse(bayesTree.empty())
 
   99         self.assertEqual(12, bayesTree.size())
 
  101     @unittest.skip(
"Too Slow")
 
  103         """Check that we can have a multi-frontal lookup table.""" 
  106         x1, x2, x3 = (
X(1), 3), (
X(2), 3), (
X(3), 3)
 
  107         a1, a2 = (A(1), 2), (A(2), 2)
 
  110         graph.add([x1], np.array([1, 0, 0]))
 
  111         graph.add([x3], np.array([0, 0, 1]))
 
  121         graph.add([x1, a1, x2], table)
 
  122         graph.add([x2, a2, x3], table)
 
  125         ordering = 
Ordering(keys=[A(2), 
X(3), 
X(1), A(1), 
X(2)])
 
  126         lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
 
  129         assert lookup.size() == 2
 
  130         lookup_x1_a1_x2 = lookup[
X(1)].conditional()
 
  131         assert lookup_x1_a1_x2.nrFrontals() == 3
 
  134         self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
 
  140         self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
 
  142         lookup_a2_x3 = lookup[
X(3)].conditional()
 
  144         sum_x2 = lookup_a2_x3.sum(2)
 
  147         self.assertAlmostEqual(sum_x2(values), 0)
 
  149         self.assertAlmostEqual(sum_x2(values), 1.0)  
 
  151         self.assertAlmostEqual(sum_x2(values), 2.0)  
 
  152         assert lookup_a2_x3.nrFrontals() == 2
 
  158         self.assertAlmostEqual(lookup_a2_x3(values), 1.0)  
 
  161         """Test creating a Bayes tree directly from cliques.""" 
  164         A, B, C = (0, 2), (1, 2), (2, 2)
 
  165         bayesNet.add(A, 
"1/3")
 
  166         bayesNet.add(B, [A], 
"1/3 3/1")
 
  167         bayesNet.add(C, [B], 
"3/1 3/1")
 
  176         bayesTree.insertRoot(clique2)
 
  177         bayesTree.addClique(clique1, clique2)
 
  178         bayesTree.addClique(clique0, clique1)
 
  188         self.assertAlmostEqual(expected, bayesNet.evaluate(values))
 
  191 if __name__ == 
"__main__":