testDecisionTreeFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testDecisionTreeFactor.cpp
14  *
15  * @date Feb 5, 2012
16  * @author Frank Dellaert
17  * @author Duy-Nguyen Ta
18  */
19 
21 #include <gtsam/base/Testable.h>
27 #include <gtsam/inference/Key.h>
29 
30 using namespace std;
31 using namespace gtsam;
32 
33 /* ************************************************************************* */
34 TEST(DecisionTreeFactor, ConstructorsMatch) {
35  // Declare two keys
36  DiscreteKey X(0, 2), Y(1, 3);
37 
38  // Create with vector and with string
39  const std::vector<double> table{2, 5, 3, 6, 4, 7};
41  DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
43 }
44 
45 /* ************************************************************************* */
46 TEST(DecisionTreeFactor, constructors) {
47  // Declare a bunch of keys
48  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
49 
50  // Create factors
51  DecisionTreeFactor f1(X, {2, 8});
52  DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
53  DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
54  EXPECT_LONGS_EQUAL(1, f1.size());
55  EXPECT_LONGS_EQUAL(2, f2.size());
56  EXPECT_LONGS_EQUAL(3, f3.size());
57 
58  DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
59  EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
60  EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
61  EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
62 
63  // Assert that error = -log(value)
64  EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
65 
66  // Construct from DiscreteConditional
67  DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
68  DecisionTreeFactor f4(conditional);
69  EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
70 }
71 
72 /* ************************************************************************* */
74  // Declare a bunch of keys
75  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
76 
77  // Create factors
78  DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
79 
80  auto errors = f.errorTree();
81  // regression
83  {X, Y, Z},
84  vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
85  -1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
86  -4.1743873, -3.8066625, -4.3174881});
87  EXPECT(assert_equal(expected, errors, 1e-6));
88 }
89 
90 /* ************************************************************************* */
91 TEST(DecisionTreeFactor, multiplication) {
92  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
93 
94  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
95  DiscreteDistribution prior(v1 % "1/3");
96  DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
97  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
100 
101  // Multiply two factors
102  DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
103  DecisionTreeFactor actual = f1 * f2;
104  DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
105  CHECK(assert_equal(expected2, actual));
106 }
107 
108 /* ************************************************************************* */
110  DiscreteKey v0(0, 3), v1(1, 2);
111  DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
112 
113  DecisionTreeFactor expected(v1, "9 12");
114  DecisionTreeFactor::shared_ptr actual = f1.sum(1);
115  CHECK(assert_equal(expected, *actual, 1e-5));
116 
117  DecisionTreeFactor expected2(v1, "5 6");
118  DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
119  CHECK(assert_equal(expected2, *actual2));
120 
121  DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
122  DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
123 }
124 
125 /* ************************************************************************* */
126 // Check enumerate yields the correct list of assignment/value pairs.
128  DiscreteKey A(12, 3), B(5, 2);
129  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
130  auto actual = f.enumerate();
131  std::vector<std::pair<DiscreteValues, double>> expected;
133  for (size_t a : {0, 1, 2}) {
134  for (size_t b : {0, 1}) {
135  values[12] = a;
136  values[5] = b;
137  expected.emplace_back(values, f(values));
138  }
139  }
140  EXPECT(actual == expected);
141 }
142 
143 namespace pruning_fixture {
144 
145 DiscreteKey A(1, 2), B(2, 2), C(3, 2);
146 DecisionTreeFactor f(A& B& C, "1 5 3 7 2 6 4 8");
147 
148 DiscreteKey D(4, 2);
150  D& C & B & A,
151  "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
152  "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
153 
154 } // namespace pruning_fixture
155 
156 /* ************************************************************************* */
157 // Check if computing the correct threshold works.
158 TEST(DecisionTreeFactor, ComputeThreshold) {
159  using namespace pruning_fixture;
160 
161  // Only keep the leaves with the top 5 values.
162  double threshold = f.computeThreshold(5);
163  EXPECT_DOUBLES_EQUAL(4.0, threshold, 1e-9);
164 
165  // Check for more extreme pruning where we only keep the top 2 leaves
166  threshold = f.computeThreshold(2);
167  EXPECT_DOUBLES_EQUAL(7.0, threshold, 1e-9);
168 
169  threshold = factor.computeThreshold(5);
170  EXPECT_DOUBLES_EQUAL(0.99995287, threshold, 1e-9);
171 
172  threshold = factor.computeThreshold(3);
173  EXPECT_DOUBLES_EQUAL(1.0, threshold, 1e-9);
174 
175  threshold = factor.computeThreshold(6);
176  EXPECT_DOUBLES_EQUAL(0.61247742, threshold, 1e-9);
177 }
178 
179 /* ************************************************************************* */
180 // Check pruning of the decision tree works as expected.
182  using namespace pruning_fixture;
183 
184  // Only keep the leaves with the top 5 values.
185  size_t maxNrAssignments = 5;
186  auto pruned5 = f.prune(maxNrAssignments);
187 
188  // Pruned leaves should be 0
189  DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
190  EXPECT(assert_equal(expected, pruned5));
191 
192  // Check for more extreme pruning where we only keep the top 2 leaves
193  maxNrAssignments = 2;
194  auto pruned2 = f.prune(maxNrAssignments);
195  DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
196  EXPECT(assert_equal(expected2, pruned2));
197 
198  DecisionTreeFactor expected3(D & C & B & A,
199  "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
200  "0.999952870000 1.0 1.0 1.0 1.0");
201  maxNrAssignments = 5;
202  auto pruned3 = factor.prune(maxNrAssignments);
203  EXPECT(assert_equal(expected3, pruned3));
204 }
205 
206 /* ************************************************************************** */
207 // Asia Bayes Network
208 /* ************************************************************************** */
209 
210 #define DISABLE_DOT
211 
212 void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
213 #ifndef DISABLE_DOT
214  std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
215  auto formatter = [&](Key key) { return names[key]; };
216  f.dot(filename, formatter, true);
217 #endif
218 }
219 
222  DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
223  return p;
224 }
225 
226 /* ************************************************************************* */
227 // test Asia Joint
229  DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
230  D(7, 2);
231 
232  gttic_(asiaCPTs);
233  DecisionTreeFactor pA = create(A % "99/1");
234  DecisionTreeFactor pS = create(S % "50/50");
235  DecisionTreeFactor pT = create(T | A = "99/1 95/5");
236  DecisionTreeFactor pL = create(L | S = "99/1 90/10");
237  DecisionTreeFactor pB = create(B | S = "70/30 40/60");
238  DecisionTreeFactor pE = create((E | T, L) = "F T T T");
239  DecisionTreeFactor pX = create(X | E = "95/5 2/98");
240  DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
241 
242  // Create joint
243  gttic_(asiaJoint);
244  DecisionTreeFactor joint = pA;
245  maybeSaveDotFile(joint, "Asia-A");
246  joint = joint * pS;
247  maybeSaveDotFile(joint, "Asia-AS");
248  joint = joint * pT;
249  maybeSaveDotFile(joint, "Asia-AST");
250  joint = joint * pL;
251  maybeSaveDotFile(joint, "Asia-ASTL");
252  joint = joint * pB;
253  maybeSaveDotFile(joint, "Asia-ASTLB");
254  joint = joint * pE;
255  maybeSaveDotFile(joint, "Asia-ASTLBE");
256  joint = joint * pX;
257  maybeSaveDotFile(joint, "Asia-ASTLBEX");
258  joint = joint * pD;
259  maybeSaveDotFile(joint, "Asia-ASTLBEXD");
260 
261  // Check that discrete keys are as expected
262  EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
263 
264  // Check that summing out variables maintains the keys even if merged, as is
265  // the case with S.
266  auto noAB = joint.sum(Ordering{A.first, B.first});
267  EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
268 }
269 
270 /* ************************************************************************* */
271 TEST(DecisionTreeFactor, DotWithNames) {
272  DiscreteKey A(12, 3), B(5, 2);
273  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
274  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
275 
276  for (bool showZero : {true, false}) {
277  string actual = f.dot(formatter, showZero);
278  // pretty weak test, as ids are pointers and not stable across platforms.
279  string expected = "digraph G {";
280  EXPECT(actual.substr(0, 11) == expected);
281  }
282 }
283 
284 /* ************************************************************************* */
285 // Check markdown representation looks as expected.
287  DiscreteKey A(12, 3), B(5, 2);
288  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
289  string expected =
290  "|A|B|value|\n"
291  "|:-:|:-:|:-:|\n"
292  "|0|0|1|\n"
293  "|0|1|2|\n"
294  "|1|0|3|\n"
295  "|1|1|4|\n"
296  "|2|0|5|\n"
297  "|2|1|6|\n";
298  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
299  string actual = f.markdown(formatter);
300  EXPECT(actual == expected);
301 }
302 
303 /* ************************************************************************* */
304 // Check markdown representation with a value formatter.
305 TEST(DecisionTreeFactor, markdownWithValueFormatter) {
306  DiscreteKey A(12, 3), B(5, 2);
307  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
308  string expected =
309  "|A|B|value|\n"
310  "|:-:|:-:|:-:|\n"
311  "|Zero|-|1|\n"
312  "|Zero|+|2|\n"
313  "|One|-|3|\n"
314  "|One|+|4|\n"
315  "|Two|-|5|\n"
316  "|Two|+|6|\n";
317  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
318  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
319  {5, {"-", "+"}}};
320  string actual = f.markdown(keyFormatter, names);
321  EXPECT(actual == expected);
322 }
323 
324 /* ************************************************************************* */
325 // Check html representation with a value formatter.
326 TEST(DecisionTreeFactor, htmlWithValueFormatter) {
327  DiscreteKey A(12, 3), B(5, 2);
328  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
329  string expected =
330  "<div>\n"
331  "<table class='DecisionTreeFactor'>\n"
332  " <thead>\n"
333  " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
334  " </thead>\n"
335  " <tbody>\n"
336  " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
337  " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
338  " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
339  " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
340  " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
341  " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
342  " </tbody>\n"
343  "</table>\n"
344  "</div>";
345  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
346  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
347  {5, {"-", "+"}}};
348  string actual = f.html(keyFormatter, names);
349  EXPECT(actual == expected);
350 }
351 
352 /* ************************************************************************* */
353 int main() {
354  TestResult tr;
355  return TestRegistry::runAllTests(tr);
356 }
357 /* ************************************************************************* */
gtsam::Signature::cpt
std::vector< double > cpt() const
Definition: Signature.cpp:69
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
gtsam::DiscreteFactor::errorTree
virtual AlgebraicDecisionTree< Key > errorTree() const
Compute error for each assignment and return as a tree.
Definition: DiscreteFactor.cpp:59
gtsam::markdown
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
Definition: DiscreteValues.cpp:130
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
B
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
test_constructor::f1
auto f1
Definition: testHybridNonlinearFactor.cpp:56
gtsam::DecisionTreeFactor::computeThreshold
double computeThreshold(const size_t N) const
Compute the probability value which is the threshold above which only N leaves are present.
Definition: DecisionTreeFactor.cpp:410
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
v0
static const double v0
Definition: testCal3DFisheye.cpp:31
EXPECT_LONGS_EQUAL
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Testable.h
Concept check for values that can be used in unit tests.
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
gtsam::DecisionTreeFactor::dot
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
Definition: DecisionTreeFactor.cpp:255
b
Scalar * b
Definition: benchVecAdd.cpp:17
asiaCPTs::pX
ADT pX
Definition: testAlgebraicDecisionTree.cpp:159
gtsam::DiscreteDistribution
Definition: DiscreteDistribution.h:33
gtsam::Y
GaussianFactorGraphValuePair Y
Definition: HybridGaussianProductFactor.cpp:29
pruning_fixture::f
DecisionTreeFactor f(A &B &C, "1 5 3 7 2 6 4 8")
asiaCPTs::pA
ADT pA
Definition: testAlgebraicDecisionTree.cpp:153
pruning_fixture::C
DiscreteKey C(3, 2)
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:464
asiaCPTs
Definition: testAlgebraicDecisionTree.cpp:149
maybeSaveDotFile
void maybeSaveDotFile(const DecisionTreeFactor &f, const string &filename)
Definition: testDecisionTreeFactor.cpp:212
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
f2
double f2(const Vector2 &x)
Definition: testNumericalDerivative.cpp:58
pruning_fixture::D
DiscreteKey D(4, 2)
T
Eigen::Triplet< double > T
Definition: Tutorial_sparse_example.cpp:6
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
log
const EIGEN_DEVICE_FUNC LogReturnType log() const
Definition: ArrayCwiseUnaryOps.h:128
Ordering.h
Variable ordering for the elimination algorithm.
X
#define X
Definition: icosphere.cpp:20
TEST
TEST(DecisionTreeFactor, ConstructorsMatch)
Definition: testDecisionTreeFactor.cpp:34
create
DecisionTreeFactor create(const Signature &signature)
Definition: testDecisionTreeFactor.cpp:221
gtsam::DecisionTreeFactor::shared_ptr
std::shared_ptr< DecisionTreeFactor > shared_ptr
Definition: DecisionTreeFactor.h:51
A
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
gtsam::AlgebraicDecisionTree< Key >
pruning_fixture::factor
DecisionTreeFactor factor(D &C &B &A, "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0")
Key.h
asiaCPTs::pE
ADT pE
Definition: testAlgebraicDecisionTree.cpp:158
table
ArrayXXf table(10, 4)
main
int main()
Definition: testDecisionTreeFactor.cpp:353
gttic_
#define gttic_(label)
Definition: timing.h:245
relicense.filename
filename
Definition: relicense.py:57
DiscreteFactor.h
Signature.h
signatures for conditional densities
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
L
MatrixXd L
Definition: LLT_example.cpp:6
asiaCPTs::pT
ADT pT
Definition: testAlgebraicDecisionTree.cpp:155
asiaCPTs::pS
ADT pS
Definition: testAlgebraicDecisionTree.cpp:154
Eigen::Triplet< double >
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
asiaCPTs::pD
ADT pD
Definition: testAlgebraicDecisionTree.cpp:160
gtsam::DecisionTreeFactor::sum
shared_ptr sum(size_t nrFrontals) const
Create new factor by summing all values with the same separator values.
Definition: DecisionTreeFactor.h:166
serializationTestHelpers.h
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
E
DiscreteKey E(5, 2)
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
gtsam
traits
Definition: SFMdata.h:40
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
f3
double f3(double x1, double x2)
Definition: testNumericalDerivative.cpp:78
DiscreteDistribution.h
CHECK
#define CHECK(condition)
Definition: Test.h:108
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
p
float * p
Definition: Tutorial_Map_using.cpp:9
v2
Vector v2
Definition: testSerializationBase.cpp:39
gtsam::DecisionTreeFactor::html
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
Definition: DecisionTreeFactor.cpp:307
pruning_fixture
Definition: testDecisionTreeFactor.cpp:143
f4
double f4(double x, double y, double z)
Definition: testNumericalDerivative.cpp:107
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:41
gtsam::DecisionTreeFactor::enumerate
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
Definition: DecisionTreeFactor.cpp:187
gtsam::Signature::discreteKeys
DiscreteKeys discreteKeys() const
Definition: Signature.cpp:55
different_sigmas::prior
const auto prior
Definition: testHybridBayesNet.cpp:238
gtsam::DiscreteFactor::discreteKeys
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Definition: DiscreteFactor.cpp:37
asiaCPTs::pB
ADT pB
Definition: testAlgebraicDecisionTree.cpp:157
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
Z
#define Z
Definition: icosphere.cpp:21
gtsam::Ordering
Definition: inference/Ordering.h:33
asiaCPTs::pL
ADT pL
Definition: testAlgebraicDecisionTree.cpp:156
DecisionTreeFactor.h
gtsam::DecisionTreeFactor::markdown
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
Definition: DecisionTreeFactor.cpp:276
gtsam::Signature
Definition: Signature.h:54
S
DiscreteKey S(1, 2)
v1
Vector v1
Definition: testSerializationBase.cpp:38


gtsam
Author(s):
autogenerated on Wed Jan 1 2025 04:06:10