testDecisionTreeFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testDecisionTreeFactor.cpp
14  *
15  * @date Feb 5, 2012
16  * @author Frank Dellaert
17  * @author Duy-Nguyen Ta
18  */
19 
21 #include <gtsam/base/Testable.h>
26 
27 using namespace std;
28 using namespace gtsam;
29 
30 /* ************************************************************************* */
31 TEST( DecisionTreeFactor, constructors)
32 {
33  // Declare a bunch of keys
34  DiscreteKey X(0,2), Y(1,3), Z(2,2);
35 
36  // Create factors
37  DecisionTreeFactor f1(X, {2, 8});
38  DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
39  DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
40  EXPECT_LONGS_EQUAL(1,f1.size());
41  EXPECT_LONGS_EQUAL(2,f2.size());
42  EXPECT_LONGS_EQUAL(3,f3.size());
43 
45  values[0] = 1; // x
46  values[1] = 2; // y
47  values[2] = 1; // z
48  EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
49  EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
50  EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
51 
52  // Assert that error = -log(value)
53  EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
54 }
55 
56 /* ************************************************************************* */
57 TEST(DecisionTreeFactor, multiplication) {
58  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
59 
60  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
61  DiscreteDistribution prior(v1 % "1/3");
62  DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
63  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
64  CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
65  CHECK(assert_equal(expected, f1 * prior));
66 
67  // Multiply two factors
68  DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
69  DecisionTreeFactor actual = f1 * f2;
70  DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
71  CHECK(assert_equal(expected2, actual));
72 }
73 
74 /* ************************************************************************* */
76 {
77  DiscreteKey v0(0,3), v1(1,2);
78  DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
79 
80  DecisionTreeFactor expected(v1, "9 12");
81  DecisionTreeFactor::shared_ptr actual = f1.sum(1);
82  CHECK(assert_equal(expected, *actual, 1e-5));
83 
84  DecisionTreeFactor expected2(v1, "5 6");
85  DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
86  CHECK(assert_equal(expected2, *actual2));
87 
88  DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
89  DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
90 }
91 
92 /* ************************************************************************* */
93 // Check enumerate yields the correct list of assignment/value pairs.
94 TEST(DecisionTreeFactor, enumerate) {
95  DiscreteKey A(12, 3), B(5, 2);
96  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
97  auto actual = f.enumerate();
98  std::vector<std::pair<DiscreteValues, double>> expected;
100  for (size_t a : {0, 1, 2}) {
101  for (size_t b : {0, 1}) {
102  values[12] = a;
103  values[5] = b;
104  expected.emplace_back(values, f(values));
105  }
106  }
107  EXPECT(actual == expected);
108 }
109 
110 /* ************************************************************************* */
111 // Check pruning of the decision tree works as expected.
113  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
114  DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
115 
116  // Only keep the leaves with the top 5 values.
117  size_t maxNrAssignments = 5;
118  auto pruned5 = f.prune(maxNrAssignments);
119 
120  // Pruned leaves should be 0
121  DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
122  EXPECT(assert_equal(expected, pruned5));
123 
124  // Check for more extreme pruning where we only keep the top 2 leaves
125  maxNrAssignments = 2;
126  auto pruned2 = f.prune(maxNrAssignments);
127  DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
128  EXPECT(assert_equal(expected2, pruned2));
129 
130  DiscreteKey D(4, 2);
131  DecisionTreeFactor factor(
132  D & C & B & A,
133  "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
134  "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
135 
136  DecisionTreeFactor expected3(
137  D & C & B & A,
138  "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
139  "0.999952870000 1.0 1.0 1.0 1.0");
140  maxNrAssignments = 5;
141  auto pruned3 = factor.prune(maxNrAssignments);
142  EXPECT(assert_equal(expected3, pruned3));
143 }
144 
145 /* ************************************************************************* */
146 TEST(DecisionTreeFactor, DotWithNames) {
147  DiscreteKey A(12, 3), B(5, 2);
148  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
149  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
150 
151  for (bool showZero:{true, false}) {
152  string actual = f.dot(formatter, showZero);
153  // pretty weak test, as ids are pointers and not stable across platforms.
154  string expected = "digraph G {";
155  EXPECT(actual.substr(0, 11) == expected);
156  }
157 }
158 
159 /* ************************************************************************* */
160 // Check markdown representation looks as expected.
162  DiscreteKey A(12, 3), B(5, 2);
163  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
164  string expected =
165  "|A|B|value|\n"
166  "|:-:|:-:|:-:|\n"
167  "|0|0|1|\n"
168  "|0|1|2|\n"
169  "|1|0|3|\n"
170  "|1|1|4|\n"
171  "|2|0|5|\n"
172  "|2|1|6|\n";
173  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
174  string actual = f.markdown(formatter);
175  EXPECT(actual == expected);
176 }
177 
178 /* ************************************************************************* */
179 // Check markdown representation with a value formatter.
180 TEST(DecisionTreeFactor, markdownWithValueFormatter) {
181  DiscreteKey A(12, 3), B(5, 2);
182  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
183  string expected =
184  "|A|B|value|\n"
185  "|:-:|:-:|:-:|\n"
186  "|Zero|-|1|\n"
187  "|Zero|+|2|\n"
188  "|One|-|3|\n"
189  "|One|+|4|\n"
190  "|Two|-|5|\n"
191  "|Two|+|6|\n";
192  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
193  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
194  {5, {"-", "+"}}};
195  string actual = f.markdown(keyFormatter, names);
196  EXPECT(actual == expected);
197 }
198 
199 /* ************************************************************************* */
200 // Check html representation with a value formatter.
201 TEST(DecisionTreeFactor, htmlWithValueFormatter) {
202  DiscreteKey A(12, 3), B(5, 2);
203  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
204  string expected =
205  "<div>\n"
206  "<table class='DecisionTreeFactor'>\n"
207  " <thead>\n"
208  " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
209  " </thead>\n"
210  " <tbody>\n"
211  " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
212  " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
213  " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
214  " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
215  " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
216  " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
217  " </tbody>\n"
218  "</table>\n"
219  "</div>";
220  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
221  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
222  {5, {"-", "+"}}};
223  string actual = f.html(keyFormatter, names);
224  EXPECT(actual == expected);
225 }
226 
227 /* ************************************************************************* */
228 int main() {
229  TestResult tr;
230  return TestRegistry::runAllTests(tr);
231 }
232 /* ************************************************************************* */
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
const gtsam::Symbol key('X', 0)
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
#define CHECK(condition)
Definition: Test.h:108
const char Y
Vector v2
Scalar * b
Definition: benchVecAdd.cpp:17
Concept check for values that can be used in unit tests.
static int runAllTests(TestResult &result)
Vector v1
signatures for conditional densities
Point2 prior(const Point2 &x)
Prior on a single pose.
Definition: simulated2D.h:88
Matrix expected
Definition: testMatrix.cpp:971
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
size_t size() const
Definition: Factor.h:159
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
shared_ptr max(size_t nrFrontals) const
Create new factor by maximizing over all values with the same separator.
leaf::MyValues values
Definition: BFloat16.h:88
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
EIGEN_DEVICE_FUNC const LogReturnType log() const
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
const KeyFormatter & formatter
#define Z
Definition: icosphere.cpp:21
#define EXPECT(condition)
Definition: Test.h:150
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
TEST(DecisionTreeFactor, constructors)
traits
Definition: chartTesting.h:28
std::shared_ptr< DecisionTreeFactor > shared_ptr
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
static const double v0
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
double f3(double x1, double x2)
shared_ptr sum(size_t nrFrontals) const
Create new factor by summing all values with the same separator values.
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
#define X
Definition: icosphere.cpp:20
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:01