test_DiscreteConditional.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Conditionals.
9 Author: Varun Agrawal
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
17 from gtsam.utils.test_case import GtsamTestCase
18 
19 # Some DiscreteKeys for binary variables:
20 A = 0, 2
21 B = 1, 2
22 C = 2, 2
23 D = 4, 2
24 E = 3, 2
25 
26 
28  """Tests for Discrete Conditionals."""
29 
31  X = (0, 2)
32  Y = (1, 3)
33  conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
34 
35  actual0 = conditional.likelihood(0)
36  expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
37  self.gtsamAssertEquals(actual0, expected0, 1e-9)
38 
39  actual1 = conditional.likelihood(1)
40  expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5")
41  self.gtsamAssertEquals(actual1, expected1, 1e-9)
42 
43  actual = conditional.sample(2)
44  self.assertIsInstance(actual, int)
45 
46  def test_multiply(self):
47  """Check calculation of joint P(A,B)"""
48  conditional = DiscreteConditional(A, [B], "1/2 2/1")
49  prior = DiscreteConditional(B, "1/2")
50 
51  # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
52  for actual in [prior * conditional, conditional * prior]:
53  self.assertEqual(2, actual.nrFrontals())
54  for v, value in actual.enumerate():
55  self.assertAlmostEqual(actual(v), conditional(v) * prior(v))
56 
57  def test_multiply2(self):
58  """Check calculation of conditional joint P(A,B|C)"""
59  A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
60  B_given_C = DiscreteConditional(B, [C], "1/3 3/1")
61 
62  # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
63  for actual in [A_given_B * B_given_C, B_given_C * A_given_B]:
64  self.assertEqual(2, actual.nrFrontals())
65  self.assertEqual(1, actual.nrParents())
66  for v, value in actual.enumerate():
67  self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v))
68 
69  def test_multiply4(self):
70  """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)"""
71  A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
72  B_given_D = DiscreteConditional(B, [D], "1/3 3/1")
73  AB_given_D = A_given_B * B_given_D
74  C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4")
75 
76  # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
77  for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]:
78  self.assertEqual(3, actual.nrFrontals())
79  self.assertEqual(2, actual.nrParents())
80  for v, value in actual.enumerate():
81  self.assertAlmostEqual(
82  actual(v), AB_given_D(v) * C_given_DE(v))
83 
84  def test_marginals(self):
85  conditional = DiscreteConditional(A, [B], "1/2 2/1")
86  prior = DiscreteConditional(B, "1/2")
87  pAB = prior * conditional
88  self.gtsamAssertEquals(prior, pAB.marginal(B[0]))
89 
90  pA = DiscreteConditional(A, "5/4")
91  self.gtsamAssertEquals(pA, pAB.marginal(A[0]))
92 
93  def test_markdown(self):
94  """Test whether the _repr_markdown_ method."""
95 
96  A = (2, 2)
97  B = (1, 2)
98  C = (0, 3)
99  parents = DiscreteKeys()
100  parents.push_back(B)
101  parents.push_back(C)
102 
103  conditional = DiscreteConditional(A, parents,
104  "0/1 1/3 1/1 3/1 0/1 1/0")
105  expected = " *P(A|B,C):*\n\n" \
106  "|*B*|*C*|0|1|\n" \
107  "|:-:|:-:|:-:|:-:|\n" \
108  "|0|0|0|1|\n" \
109  "|0|1|0.25|0.75|\n" \
110  "|0|2|0.5|0.5|\n" \
111  "|1|0|0.75|0.25|\n" \
112  "|1|1|0|1|\n" \
113  "|1|2|1|0|\n"
114 
115  def formatter(x: int):
116  names = ["C", "B", "A"]
117  return names[x]
118 
119  actual = conditional._repr_markdown_(formatter)
120  self.assertEqual(actual, expected)
121 
122 
123 if __name__ == "__main__":
124  unittest.main()
test_DiscreteConditional.TestDiscreteConditional.test_multiply
def test_multiply(self)
Definition: test_DiscreteConditional.py:46
test_DiscreteConditional.TestDiscreteConditional.test_markdown
def test_markdown(self)
Definition: test_DiscreteConditional.py:93
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
test_DiscreteConditional.TestDiscreteConditional.test_marginals
def test_marginals(self)
Definition: test_DiscreteConditional.py:84
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:18
test_DiscreteConditional.TestDiscreteConditional
Definition: test_DiscreteConditional.py:27
gtsam::utils.test_case
Definition: test_case.py:1
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
test_DiscreteConditional.TestDiscreteConditional.test_multiply2
def test_multiply2(self)
Definition: test_DiscreteConditional.py:57
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:15
different_sigmas::prior
const auto prior
Definition: testHybridBayesNet.cpp:238
test_DiscreteConditional.TestDiscreteConditional.test_single_value_versions
def test_single_value_versions(self)
Definition: test_DiscreteConditional.py:30
test_DiscreteConditional.TestDiscreteConditional.test_multiply4
def test_multiply4(self)
Definition: test_DiscreteConditional.py:69


gtsam
Author(s):
autogenerated on Sat Nov 16 2024 04:07:05