test_HybridBayesNet.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Hybrid Values.
9 Author: Frank Dellaert
10 """
11 # pylint: disable=invalid-name, no-name-in-module, no-member
12 
13 import math
14 import unittest
15 
16 import numpy as np
17 from gtsam.symbol_shorthand import A, X
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 from gtsam import (DiscreteConditional, DiscreteValues,
21  GaussianConditional, HybridBayesNet,
22  HybridGaussianConditional, HybridValues, VectorValues,
23  noiseModel)
24 
25 
27  """Unit tests for HybridValues."""
28 
29  def test_evaluate(self):
30  """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
31  asiaKey = A(0)
32  Asia = (asiaKey, 2)
33 
34  # Create the continuous conditional
35  I_1x1 = np.eye(1)
36  conditional = GaussianConditional.FromMeanAndStddev(
37  X(0), 2 * I_1x1, X(1), [-4], 5.0)
38 
39  # Create the noise models
40  model0 = noiseModel.Diagonal.Sigmas([2.0])
41  model1 = noiseModel.Diagonal.Sigmas([3.0])
42 
43  # Create the conditionals
44  conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
45  conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
46 
47  # Create hybrid Bayes net.
48  bayesNet = HybridBayesNet()
49  bayesNet.push_back(conditional)
50  bayesNet.push_back(
51  HybridGaussianConditional(Asia, [conditional0, conditional1]))
52  bayesNet.push_back(DiscreteConditional(Asia, "99/1"))
53 
54  # Create values at which to evaluate.
55  values = HybridValues()
56  continuous = VectorValues()
57  continuous.insert(X(0), [-6])
58  continuous.insert(X(1), [1])
59  values.insert(continuous)
60  discrete = DiscreteValues()
61  discrete[asiaKey] = 0
62  values.insert(discrete)
63 
64  conditionalProbability = conditional.evaluate(values.continuous())
65  mixtureProbability = conditional0.evaluate(values.continuous())
66  self.assertAlmostEqual(conditionalProbability * mixtureProbability *
67  0.99,
68  bayesNet.evaluate(values),
69  places=5)
70 
71  # Check logProbability
72  self.assertAlmostEqual(bayesNet.logProbability(values),
73  math.log(bayesNet.evaluate(values)))
74 
75  # Check invariance for all conditionals:
76  self.check_invariance(bayesNet.at(0).asGaussian(), continuous)
77  self.check_invariance(bayesNet.at(0).asGaussian(), values)
78  self.check_invariance(bayesNet.at(0), values)
79 
80  self.check_invariance(bayesNet.at(1), values)
81 
82  self.check_invariance(bayesNet.at(2).asDiscrete(), discrete)
83  self.check_invariance(bayesNet.at(2).asDiscrete(), values)
84  self.check_invariance(bayesNet.at(2), values)
85 
86  def check_invariance(self, conditional, values):
87  """Check invariance for given conditional."""
88  probability = conditional.evaluate(values)
89  self.assertTrue(probability >= 0.0)
90  logProb = conditional.logProbability(values)
91  self.assertAlmostEqual(probability, np.exp(logProb))
92  expected = -(conditional.negLogConstant() + conditional.error(values))
93  self.assertAlmostEqual(logProb, expected)
94 
95 
96 if __name__ == "__main__":
97  unittest.main()
gtsam::HybridValues
Definition: HybridValues.h:37
test_HybridBayesNet.TestHybridBayesNet
Definition: test_HybridBayesNet.py:26
gtsam::HybridBayesNet
Definition: HybridBayesNet.h:37
gtsam::symbol_shorthand
Definition: inference/Symbol.h:147
X
#define X
Definition: icosphere.cpp:20
gtsam::VectorValues
Definition: VectorValues.h:74
test_HybridBayesNet.TestHybridBayesNet.test_evaluate
def test_evaluate(self)
Definition: test_HybridBayesNet.py:29
gtsam::GaussianConditional
Definition: GaussianConditional.h:40
gtsam::HybridGaussianConditional
A conditional of gaussian conditionals indexed by discrete variables, as part of a Bayes Network....
Definition: HybridGaussianConditional.h:54
gtsam::utils.test_case
Definition: test_case.py:1
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_HybridBayesNet.TestHybridBayesNet.check_invariance
def check_invariance(self, conditional, values)
Definition: test_HybridBayesNet.py:86


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:06:54