test_HybridBayesNet.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Hybrid Values.
9 Author: Frank Dellaert
10 """
11 # pylint: disable=invalid-name, no-name-in-module, no-member
12 
13 import math
14 import unittest
15 
16 import numpy as np
17 from gtsam.symbol_shorthand import A, X
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues,
21  GaussianConditional, GaussianMixture, HybridBayesNet,
22  HybridValues, VectorValues, noiseModel)
23 
24 
26  """Unit tests for HybridValues."""
27 
28  def test_evaluate(self):
29  """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
30  asiaKey = A(0)
31  Asia = (asiaKey, 2)
32 
33  # Create the continuous conditional
34  I_1x1 = np.eye(1)
35  conditional = GaussianConditional.FromMeanAndStddev(
36  X(0), 2 * I_1x1, X(1), [-4], 5.0)
37 
38  # Create the noise models
39  model0 = noiseModel.Diagonal.Sigmas([2.0])
40  model1 = noiseModel.Diagonal.Sigmas([3.0])
41 
42  # Create the conditionals
43  conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
44  conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
45  discrete_keys = DiscreteKeys()
46  discrete_keys.push_back(Asia)
47 
48  # Create hybrid Bayes net.
49  bayesNet = HybridBayesNet()
50  bayesNet.push_back(conditional)
51  bayesNet.push_back(
52  GaussianMixture([X(1)], [], discrete_keys,
53  [conditional0, conditional1]))
54  bayesNet.push_back(DiscreteConditional(Asia, "99/1"))
55 
56  # Create values at which to evaluate.
57  values = HybridValues()
58  continuous = VectorValues()
59  continuous.insert(X(0), [-6])
60  continuous.insert(X(1), [1])
61  values.insert(continuous)
62  discrete = DiscreteValues()
63  discrete[asiaKey] = 0
64  values.insert(discrete)
65 
66  conditionalProbability = conditional.evaluate(values.continuous())
67  mixtureProbability = conditional0.evaluate(values.continuous())
68  self.assertAlmostEqual(conditionalProbability * mixtureProbability *
69  0.99,
70  bayesNet.evaluate(values),
71  places=5)
72 
73  # Check logProbability
74  self.assertAlmostEqual(bayesNet.logProbability(values),
75  math.log(bayesNet.evaluate(values)))
76 
77  # Check invariance for all conditionals:
78  self.check_invariance(bayesNet.at(0).asGaussian(), continuous)
79  self.check_invariance(bayesNet.at(0).asGaussian(), values)
80  self.check_invariance(bayesNet.at(0), values)
81 
82  self.check_invariance(bayesNet.at(1), values)
83 
84  self.check_invariance(bayesNet.at(2).asDiscrete(), discrete)
85  self.check_invariance(bayesNet.at(2).asDiscrete(), values)
86  self.check_invariance(bayesNet.at(2), values)
87 
88  def check_invariance(self, conditional, values):
89  """Check invariance for given conditional."""
90  probability = conditional.evaluate(values)
91  self.assertTrue(probability >= 0.0)
92  logProb = conditional.logProbability(values)
93  self.assertAlmostEqual(probability, np.exp(logProb))
94  expected = conditional.logNormalizationConstant() - \
95  conditional.error(values)
96  self.assertAlmostEqual(logProb, expected)
97 
98 
99 if __name__ == "__main__":
100  unittest.main()
def check_invariance(self, conditional, values)
#define X
Definition: icosphere.cpp:20


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:37:45