testDecisionTreeFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testDecisionTreeFactor.cpp
14  *
15  * @date Feb 5, 2012
16  * @author Frank Dellaert
17  * @author Duy-Nguyen Ta
18  */
19 
21 #include <gtsam/base/Testable.h>
26 
27 using namespace std;
28 using namespace gtsam;
29 
30 /* ************************************************************************* */
31 TEST(DecisionTreeFactor, ConstructorsMatch) {
32  // Declare two keys
33  DiscreteKey X(0, 2), Y(1, 3);
34 
35  // Create with vector and with string
36  const std::vector<double> table {2, 5, 3, 6, 4, 7};
38  DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
40 }
41 
42 /* ************************************************************************* */
43 TEST( DecisionTreeFactor, constructors)
44 {
45  // Declare a bunch of keys
46  DiscreteKey X(0,2), Y(1,3), Z(2,2);
47 
48  // Create factors
49  DecisionTreeFactor f1(X, {2, 8});
50  DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
51  DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
52  EXPECT_LONGS_EQUAL(1,f1.size());
53  EXPECT_LONGS_EQUAL(2,f2.size());
54  EXPECT_LONGS_EQUAL(3,f3.size());
55 
56  DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
57  EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
58  EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
59  EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
60 
61  // Assert that error = -log(value)
62  EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
63 
64  // Construct from DiscreteConditional
65  DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
66  DecisionTreeFactor f4(conditional);
67  EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
68 }
69 
70 /* ************************************************************************* */
72  // Declare a bunch of keys
73  DiscreteKey X(0,2), Y(1,3), Z(2,2);
74 
75  // Create factors
76  DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
77 
78  auto errors = f.errorTree();
79  // regression
81  {X, Y, Z},
82  vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
83  -1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
84  -4.1743873, -3.8066625, -4.3174881});
85  EXPECT(assert_equal(expected, errors, 1e-6));
86 }
87 
88 /* ************************************************************************* */
89 TEST(DecisionTreeFactor, multiplication) {
90  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
91 
92  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
93  DiscreteDistribution prior(v1 % "1/3");
94  DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
95  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
98 
99  // Multiply two factors
100  DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
101  DecisionTreeFactor actual = f1 * f2;
102  DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
103  CHECK(assert_equal(expected2, actual));
104 }
105 
106 /* ************************************************************************* */
108 {
109  DiscreteKey v0(0,3), v1(1,2);
110  DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
111 
112  DecisionTreeFactor expected(v1, "9 12");
113  DecisionTreeFactor::shared_ptr actual = f1.sum(1);
114  CHECK(assert_equal(expected, *actual, 1e-5));
115 
116  DecisionTreeFactor expected2(v1, "5 6");
117  DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
118  CHECK(assert_equal(expected2, *actual2));
119 
120  DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
121  DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
122 }
123 
124 /* ************************************************************************* */
125 // Check enumerate yields the correct list of assignment/value pairs.
127  DiscreteKey A(12, 3), B(5, 2);
128  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
129  auto actual = f.enumerate();
130  std::vector<std::pair<DiscreteValues, double>> expected;
132  for (size_t a : {0, 1, 2}) {
133  for (size_t b : {0, 1}) {
134  values[12] = a;
135  values[5] = b;
136  expected.emplace_back(values, f(values));
137  }
138  }
139  EXPECT(actual == expected);
140 }
141 
142 /* ************************************************************************* */
143 // Check pruning of the decision tree works as expected.
145  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
146  DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
147 
148  // Only keep the leaves with the top 5 values.
149  size_t maxNrAssignments = 5;
150  auto pruned5 = f.prune(maxNrAssignments);
151 
152  // Pruned leaves should be 0
153  DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
154  EXPECT(assert_equal(expected, pruned5));
155 
156  // Check for more extreme pruning where we only keep the top 2 leaves
157  maxNrAssignments = 2;
158  auto pruned2 = f.prune(maxNrAssignments);
159  DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
160  EXPECT(assert_equal(expected2, pruned2));
161 
162  DiscreteKey D(4, 2);
163  DecisionTreeFactor factor(
164  D & C & B & A,
165  "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
166  "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
167 
168  DecisionTreeFactor expected3(
169  D & C & B & A,
170  "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
171  "0.999952870000 1.0 1.0 1.0 1.0");
172  maxNrAssignments = 5;
173  auto pruned3 = factor.prune(maxNrAssignments);
174  EXPECT(assert_equal(expected3, pruned3));
175 }
176 
177 /* ************************************************************************* */
178 TEST(DecisionTreeFactor, DotWithNames) {
179  DiscreteKey A(12, 3), B(5, 2);
180  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
181  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
182 
183  for (bool showZero:{true, false}) {
184  string actual = f.dot(formatter, showZero);
185  // pretty weak test, as ids are pointers and not stable across platforms.
186  string expected = "digraph G {";
187  EXPECT(actual.substr(0, 11) == expected);
188  }
189 }
190 
191 /* ************************************************************************* */
192 // Check markdown representation looks as expected.
194  DiscreteKey A(12, 3), B(5, 2);
195  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
196  string expected =
197  "|A|B|value|\n"
198  "|:-:|:-:|:-:|\n"
199  "|0|0|1|\n"
200  "|0|1|2|\n"
201  "|1|0|3|\n"
202  "|1|1|4|\n"
203  "|2|0|5|\n"
204  "|2|1|6|\n";
205  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
206  string actual = f.markdown(formatter);
207  EXPECT(actual == expected);
208 }
209 
210 /* ************************************************************************* */
211 // Check markdown representation with a value formatter.
212 TEST(DecisionTreeFactor, markdownWithValueFormatter) {
213  DiscreteKey A(12, 3), B(5, 2);
214  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
215  string expected =
216  "|A|B|value|\n"
217  "|:-:|:-:|:-:|\n"
218  "|Zero|-|1|\n"
219  "|Zero|+|2|\n"
220  "|One|-|3|\n"
221  "|One|+|4|\n"
222  "|Two|-|5|\n"
223  "|Two|+|6|\n";
224  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
225  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
226  {5, {"-", "+"}}};
227  string actual = f.markdown(keyFormatter, names);
228  EXPECT(actual == expected);
229 }
230 
231 /* ************************************************************************* */
232 // Check html representation with a value formatter.
233 TEST(DecisionTreeFactor, htmlWithValueFormatter) {
234  DiscreteKey A(12, 3), B(5, 2);
235  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
236  string expected =
237  "<div>\n"
238  "<table class='DecisionTreeFactor'>\n"
239  " <thead>\n"
240  " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
241  " </thead>\n"
242  " <tbody>\n"
243  " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
244  " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
245  " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
246  " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
247  " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
248  " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
249  " </tbody>\n"
250  "</table>\n"
251  "</div>";
252  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
253  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
254  {5, {"-", "+"}}};
255  string actual = f.html(keyFormatter, names);
256  EXPECT(actual == expected);
257 }
258 
259 /* ************************************************************************* */
260 int main() {
261  TestResult tr;
262  return TestRegistry::runAllTests(tr);
263 }
264 /* ************************************************************************* */
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
gtsam::markdown
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
Definition: DiscreteValues.cpp:130
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
B
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
Y
const char Y
Definition: test/EulerAngles.cpp:31
D
MatrixXcd D
Definition: EigenSolver_EigenSolver_MatrixType.cpp:14
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
v0
static const double v0
Definition: testCal3DFisheye.cpp:31
EXPECT_LONGS_EQUAL
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Testable.h
Concept check for values that can be used in unit tests.
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
b
Scalar * b
Definition: benchVecAdd.cpp:17
gtsam::DiscreteDistribution
Definition: DiscreteDistribution.h:33
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:371
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
f2
double f2(const Vector2 &x)
Definition: testNumericalDerivative.cpp:56
log
const EIGEN_DEVICE_FUNC LogReturnType log() const
Definition: ArrayCwiseUnaryOps.h:128
X
#define X
Definition: icosphere.cpp:20
TEST
TEST(DecisionTreeFactor, ConstructorsMatch)
Definition: testDecisionTreeFactor.cpp:31
gtsam::DecisionTreeFactor::shared_ptr
std::shared_ptr< DecisionTreeFactor > shared_ptr
Definition: DecisionTreeFactor.h:50
A
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
gtsam::AlgebraicDecisionTree< Key >
table
ArrayXXf table(10, 4)
main
int main()
Definition: testDecisionTreeFactor.cpp:260
Signature.h
signatures for conditional densities
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
serializationTestHelpers.h
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
simulated2D::prior
Point2 prior(const Point2 &x)
Prior on a single pose.
Definition: simulated2D.h:88
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
C
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
gtsam
traits
Definition: chartTesting.h:28
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
f3
double f3(double x1, double x2)
Definition: testNumericalDerivative.cpp:76
leaf::values
leaf::MyValues values
DiscreteDistribution.h
CHECK
#define CHECK(condition)
Definition: Test.h:108
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
v2
Vector v2
Definition: testSerializationBase.cpp:39
f4
double f4(double x, double y, double z)
Definition: testNumericalDerivative.cpp:105
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
unary::f1
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Definition: testExpression.cpp:79
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
Z
#define Z
Definition: icosphere.cpp:21
DecisionTreeFactor.h
v1
Vector v1
Definition: testSerializationBase.cpp:38


gtsam
Author(s):
autogenerated on Tue Jun 25 2024 03:05:42