test_DecisionTreeFactor.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for DecisionTreeFactors.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
17  Ordering)
18 from gtsam.utils.test_case import GtsamTestCase
19 
20 
22  """Tests for DecisionTreeFactors."""
23 
24  def setUp(self):
25  self.A = (12, 3)
26  self.B = (5, 2)
27  self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
28 
29  def test_enumerate(self):
30  actual = self.factor.enumerate()
31  _, values = zip(*actual)
32  self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
33 
35  """Test whether multiplication works with overloading."""
36  v0 = (0, 2)
37  v1 = (1, 2)
38  v2 = (2, 2)
39 
40  # Multiply with a DiscreteDistribution, i.e., Bayes Law!
41  prior = DiscreteDistribution(v1, [1, 3])
42  f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
43  expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
44  self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
45  self.gtsamAssertEquals(f1 * prior, expected)
46 
47  # Multiply two factors
48  f2 = DecisionTreeFactor([v1, v2], "5 6 7 8")
49  actual = f1 * f2
50  expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32")
51  self.gtsamAssertEquals(actual, expected2)
52 
53  def test_methods(self):
54  """Test whether we can call methods in python."""
55  # double operator()(const DiscreteValues& values) const;
56  values = DiscreteValues()
57  values[self.A[0]] = 0
58  values[self.B[0]] = 0
59  self.assertIsInstance(self.factor(values), float)
60 
61  # size_t cardinality(Key j) const;
62  self.assertIsInstance(self.factor.cardinality(self.A[0]), int)
63 
64  # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const;
65  self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor)
66 
67  # DecisionTreeFactor* sum(size_t nrFrontals) const;
68  self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor)
69 
70  # DecisionTreeFactor* sum(const Ordering& keys) const;
71  ordering = Ordering()
72  ordering.push_back(self.A[0])
73  self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor)
74 
75  # DecisionTreeFactor* max(size_t nrFrontals) const;
76  self.assertIsInstance(self.factor.max(1), DecisionTreeFactor)
77 
78  def test_markdown(self):
79  """Test whether the _repr_markdown_ method."""
80 
81  expected = \
82  "|A|B|value|\n" \
83  "|:-:|:-:|:-:|\n" \
84  "|0|0|1|\n" \
85  "|0|1|2|\n" \
86  "|1|0|3|\n" \
87  "|1|1|4|\n" \
88  "|2|0|5|\n" \
89  "|2|1|6|\n"
90 
91  def formatter(x: int):
92  return "A" if x == 12 else "B"
93 
94  actual = self.factor._repr_markdown_(formatter)
95  self.assertEqual(actual, expected)
96 
97 
98 if __name__ == "__main__":
99  unittest.main()
#define max(a, b)
Definition: datatypes.h:20
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:18
const KeyFormatter & formatter
Definition: pytypes.h:1979


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:37:45