testDecisionTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDecisionTree.cpp
14  * @brief Develop DecisionTree
15  * @author Frank Dellaert
16  * @author Can Erdogan
17  * @date Jan 30, 2012
18  */
19 
20 // #define DT_DEBUG_MEMORY
21 // #define GTSAM_DT_NO_PRUNING
22 #define DISABLE_DOT
24 #include <gtsam/base/Testable.h>
28 
29 #include <iomanip>
30 
31 using std::vector;
32 using std::string;
33 using std::map;
34 using namespace gtsam;
35 
36 template <typename T>
37 void dot(const T& f, const string& filename) {
38 #ifndef DISABLE_DOT
39  f.dot(filename);
40 #endif
41 }
42 
43 #define DOT(x) (dot(x, #x))
44 
45 struct Crazy {
46  int a;
47  double b;
48 };
49 
50 struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
52  void print(const std::string& s = "") const {
53  auto keyFormatter = [](const std::string& s) { return s; };
54  auto valueFormatter = [](const Crazy& v) {
55  std::stringstream ss;
56  ss << "{" << v.a << "," << std::setw(4) << std::setprecision(2) << v.b << "}";
57  return ss.str();
58  };
60  }
62  bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
63  auto compare = [tol](const Crazy& v, const Crazy& w) {
64  return v.a == w.a && std::abs(v.b - w.b) < tol;
65  };
67  }
68 };
69 
70 // traits
71 namespace gtsam {
72 template <>
73 struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
74 } // namespace gtsam
75 
77 
78 /* ************************************************************************** */
79 // Test string labels and int range
80 /* ************************************************************************** */
81 
82 struct DT : public DecisionTree<string, int> {
85  DT() = default;
86 
87  DT(const Base& dt) : Base(dt) {}
88 
90  void print(const std::string& s = "") const {
91  auto keyFormatter = [](const std::string& s) { return s; };
92  auto valueFormatter = [](const int& v) {
93  return std::to_string(v);
94  };
95  std::cout << s;
96  Base::print("", keyFormatter, valueFormatter);
97  }
99  bool equals(const Base& other, double tol = 1e-9) const {
100  auto compare = [](const int& v, const int& w) { return v == w; };
101  return Base::equals(other, compare);
102  }
103 };
104 
105 // traits
106 namespace gtsam {
107 template <>
108 struct traits<DT> : public Testable<DT> {};
109 } // namespace gtsam
110 
112 
113 struct Ring {
114  static inline int zero() { return 0; }
115  static inline int one() { return 1; }
116  static inline int id(const int& a) { return a; }
117  static inline int add(const int& a, const int& b) { return a + b; }
118  static inline int mul(const int& a, const int& b) { return a * b; }
119 };
120 
121 /* ************************************************************************** */
122 // test DT
124  // Create labels
125  string A("A"), B("B"), C("C");
126 
127  // create a value
128  Assignment<string> x00, x01, x10, x11;
129  x00[A] = 0, x00[B] = 0;
130  x01[A] = 0, x01[B] = 1;
131  x10[A] = 1, x10[B] = 0;
132  x11[A] = 1, x11[B] = 1;
133 
134  // empty
135  DT empty;
136 
137  // A
138  DT a(A, 0, 5);
139  LONGS_EQUAL(0, a(x00))
140  LONGS_EQUAL(5, a(x10))
141  DOT(a);
142 
143  // pruned
144  DT p(A, 2, 2);
145  LONGS_EQUAL(2, p(x00))
146  LONGS_EQUAL(2, p(x10))
147  DOT(p);
148 
149  // \neg B
150  DT notb(B, 5, 0);
151  LONGS_EQUAL(5, notb(x00))
152  LONGS_EQUAL(5, notb(x10))
153  DOT(notb);
154 
155  // Check supplying empty trees yields an exception
156  CHECK_EXCEPTION(gtsam::apply(empty, &Ring::id), std::runtime_error);
157  CHECK_EXCEPTION(gtsam::apply(empty, a, &Ring::mul), std::runtime_error);
158  CHECK_EXCEPTION(gtsam::apply(a, empty, &Ring::mul), std::runtime_error);
159 
160  // apply, two nodes, in natural order
161  DT anotb = apply(a, notb, &Ring::mul);
162  LONGS_EQUAL(0, anotb(x00))
163  LONGS_EQUAL(0, anotb(x01))
164  LONGS_EQUAL(25, anotb(x10))
165  LONGS_EQUAL(0, anotb(x11))
166  DOT(anotb);
167 
168  // check pruning
169  DT pnotb = apply(p, notb, &Ring::mul);
170  LONGS_EQUAL(10, pnotb(x00))
171  LONGS_EQUAL(0, pnotb(x01))
172  LONGS_EQUAL(10, pnotb(x10))
173  LONGS_EQUAL(0, pnotb(x11))
174  DOT(pnotb);
175 
176  // check pruning
177  DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
178  LONGS_EQUAL(0, zeros(x00))
179  LONGS_EQUAL(0, zeros(x01))
180  LONGS_EQUAL(0, zeros(x10))
181  LONGS_EQUAL(0, zeros(x11))
182  DOT(zeros);
183 
184  // apply, two nodes, in switched order
185  DT notba = apply(a, notb, &Ring::mul);
186  LONGS_EQUAL(0, notba(x00))
187  LONGS_EQUAL(0, notba(x01))
188  LONGS_EQUAL(25, notba(x10))
189  LONGS_EQUAL(0, notba(x11))
190  DOT(notba);
191 
192  // Test choose 0
193  DT actual0 = notba.choose(A, 0);
194  EXPECT(assert_equal(DT(0.0), actual0));
195  DOT(actual0);
196 
197  // Test choose 1
198  DT actual1 = notba.choose(A, 1);
199  EXPECT(assert_equal(DT(B, 25, 0), actual1));
200  DOT(actual1);
201 
202  // apply, two nodes at same level
203  DT a_and_a = apply(a, a, &Ring::mul);
204  LONGS_EQUAL(0, a_and_a(x00))
205  LONGS_EQUAL(0, a_and_a(x01))
206  LONGS_EQUAL(25, a_and_a(x10))
207  LONGS_EQUAL(25, a_and_a(x11))
208  DOT(a_and_a);
209 
210  // create a function on C
211  DT c(C, 0, 5);
212 
213  // and a model assigning stuff to C
214  Assignment<string> x101;
215  x101[A] = 1, x101[B] = 0, x101[C] = 1;
216 
217  // mul notba with C
218  DT notbac = apply(notba, c, &Ring::mul);
219  LONGS_EQUAL(125, notbac(x101))
220  DOT(notbac);
221 
222  // mul now in different order
223  DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
224  LONGS_EQUAL(125, acnotb(x101))
225  DOT(acnotb);
226 }
227 
228 /* ************************************************************************** */
229 // test Conversion of values
230 bool bool_of_int(const int& y) { return y != 0; };
232 
233 TEST(DecisionTree, ConvertValuesOnly) {
234  // Create labels
235  string A("A"), B("B");
236 
237  // apply, two nodes, in natural order
238  DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
239 
240  // convert
242 
243  // Check a value
244  Assignment<string> x00;
245  x00["A"] = 0, x00["B"] = 0;
246  EXPECT(!f2(x00));
247 }
248 
249 /* ************************************************************************** */
250 // test Conversion of both values and labels.
251 enum Label { U, V, X, Y, Z };
253 
254 TEST(DecisionTree, ConvertBoth) {
255  // Create labels
256  string A("A"), B("B");
257 
258  // apply, two nodes, in natural order
259  DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
260 
261  // convert
262  map<string, Label> ordering;
263  ordering[A] = X;
264  ordering[B] = Y;
265  LabelBoolTree f2(f1, ordering, &bool_of_int);
266 
267  // Check some values
268  Assignment<Label> x00, x01, x10, x11;
269  x00[X] = 0, x00[Y] = 0;
270  x01[X] = 0, x01[Y] = 1;
271  x10[X] = 1, x10[Y] = 0;
272  x11[X] = 1, x11[Y] = 1;
273  EXPECT(!f2(x00));
274  EXPECT(!f2(x01));
275  EXPECT(f2(x10));
276  EXPECT(!f2(x11));
277 }
278 
279 /* ************************************************************************** */
280 // test Compose expansion
281 TEST(DecisionTree, Compose) {
282  // Create labels
283  string A("A"), B("B"), C("C");
284 
285  // Put two stumps on A together
286  DT f1(B, DT(A, 0, 1), DT(A, 2, 3));
287 
288  // Create from string
289  vector<DT::LabelC> keys{DT::LabelC(A, 2), DT::LabelC(B, 2)};
290  DT f2(keys, "0 2 1 3");
291  EXPECT(assert_equal(f2, f1, 1e-9));
292 
293  // Put this AB tree together with another one
294  DT f3(keys, "4 6 5 7");
295  DT f4(C, f1, f3);
296  DOT(f4);
297 
298  // a bigger tree
299  keys.push_back(DT::LabelC(C, 2));
300  DT f5(keys, "0 4 2 6 1 5 3 7");
301  EXPECT(assert_equal(f5, f4, 1e-9));
302  DOT(f5);
303 }
304 
305 /* ************************************************************************** */
306 // Check we can create a decision tree of containers.
307 TEST(DecisionTree, Containers) {
308  using Container = std::vector<double>;
309  using StringContainerTree = DecisionTree<string, Container>;
310 
311  // Check default constructor
312  StringContainerTree tree;
313 
314  // Create small two-level tree
315  string A("A"), B("B");
316  DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
317 
318  // Check conversion
319  auto container_of_int = [](const int& i) {
320  Container c;
321  c.emplace_back(i);
322  return c;
323  };
324  StringContainerTree converted(stringIntTree, container_of_int);
325 }
326 
327 /* ************************************************************************** */
328 // Test nrAssignments.
329 TEST(DecisionTree, NrAssignments) {
330  const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
331  DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
332  EXPECT(tree.root_->isLeaf());
333  auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
334  EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
335 
336  DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
337  /* The tree is
338  Choice(C)
339  0 Choice(B)
340  0 0 Leaf 1
341  0 1 Choice(A)
342  0 1 0 Leaf 1
343  0 1 1 Leaf 2
344  1 Choice(B)
345  1 0 Choice(A)
346  1 0 0 Leaf 3
347  1 0 1 Leaf 4
348  1 1 Leaf 5
349  */
350 
351  auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
352  CHECK(root);
353  auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
354  CHECK(choice0);
355  EXPECT(choice0->branches()[0]->isLeaf());
356  auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
357  CHECK(choice00);
358  EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
359 
360  auto choice1 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
361  CHECK(choice1);
362  auto choice10 = std::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
363  CHECK(choice10);
364  auto choice11 = std::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
365  CHECK(choice11);
366  EXPECT(choice11->isLeaf());
367  EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
368 }
369 
370 /* ************************************************************************** */
371 // Test visit.
373  // Create small two-level tree
374  string A("A"), B("B");
375  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
376  double sum = 0.0;
377  auto visitor = [&](int y) { sum += y; };
378  tree.visit(visitor);
379  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
380 }
381 
382 /* ************************************************************************** */
383 // Test visit, with Choices argument.
384 TEST(DecisionTree, visitWith) {
385  // Create small two-level tree
386  string A("A"), B("B");
387  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
388  double sum = 0.0;
389  auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
390  tree.visitWith(visitor);
391  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
392 }
393 
394 /* ************************************************************************** */
395 // Test visit, with Choices argument.
396 TEST(DecisionTree, VisitWithPruned) {
397  // Create pruned tree
398  std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
399  std::vector<std::pair<string, size_t>> labels = {C, B, A};
400  std::vector<int> nodes = {0, 0, 2, 3, 4, 4, 6, 7};
401  DT tree(labels, nodes);
402 
403  std::vector<Assignment<string>> choices;
404  auto func = [&](const Assignment<string>& choice, const int& d) {
405  choices.push_back(choice);
406  };
407  tree.visitWith(func);
408 
409  EXPECT_LONGS_EQUAL(6, choices.size());
410 
411  Assignment<string> expectedAssignment;
412 
413  expectedAssignment = {{"B", 0}, {"C", 0}};
414  EXPECT(expectedAssignment == choices.at(0));
415 
416  expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
417  EXPECT(expectedAssignment == choices.at(1));
418 
419  expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
420  EXPECT(expectedAssignment == choices.at(2));
421 
422  expectedAssignment = {{"B", 0}, {"C", 1}};
423  EXPECT(expectedAssignment == choices.at(3));
424 
425  expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}};
426  EXPECT(expectedAssignment == choices.at(4));
427 
428  expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
429  EXPECT(expectedAssignment == choices.at(5));
430 }
431 
432 /* ************************************************************************** */
433 // Test fold.
435  // Create small two-level tree
436  string A("A"), B("B");
437  DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
438  auto add = [](const int& y, double x) { return y + x; };
439  double sum = tree.fold(add, 0.0);
440  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
441 }
442 
443 /* ************************************************************************** */
444 // Test retrieving all labels.
446  // Create small two-level tree
447  string A("A"), B("B");
448  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
449  auto labels = tree.labels();
450  EXPECT_LONGS_EQUAL(2, labels.size());
451 }
452 
453 /* ************************************************************************** */
454 // Test unzip method.
457  using DT1 = DecisionTree<string, int>;
458  using DT2 = DecisionTree<string, string>;
459 
460  // Create small two-level tree
461  string A("A"), B("B"), C("C");
462  DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
463  DTP(A, {2, "two"}, {1337, "l33t"}));
464 
465  const auto [dt1, dt2] = unzip(tree);
466 
467  DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
468  DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
469 
470  EXPECT(tree1.equals(dt1));
471  EXPECT(tree2.equals(dt2));
472 }
473 
474 /* ************************************************************************** */
475 // Test thresholding.
476 TEST(DecisionTree, threshold) {
477  // Create three level tree
478  const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
479  DT::LabelC("A", 2)};
480  DT tree(keys, "0 1 2 3 4 5 6 7");
481 
482  // Check number of leaves equal to zero
483  auto count = [](const int& value, int count) {
484  return value == 0 ? count + 1 : count;
485  };
486  EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
487 
488  // Now threshold
489  auto threshold = [](int value) { return value < 5 ? 0 : value; };
490  DT thresholded(tree, threshold);
491 
492  // Check number of leaves equal to zero now = 2
493  // Note: it is 2, because the pruned branches are counted as 1!
494  EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
495 }
496 
497 /* ************************************************************************** */
498 // Test apply with assignment.
499 TEST(DecisionTree, ApplyWithAssignment) {
500  // Create three level tree
501  const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
502  DT::LabelC("A", 2)};
503  DT tree(keys, "1 2 3 4 5 6 7 8");
504 
506  keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
507  double threshold = 0.045;
508 
509  // We test pruning one tree by indexing into another.
510  auto pruner = [&](const Assignment<string>& choices, const int& x) {
511  // Prune out all the leaves with even numbers
512  if (probTree(choices) < threshold) {
513  return 0;
514  } else {
515  return x;
516  }
517  };
518  DT prunedTree = tree.apply(pruner);
519 
520  DT expectedTree(keys, "0 0 0 0 5 6 7 8");
521  EXPECT(assert_equal(expectedTree, prunedTree));
522 
523  size_t count = 0;
524  auto counter = [&](const Assignment<string>& choices, const int& x) {
525  count += 1;
526  return x;
527  };
528  DT prunedTree2 = prunedTree.apply(counter);
529 
530  // Check if apply doesn't enumerate all leaves.
531  EXPECT_LONGS_EQUAL(5, count);
532 }
533 
534 /* ************************************************************************* */
535 int main() {
536  TestResult tr;
537  return TestRegistry::runAllTests(tr);
538 }
539 /* ************************************************************************* */
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
#define CHECK(condition)
Definition: Test.h:108
bool compare
static int id(const int &a)
static int mul(const int &a, const int &b)
static int one()
Scalar * y
Concept check for values that can be used in unit tests.
void visit(Func f) const
Visit all leaves in depth-first fashion.
static int runAllTests(TestResult &result)
double dot(const V1 &a, const V2 &b)
Definition: Vector.h:195
signatures for conditional densities
std::vector< std::string > labels
KeyVector nodes
Definition: testMFAS.cpp:28
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:398
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
static std::string valueFormatter(const double &v)
#define CHECK_EXCEPTION(condition, exception_name)
Definition: Test.h:118
std::set< L > labels() const
static enum @1107 ordering
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
const double dt
static int zero()
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:425
bool bool_of_int(const int &y)
#define EXPECT(condition)
Definition: Test.h:150
Array< int, Dynamic, 1 > v
bool equals(const Base &other, double tol=1e-9) const
Equality method customized to int node type.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
RealScalar s
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
#define DOT(x)
void print(const std::string &s="") const
print to stdout
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
const G & b
Definition: Group.h:86
RowVector3d w
static std::stringstream ss
Definition: testBTree.cpp:31
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:134
int main()
traits
Definition: chartTesting.h:28
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
static int add(const int &a, const int &b)
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
double f4(double x, double y, double z)
float * p
double f3(double x1, double x2)
void print(const std::string &s="") const
print to stdout
graph add(PriorFactor< Pose2 >(1, priorMean, priorNoise))
DecisionTree< Label, bool > LabelBoolTree
const G double tol
Definition: Group.h:86
std::pair< string, size_t > LabelC
Definition: DecisionTree.h:67
TEST(SmartFactorBase, Pinhole)
const KeyVector keys
DecisionTree< string, bool > StringBoolTree
bool equals(const CrazyDecisionTree &other, double tol=1e-9) const
Equality method customized to Crazy node type.
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
#define abs(x)
Definition: datatypes.h:17
DT(const Base &dt)
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:341
#define GTSAM_CONCEPT_TESTABLE_INST(T)
Definition: Testable.h:176


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:01