test_DiscreteConditional.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Conditionals.
9 Author: Varun Agrawal
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 from gtsam.utils.test_case import GtsamTestCase
17 
18 from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
19 
20 # Some DiscreteKeys for binary variables:
21 A = 0, 2
22 B = 1, 2
23 C = 2, 2
24 D = 4, 2
25 E = 3, 2
26 
27 
29  """Tests for Discrete Conditionals."""
30 
32  X = (0, 2)
33  Y = (1, 3)
34  conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
35 
36  actual0 = conditional.likelihood(0)
37  expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
38  self.gtsamAssertEquals(actual0, expected0, 1e-9)
39 
40  actual1 = conditional.likelihood(1)
41  expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
42  self.gtsamAssertEquals(actual1, expected1, 1e-9)
43 
44  actual = conditional.sample(2)
45  self.assertIsInstance(actual, int)
46 
47  def test_multiply(self):
48  """Check calculation of joint P(A,B)"""
49  conditional = DiscreteConditional(A, [B], "1/2 2/1")
50  prior = DiscreteConditional(B, "1/2")
51 
52  # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
53  for actual in [prior * conditional, conditional * prior]:
54  self.assertEqual(2, actual.nrFrontals())
55  for v, value in actual.enumerate():
56  self.assertAlmostEqual(actual(v), conditional(v) * prior(v))
57 
58  def test_multiply2(self):
59  """Check calculation of conditional joint P(A,B|C)"""
60  A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
61  B_given_C = DiscreteConditional(B, [C], "1/3 3/1")
62 
63  # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
64  for actual in [A_given_B * B_given_C, B_given_C * A_given_B]:
65  self.assertEqual(2, actual.nrFrontals())
66  self.assertEqual(1, actual.nrParents())
67  for v, value in actual.enumerate():
68  self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v))
69 
70  def test_multiply4(self):
71  """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)"""
72  A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
73  B_given_D = DiscreteConditional(B, [D], "1/3 3/1")
74  AB_given_D = A_given_B * B_given_D
75  C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4")
76 
77  # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
78  for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]:
79  self.assertEqual(3, actual.nrFrontals())
80  self.assertEqual(2, actual.nrParents())
81  for v, value in actual.enumerate():
82  self.assertAlmostEqual(
83  actual(v), AB_given_D(v) * C_given_DE(v))
84 
85  def test_marginals(self):
86  conditional = DiscreteConditional(A, [B], "1/2 2/1")
87  prior = DiscreteConditional(B, "1/2")
88  pAB = prior * conditional
89  self.gtsamAssertEquals(prior, pAB.marginal(B[0]))
90 
91  pA = DiscreteConditional(A, "5/4")
92  self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
93 
94  def test_markdown(self):
95  """Test whether the _repr_markdown_ method."""
96 
97  A = (2, 2)
98  B = (1, 2)
99  C = (0, 3)
100  parents = DiscreteKeys()
101  parents.push_back(B)
102  parents.push_back(C)
103 
104  conditional = DiscreteConditional(A, parents,
105  "0/1 1/3 1/1 3/1 0/1 1/0")
106  expected = " *P(A|B,C):*\n\n" \
107  "|*B*|*C*|0|1|\n" \
108  "|:-:|:-:|:-:|:-:|\n" \
109  "|0|0|0|1|\n" \
110  "|0|1|0.25|0.75|\n" \
111  "|0|2|0.5|0.5|\n" \
112  "|1|0|0.75|0.25|\n" \
113  "|1|1|0|1|\n" \
114  "|1|2|1|0|\n"
115 
116  def formatter(x: int):
117  names = ["C", "B", "A"]
118  return names[x]
119 
120  actual = conditional._repr_markdown_(formatter)
121  self.assertEqual(actual, expected)
122 
123 
124 if __name__ == "__main__":
125  unittest.main()
test_DiscreteConditional.TestDiscreteConditional.test_multiply
def test_multiply(self)
Definition: test_DiscreteConditional.py:47
test_DiscreteConditional.TestDiscreteConditional.test_markdown
def test_markdown(self)
Definition: test_DiscreteConditional.py:94
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
test_DiscreteConditional.TestDiscreteConditional.test_marginals
def test_marginals(self)
Definition: test_DiscreteConditional.py:85
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
test_DiscreteConditional.TestDiscreteConditional
Definition: test_DiscreteConditional.py:28
gtsam::utils.test_case
Definition: test_case.py:1
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
test_DiscreteConditional.TestDiscreteConditional.test_multiply2
def test_multiply2(self)
Definition: test_DiscreteConditional.py:58
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
different_sigmas::prior
const auto prior
Definition: testHybridBayesNet.cpp:238
test_DiscreteConditional.TestDiscreteConditional.test_single_value_versions
def test_single_value_versions(self)
Definition: test_DiscreteConditional.py:31
test_DiscreteConditional.TestDiscreteConditional.test_multiply4
def test_multiply4(self)
Definition: test_DiscreteConditional.py:70


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:06:54