test_DecisionTreeFactor.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for DecisionTreeFactors.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 from gtsam.utils.test_case import GtsamTestCase
17 
18 from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
19  Ordering)
20 
21 
23  """Tests for DecisionTreeFactors."""
24 
25  def setUp(self):
26  self.A = (12, 3)
27  self.B = (5, 2)
28  self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
29 
30  def test_from_floats(self):
31  """Test whether we can construct a factor from floats."""
32  actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.])
33  self.gtsamAssertEquals(actual, self.factor)
34 
35  def test_enumerate(self):
36  """Test whether we can enumerate the factor."""
37  actual = self.factor.enumerate()
38  _, values = zip(*actual)
39  self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
40 
42  """Test whether multiplication works with overloading."""
43  v0 = (0, 2)
44  v1 = (1, 2)
45  v2 = (2, 2)
46 
47  # Multiply with a DiscreteDistribution, i.e., Bayes Law!
48  prior = DiscreteDistribution(v1, [1, 3])
49  f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
50  expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
51  self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
52  self.gtsamAssertEquals(f1 * prior, expected)
53 
54  # Multiply two factors
55  f2 = DecisionTreeFactor([v1, v2], "5 6 7 8")
56  actual = f1 * f2
57  expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32")
58  self.gtsamAssertEquals(actual, expected2)
59 
60  def test_methods(self):
61  """Test whether we can call methods in python."""
62  # double operator()(const DiscreteValues& values) const;
63  values = DiscreteValues()
64  values[self.A[0]] = 0
65  values[self.B[0]] = 0
66  self.assertIsInstance(self.factor(values), float)
67 
68  # size_t cardinality(Key j) const;
69  self.assertIsInstance(self.factor.cardinality(self.A[0]), int)
70 
71  # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const;
72  self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor)
73 
74  # DecisionTreeFactor* sum(size_t nrFrontals) const;
75  self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor)
76 
77  # DecisionTreeFactor* sum(const Ordering& keys) const;
78  ordering = Ordering()
79  ordering.push_back(self.A[0])
80  self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor)
81 
82  # DecisionTreeFactor* max(size_t nrFrontals) const;
83  self.assertIsInstance(self.factor.max(1), DecisionTreeFactor)
84 
85  def test_markdown(self):
86  """Test whether the _repr_markdown_ method."""
87 
88  expected = \
89  "|A|B|value|\n" \
90  "|:-:|:-:|:-:|\n" \
91  "|0|0|1|\n" \
92  "|0|1|2|\n" \
93  "|1|0|3|\n" \
94  "|1|1|4|\n" \
95  "|2|0|5|\n" \
96  "|2|1|6|\n"
97 
98  def formatter(x: int):
99  return "A" if x == 12 else "B"
100 
101  actual = self.factor._repr_markdown_(formatter)
102  self.assertEqual(actual, expected)
103 
104 
105 if __name__ == "__main__":
106  unittest.main()
test_DecisionTreeFactor.TestDecisionTreeFactor.test_enumerate
def test_enumerate(self)
Definition: test_DecisionTreeFactor.py:35
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
test_DecisionTreeFactor.TestDecisionTreeFactor.B
B
Definition: test_DecisionTreeFactor.py:27
list
Definition: pytypes.h:2166
gtsam::DiscreteDistribution
Definition: DiscreteDistribution.h:33
test_DecisionTreeFactor.TestDecisionTreeFactor.factor
factor
Definition: test_DecisionTreeFactor.py:28
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::utils.test_case.GtsamTestCase.gtsamAssertEquals
def gtsamAssertEquals(self, actual, expected, tol=1e-9)
Definition: test_case.py:19
test_DecisionTreeFactor.TestDecisionTreeFactor.setUp
def setUp(self)
Definition: test_DecisionTreeFactor.py:25
test_DecisionTreeFactor.TestDecisionTreeFactor.test_multiplication
def test_multiplication(self)
Definition: test_DecisionTreeFactor.py:41
gtsam::utils.test_case
Definition: test_case.py:1
test_DecisionTreeFactor.TestDecisionTreeFactor
Definition: test_DecisionTreeFactor.py:22
test_DecisionTreeFactor.TestDecisionTreeFactor.test_methods
def test_methods(self)
Definition: test_DecisionTreeFactor.py:60
test_DecisionTreeFactor.TestDecisionTreeFactor.test_from_floats
def test_from_floats(self)
Definition: test_DecisionTreeFactor.py:30
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_DecisionTreeFactor.TestDecisionTreeFactor.A
A
Definition: test_DecisionTreeFactor.py:26
max
#define max(a, b)
Definition: datatypes.h:20
gtsam::Ordering
Definition: inference/Ordering.h:33
test_DecisionTreeFactor.TestDecisionTreeFactor.test_markdown
def test_markdown(self)
Definition: test_DecisionTreeFactor.py:85


gtsam
Author(s):
autogenerated on Fri Jan 10 2025 04:06:51