Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
8  * See LICENSE for the license information
10  * -------------------------------------------------------------------------- */
12 /*
13  * @file testDecisionTree.cpp
14  * @brief Develop DecisionTree
15  * @author Frank Dellaert
16  * @author Can Erdogan
17  * @date Jan 30, 2012
18  */
20 // #define DT_DEBUG_MEMORY
21 // #define GTSAM_DT_NO_PRUNING
22 #define DISABLE_DOT
24 #include <gtsam/base/Testable.h>
29 #include <iomanip>
31 using std::vector;
32 using std::string;
33 using std::map;
34 using namespace gtsam;
36 template <typename T>
37 void dot(const T& f, const string& filename) {
38 #ifndef DISABLE_DOT
39  f.dot(filename);
40 #endif
41 }
43 #define DOT(x) (dot(x, #x))
45 struct Crazy {
46  int a;
47  double b;
48 };
50 struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
52  void print(const std::string& s = "") const {
53  auto keyFormatter = [](const std::string& s) { return s; };
54  auto valueFormatter = [](const Crazy& v) {
55  std::stringstream ss;
56  ss << "{" << v.a << "," << std::setw(4) << std::setprecision(2) << v.b << "}";
57  return ss.str();
58  };
60  }
62  bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
63  auto compare = [tol](const Crazy& v, const Crazy& w) {
64  return v.a == w.a && std::abs(v.b - w.b) < tol;
65  };
67  }
68 };
70 // traits
71 namespace gtsam {
72 template <>
73 struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
74 } // namespace gtsam
78 /* ************************************************************************** */
79 // Test string labels and int range
80 /* ************************************************************************** */
82 struct DT : public DecisionTree<string, int> {
85  DT() = default;
87  DT(const Base& dt) : Base(dt) {}
90  void print(const std::string& s = "") const {
91  auto keyFormatter = [](const std::string& s) { return s; };
92  auto valueFormatter = [](const int& v) {
93  return std::to_string(v);
94  };
95  std::cout << s;
96  Base::print("", keyFormatter, valueFormatter);
97  }
99  bool equals(const Base& other, double tol = 1e-9) const {
100  auto compare = [](const int& v, const int& w) { return v == w; };
101  return Base::equals(other, compare);
102  }
103 };
105 // traits
106 namespace gtsam {
107 template <>
108 struct traits<DT> : public Testable<DT> {};
109 } // namespace gtsam
113 struct Ring {
114  static inline int zero() { return 0; }
115  static inline int one() { return 1; }
116  static inline int id(const int& a) { return a; }
117  static inline int add(const int& a, const int& b) { return a + b; }
118  static inline int mul(const int& a, const int& b) { return a * b; }
119 };
121 /* ************************************************************************** */
122 // test DT
124  // Create labels
125  string A("A"), B("B"), C("C");
127  // create a value
128  Assignment<string> x00, x01, x10, x11;
129  x00[A] = 0, x00[B] = 0;
130  x01[A] = 0, x01[B] = 1;
131  x10[A] = 1, x10[B] = 0;
132  x11[A] = 1, x11[B] = 1;
134  // empty
135  DT empty;
137  // A
138  DT a(A, 0, 5);
139  LONGS_EQUAL(0, a(x00))
140  LONGS_EQUAL(5, a(x10))
141  DOT(a);
143  // pruned
144  DT p(A, 2, 2);
145  LONGS_EQUAL(2, p(x00))
146  LONGS_EQUAL(2, p(x10))
147  DOT(p);
149  // \neg B
150  DT notb(B, 5, 0);
151  LONGS_EQUAL(5, notb(x00))
152  LONGS_EQUAL(5, notb(x10))
153  DOT(notb);
155  // Check supplying empty trees yields an exception
156  CHECK_EXCEPTION(gtsam::apply(empty, &Ring::id), std::runtime_error);
157  CHECK_EXCEPTION(gtsam::apply(empty, a, &Ring::mul), std::runtime_error);
158  CHECK_EXCEPTION(gtsam::apply(a, empty, &Ring::mul), std::runtime_error);
160  // apply, two nodes, in natural order
161  DT anotb = apply(a, notb, &Ring::mul);
162  LONGS_EQUAL(0, anotb(x00))
163  LONGS_EQUAL(0, anotb(x01))
164  LONGS_EQUAL(25, anotb(x10))
165  LONGS_EQUAL(0, anotb(x11))
166  DOT(anotb);
168  // check pruning
169  DT pnotb = apply(p, notb, &Ring::mul);
170  LONGS_EQUAL(10, pnotb(x00))
171  LONGS_EQUAL(0, pnotb(x01))
172  LONGS_EQUAL(10, pnotb(x10))
173  LONGS_EQUAL(0, pnotb(x11))
174  DOT(pnotb);
176  // check pruning
177  DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
178  LONGS_EQUAL(0, zeros(x00))
179  LONGS_EQUAL(0, zeros(x01))
180  LONGS_EQUAL(0, zeros(x10))
181  LONGS_EQUAL(0, zeros(x11))
182  DOT(zeros);
184  // apply, two nodes, in switched order
185  DT notba = apply(a, notb, &Ring::mul);
186  LONGS_EQUAL(0, notba(x00))
187  LONGS_EQUAL(0, notba(x01))
188  LONGS_EQUAL(25, notba(x10))
189  LONGS_EQUAL(0, notba(x11))
190  DOT(notba);
192  // Test choose 0
193  DT actual0 = notba.choose(A, 0);
194  EXPECT(assert_equal(DT(0.0), actual0));
195  DOT(actual0);
197  // Test choose 1
198  DT actual1 = notba.choose(A, 1);
199  EXPECT(assert_equal(DT(B, 25, 0), actual1));
200  DOT(actual1);
202  // apply, two nodes at same level
203  DT a_and_a = apply(a, a, &Ring::mul);
204  LONGS_EQUAL(0, a_and_a(x00))
205  LONGS_EQUAL(0, a_and_a(x01))
206  LONGS_EQUAL(25, a_and_a(x10))
207  LONGS_EQUAL(25, a_and_a(x11))
208  DOT(a_and_a);
210  // create a function on C
211  DT c(C, 0, 5);
213  // and a model assigning stuff to C
214  Assignment<string> x101;
215  x101[A] = 1, x101[B] = 0, x101[C] = 1;
217  // mul notba with C
218  DT notbac = apply(notba, c, &Ring::mul);
219  LONGS_EQUAL(125, notbac(x101))
220  DOT(notbac);
222  // mul now in different order
223  DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
224  LONGS_EQUAL(125, acnotb(x101))
225  DOT(acnotb);
226 }
228 /* ************************************************************************** */
229 // test Conversion of values
230 bool bool_of_int(const int& y) { return y != 0; };
233 TEST(DecisionTree, ConvertValuesOnly) {
234  // Create labels
235  string A("A"), B("B");
237  // apply, two nodes, in natural order
238  DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
240  // convert
243  // Check a value
244  Assignment<string> x00;
245  x00["A"] = 0, x00["B"] = 0;
246  EXPECT(!f2(x00));
247 }
249 /* ************************************************************************** */
250 // test Conversion of both values and labels.
251 enum Label { U, V, X, Y, Z };
254 TEST(DecisionTree, ConvertBoth) {
255  // Create labels
256  string A("A"), B("B");
258  // apply, two nodes, in natural order
259  DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
261  // convert
262  map<string, Label> ordering;
263  ordering[A] = X;
264  ordering[B] = Y;
265  LabelBoolTree f2(f1, ordering, &bool_of_int);
267  // Check some values
268  Assignment<Label> x00, x01, x10, x11;
269  x00[X] = 0, x00[Y] = 0;
270  x01[X] = 0, x01[Y] = 1;
271  x10[X] = 1, x10[Y] = 0;
272  x11[X] = 1, x11[Y] = 1;
273  EXPECT(!f2(x00));
274  EXPECT(!f2(x01));
275  EXPECT(f2(x10));
276  EXPECT(!f2(x11));
277 }
279 /* ************************************************************************** */
280 // test Compose expansion
281 TEST(DecisionTree, Compose) {
282  // Create labels
283  string A("A"), B("B"), C("C");
285  // Put two stumps on A together
286  DT f1(B, DT(A, 0, 1), DT(A, 2, 3));
288  // Create from string
289  vector<DT::LabelC> keys{DT::LabelC(A, 2), DT::LabelC(B, 2)};
290  DT f2(keys, "0 2 1 3");
291  EXPECT(assert_equal(f2, f1, 1e-9));
293  // Put this AB tree together with another one
294  DT f3(keys, "4 6 5 7");
295  DT f4(C, f1, f3);
296  DOT(f4);
298  // a bigger tree
299  keys.push_back(DT::LabelC(C, 2));
300  DT f5(keys, "0 4 2 6 1 5 3 7");
301  EXPECT(assert_equal(f5, f4, 1e-9));
302  DOT(f5);
303 }
305 /* ************************************************************************** */
306 // Check we can create a decision tree of containers.
307 TEST(DecisionTree, Containers) {
308  using Container = std::vector<double>;
309  using StringContainerTree = DecisionTree<string, Container>;
311  // Check default constructor
312  StringContainerTree tree;
314  // Create small two-level tree
315  string A("A"), B("B");
316  DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
318  // Check conversion
319  auto container_of_int = [](const int& i) {
320  Container c;
321  c.emplace_back(i);
322  return c;
323  };
324  StringContainerTree converted(stringIntTree, container_of_int);
325 }
327 /* ************************************************************************** */
328 // Test nrAssignments.
329 TEST(DecisionTree, NrAssignments) {
330  const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
331  DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
332  EXPECT(tree.root_->isLeaf());
333  auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
334  EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
336  DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
337  /* The tree is
338  Choice(C)
339  0 Choice(B)
340  0 0 Leaf 1
341  0 1 Choice(A)
342  0 1 0 Leaf 1
343  0 1 1 Leaf 2
344  1 Choice(B)
345  1 0 Choice(A)
346  1 0 0 Leaf 3
347  1 0 1 Leaf 4
348  1 1 Leaf 5
349  */
351  auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
352  CHECK(root);
353  auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
354  CHECK(choice0);
355  EXPECT(choice0->branches()[0]->isLeaf());
356  auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
357  CHECK(choice00);
358  EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
360  auto choice1 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
361  CHECK(choice1);
362  auto choice10 = std::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
363  CHECK(choice10);
364  auto choice11 = std::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
365  CHECK(choice11);
366  EXPECT(choice11->isLeaf());
367  EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
368 }
370 /* ************************************************************************** */
371 // Test visit.
373  // Create small two-level tree
374  string A("A"), B("B");
375  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
376  double sum = 0.0;
377  auto visitor = [&](int y) { sum += y; };
378  tree.visit(visitor);
379  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
380 }
382 /* ************************************************************************** */
383 // Test visit, with Choices argument.
384 TEST(DecisionTree, visitWith) {
385  // Create small two-level tree
386  string A("A"), B("B");
387  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
388  double sum = 0.0;
389  auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
390  tree.visitWith(visitor);
391  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
392 }
394 /* ************************************************************************** */
395 // Test visit, with Choices argument.
396 TEST(DecisionTree, VisitWithPruned) {
397  // Create pruned tree
398  std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
399  std::vector<std::pair<string, size_t>> labels = {C, B, A};
400  std::vector<int> nodes = {0, 0, 2, 3, 4, 4, 6, 7};
401  DT tree(labels, nodes);
403  std::vector<Assignment<string>> choices;
404  auto func = [&](const Assignment<string>& choice, const int& d) {
405  choices.push_back(choice);
406  };
407  tree.visitWith(func);
409  EXPECT_LONGS_EQUAL(6, choices.size());
411  Assignment<string> expectedAssignment;
413  expectedAssignment = {{"B", 0}, {"C", 0}};
414  EXPECT(expectedAssignment == choices.at(0));
416  expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
417  EXPECT(expectedAssignment == choices.at(1));
419  expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
420  EXPECT(expectedAssignment == choices.at(2));
422  expectedAssignment = {{"B", 0}, {"C", 1}};
423  EXPECT(expectedAssignment == choices.at(3));
425  expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}};
426  EXPECT(expectedAssignment == choices.at(4));
428  expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
429  EXPECT(expectedAssignment == choices.at(5));
430 }
432 /* ************************************************************************** */
433 // Test fold.
435  // Create small two-level tree
436  string A("A"), B("B");
437  DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
438  auto add = [](const int& y, double x) { return y + x; };
439  double sum = tree.fold(add, 0.0);
440  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
441 }
443 /* ************************************************************************** */
444 // Test retrieving all labels.
446  // Create small two-level tree
447  string A("A"), B("B");
448  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
449  auto labels = tree.labels();
450  EXPECT_LONGS_EQUAL(2, labels.size());
451 }
453 /* ************************************************************************** */
454 // Test unzip method.
457  using DT1 = DecisionTree<string, int>;
458  using DT2 = DecisionTree<string, string>;
460  // Create small two-level tree
461  string A("A"), B("B"), C("C");
462  DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
463  DTP(A, {2, "two"}, {1337, "l33t"}));
465  const auto [dt1, dt2] = unzip(tree);
467  DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
468  DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
470  EXPECT(tree1.equals(dt1));
471  EXPECT(tree2.equals(dt2));
472 }
474 /* ************************************************************************** */
475 // Test thresholding.
476 TEST(DecisionTree, threshold) {
477  // Create three level tree
478  const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
479  DT::LabelC("A", 2)};
480  DT tree(keys, "0 1 2 3 4 5 6 7");
482  // Check number of leaves equal to zero
483  auto count = [](const int& value, int count) {
484  return value == 0 ? count + 1 : count;
485  };
486  EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
488  // Now threshold
489  auto threshold = [](int value) { return value < 5 ? 0 : value; };
490  DT thresholded(tree, threshold);
492  // Check number of leaves equal to zero now = 2
493  // Note: it is 2, because the pruned branches are counted as 1!
494  EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
495 }
497 /* ************************************************************************** */
498 // Test apply with assignment.
499 TEST(DecisionTree, ApplyWithAssignment) {
500  // Create three level tree
501  const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
502  DT::LabelC("A", 2)};
503  DT tree(keys, "1 2 3 4 5 6 7 8");
506  keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
507  double threshold = 0.045;
509  // We test pruning one tree by indexing into another.
510  auto pruner = [&](const Assignment<string>& choices, const int& x) {
511  // Prune out all the leaves with even numbers
512  if (probTree(choices) < threshold) {
513  return 0;
514  } else {
515  return x;
516  }
517  };
518  DT prunedTree = tree.apply(pruner);
520  DT expectedTree(keys, "0 0 0 0 5 6 7 8");
521  EXPECT(assert_equal(expectedTree, prunedTree));
523  size_t count = 0;
524  auto counter = [&](const Assignment<string>& choices, const int& x) {
525  count += 1;
526  return x;
527  };
528  DT prunedTree2 = prunedTree.apply(counter);
530  // Check if apply doesn't enumerate all leaves.
531  EXPECT_LONGS_EQUAL(5, count);
532 }
534 /* ************************************************************************* */
535 int main() {
536  TestResult tr;
537  return TestRegistry::runAllTests(tr);
538 }
539 /* ************************************************************************* */
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
#define CHECK(condition)
Definition: Test.h:108
bool compare
static int id(const int &a)
static int mul(const int &a, const int &b)
static int one()
Scalar * y
Concept check for values that can be used in unit tests.
void visit(Func f) const
Visit all leaves in depth-first fashion.
static int runAllTests(TestResult &result)
double dot(const V1 &a, const V2 &b)
Definition: Vector.h:195
signatures for conditional densities
std::vector< std::string > labels
KeyVector nodes
Definition: testMFAS.cpp:28
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:398
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
static std::string valueFormatter(const double &v)
#define CHECK_EXCEPTION(condition, exception_name)
Definition: Test.h:118
std::set< L > labels() const
static enum @1107 ordering
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
const double dt
static int zero()
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:425
bool bool_of_int(const int &y)
#define EXPECT(condition)
Definition: Test.h:150
Array< int, Dynamic, 1 > v
bool equals(const Base &other, double tol=1e-9) const
Equality method customized to int node type.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
RealScalar s
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
#define DOT(x)
void print(const std::string &s="") const
print to stdout
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
const G & b
Definition: Group.h:86
RowVector3d w
static std::stringstream ss
Definition: testBTree.cpp:31
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:134
int main()
Definition: chartTesting.h:28
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
static int add(const int &a, const int &b)
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
double f4(double x, double y, double z)
float * p
double f3(double x1, double x2)
void print(const std::string &s="") const
print to stdout
graph add(PriorFactor< Pose2 >(1, priorMean, priorNoise))
DecisionTree< Label, bool > LabelBoolTree
const G double tol
Definition: Group.h:86
std::pair< string, size_t > LabelC
Definition: DecisionTree.h:67
TEST(SmartFactorBase, Pinhole)
const KeyVector keys
DecisionTree< string, bool > StringBoolTree
bool equals(const CrazyDecisionTree &other, double tol=1e-9) const
Equality method customized to Crazy node type.
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
#define abs(x)
Definition: datatypes.h:17
DT(const Base &dt)
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:341
Definition: Testable.h:176

autogenerated on Tue Jul 4 2023 02:38:01