DecisionTree-inl.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #pragma once
21 
23 
24 #include <algorithm>
25 
26 #include <cmath>
27 #include <fstream>
28 #include <list>
29 #include <map>
30 #include <set>
31 #include <sstream>
32 #include <string>
33 #include <vector>
34 #include <optional>
35 #include <cassert>
36 #include <iterator>
37 
38 namespace gtsam {
39 
40  /****************************************************************************/
41  // Node
42  /****************************************************************************/
43 #ifdef DT_DEBUG_MEMORY
44  template<typename L, typename Y>
45  int DecisionTree<L, Y>::Node::nrNodes = 0;
46 #endif
47 
48  /****************************************************************************/
49  // Leaf
50  /****************************************************************************/
51  template <typename L, typename Y>
52  struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
55 
57  Leaf() {}
58 
60  Leaf(const Y& constant) : constant_(constant) {}
61 
63  const Y& constant() const {
64  return constant_;
65  }
66 
68  bool sameLeaf(const Leaf& q) const override {
69  return constant_ == q.constant_;
70  }
71 
73  bool sameLeaf(const Node& q) const override {
74  return (q.isLeaf() && q.sameLeaf(*this));
75  }
76 
78  bool equals(const Node& q, const CompareFunc& compare) const override {
79  const Leaf* other = dynamic_cast<const Leaf*>(&q);
80  if (!other) return false;
81  return compare(this->constant_, other->constant_);
82  }
83 
85  void print(const std::string& s, const LabelFormatter& labelFormatter,
86  const ValueFormatter& valueFormatter) const override {
87  std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
88  }
89 
91  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
93  bool showZero) const override {
94  std::string value = valueFormatter(constant_);
95  if (showZero || value.compare("0"))
96  os << "\"" << this->id() << "\" [label=\"" << value
97  << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
98  }
99 
101  const Y& operator()(const Assignment<L>& x) const override {
102  return constant_;
103  }
104 
106  NodePtr apply(const Unary& op) const override {
107  NodePtr f(new Leaf(op(constant_)));
108  return f;
109  }
110 
113  const Assignment<L>& assignment) const override {
114  NodePtr f(new Leaf(op(assignment, constant_)));
115  return f;
116  }
117 
118  // Apply binary operator "h = f op g" on Leaf node
119  // Note op is not assumed commutative so we need to keep track of order
120  // Simply calls apply on argument to call correct virtual method:
121  // fL.apply_f_op_g(gL) -> gL.apply_g_op_fL(fL) (below)
122  // fL.apply_f_op_g(gC) -> gC.apply_g_op_fL(fL) (Choice)
123  NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
124  return g.apply_g_op_fL(*this, op);
125  }
126 
127  // Applying binary operator to two leaves results in a leaf
128  NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
129  // fL op gL
130  NodePtr h(new Leaf(op(fL.constant_, constant_)));
131  return h;
132  }
133 
134  // If second argument is a Choice node, call it's apply with leaf as second
135  NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
136  return fC.apply_fC_op_gL(*this, op); // operand order back to normal
137  }
138 
140  NodePtr choose(const L& label, size_t index) const override {
141  return NodePtr(new Leaf(constant()));
142  }
143 
144  bool isLeaf() const override { return true; }
145 
146  private:
148 
149 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
150 
151  friend class boost::serialization::access;
152  template <class ARCHIVE>
153  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
154  ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
155  ar& BOOST_SERIALIZATION_NVP(constant_);
156  }
157 #endif
158  }; // Leaf
159 
160  /****************************************************************************/
161  // Choice
162  /****************************************************************************/
163  template<typename L, typename Y>
164  struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
167 
169  std::vector<NodePtr> branches_;
170 
171  private:
176  size_t allSame_;
177 
178  using ChoicePtr = std::shared_ptr<const Choice>;
179 
180  public:
182  Choice() {}
183 
184  ~Choice() override {
185 #ifdef DT_DEBUG_MEMORY
186  std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
187  << std::endl;
188 #endif
189  }
190 
207  static NodePtr Unique(const NodePtr& node) {
208  if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
209  // Choice node, we recurse!
210  // Make non-const copy so we can update
211  auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
212 
213  // Iterate over all the branches
214  for (size_t i = 0; i < choice->nrChoices(); i++) {
215  auto branch = choice->branches_[i];
216  f->push_back(Unique(branch));
217  }
218 
219 #ifdef GTSAM_DT_MERGING
220  // If all the branches are the same, we can merge them into one
221  if (f->allSame_) {
222  assert(f->branches().size() > 0);
223  NodePtr f0 = f->branches_[0];
224 
225  NodePtr newLeaf(
226  new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
227  return newLeaf;
228  }
229 #endif
230  return f;
231  } else {
232  // Leaf node, return as is
233  return node;
234  }
235  }
236 
237  bool isLeaf() const override { return false; }
238 
240  Choice(const L& label, size_t count) :
241  label_(label), allSame_(true) {
242  branches_.reserve(count);
243  }
244 
246  Choice(const Choice& f, const Choice& g, const Binary& op) :
247  allSame_(true) {
248  // Choose what to do based on label
249  if (f.label() > g.label()) {
250  // f higher than g
251  label_ = f.label();
252  size_t count = f.nrChoices();
253  branches_.reserve(count);
254  for (size_t i = 0; i < count; i++)
255  push_back(f.branches_[i]->apply_f_op_g(g, op));
256  } else if (g.label() > f.label()) {
257  // f lower than g
258  label_ = g.label();
259  size_t count = g.nrChoices();
260  branches_.reserve(count);
261  for (size_t i = 0; i < count; i++)
262  push_back(g.branches_[i]->apply_g_op_fC(f, op));
263  } else {
264  // f same level as g
265  label_ = f.label();
266  size_t count = f.nrChoices();
267  branches_.reserve(count);
268  for (size_t i = 0; i < count; i++)
269  push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
270  }
271  }
272 
274  const L& label() const {
275  return label_;
276  }
277 
278  size_t nrChoices() const {
279  return branches_.size();
280  }
281 
282  const std::vector<NodePtr>& branches() const {
283  return branches_;
284  }
285 
287  void push_back(const NodePtr& node) {
288  // allSame_ is restricted to leaf nodes in a decision tree
289  if (allSame_ && !branches_.empty()) {
290  allSame_ = node->sameLeaf(*branches_.back());
291  }
292  branches_.push_back(node);
293  }
294 
296  void print(const std::string& s, const LabelFormatter& labelFormatter,
297  const ValueFormatter& valueFormatter) const override {
298  std::cout << s << " Choice(";
299  std::cout << labelFormatter(label_) << ") " << std::endl;
300  for (size_t i = 0; i < branches_.size(); i++) {
301  branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
302  }
303  }
304 
306  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
308  bool showZero) const override {
309  os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
310  << "\"]\n";
311  size_t B = branches_.size();
312  for (size_t i = 0; i < B; i++) {
313  const NodePtr& branch = branches_[i];
314 
315  // Check if zero
316  if (!showZero) {
317  const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
318  if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
319  }
320 
321  os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
322  if (B == 2 && i == 0) os << " [style=dashed]";
323  os << std::endl;
324  branch->dot(os, labelFormatter, valueFormatter, showZero);
325  }
326  }
327 
329  bool sameLeaf(const Leaf& q) const override {
330  return false;
331  }
332 
334  bool sameLeaf(const Node& q) const override {
335  return (q.isLeaf() && q.sameLeaf(*this));
336  }
337 
339  bool equals(const Node& q, const CompareFunc& compare) const override {
340  const Choice* other = dynamic_cast<const Choice*>(&q);
341  if (!other) return false;
342  if (this->label_ != other->label_) return false;
343  if (branches_.size() != other->branches_.size()) return false;
344  // we don't care about shared pointers being equal here
345  for (size_t i = 0; i < branches_.size(); i++)
346  if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
347  return false;
348  return true;
349  }
350 
352  const Y& operator()(const Assignment<L>& x) const override {
353 #ifndef NDEBUG
354  typename Assignment<L>::const_iterator it = x.find(label_);
355  if (it == x.end()) {
356  std::cout << "Trying to find value for " << label_ << std::endl;
357  throw std::invalid_argument(
358  "DecisionTree::operator(): value undefined for a label");
359  }
360 #endif
361  size_t index = x.at(label_);
362  NodePtr child = branches_[index];
363  return (*child)(x);
364  }
365 
367  Choice(const L& label, const Choice& f, const Unary& op) :
368  label_(label), allSame_(true) {
369  branches_.reserve(f.branches_.size()); // reserve space
370  for (const NodePtr& branch : f.branches_) {
371  push_back(branch->apply(op));
372  }
373  }
374 
385  Choice(const L& label, const Choice& f, const UnaryAssignment& op,
386  const Assignment<L>& assignment)
387  : label_(label), allSame_(true) {
388  branches_.reserve(f.branches_.size()); // reserve space
389 
390  Assignment<L> assignment_ = assignment;
391 
392  for (size_t i = 0; i < f.branches_.size(); i++) {
393  assignment_[label_] = i; // Set assignment for label to i
394 
395  const NodePtr branch = f.branches_[i];
396  push_back(branch->apply(op, assignment_));
397 
398  // Remove the assignment so we are backtracking
399  auto assignment_it = assignment_.find(label_);
400  assignment_.erase(assignment_it);
401  }
402  }
403 
405  NodePtr apply(const Unary& op) const override {
406  auto r = std::make_shared<Choice>(label_, *this, op);
407  return Unique(r);
408  }
409 
412  const Assignment<L>& assignment) const override {
413  auto r = std::make_shared<Choice>(label_, *this, op, assignment);
414  return Unique(r);
415  }
416 
417  // Apply binary operator "h = f op g" on Choice node
418  // Note op is not assumed commutative so we need to keep track of order
419  // Simply calls apply on argument to call correct virtual method:
420  // fC.apply_f_op_g(gL) -> gL.apply_g_op_fC(fC) -> (Leaf)
421  // fC.apply_f_op_g(gC) -> gC.apply_g_op_fC(fC) -> (below)
422  NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
423  return g.apply_g_op_fC(*this, op);
424  }
425 
426  // If second argument of binary op is Leaf node, recurse on branches
427  NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
428  auto h = std::make_shared<Choice>(label(), nrChoices());
429  for (auto&& branch : branches_)
430  h->push_back(fL.apply_f_op_g(*branch, op));
431  return Unique(h);
432  }
433 
434  // If second argument of binary op is Choice, call constructor
435  NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
436  auto h = std::make_shared<Choice>(fC, *this, op);
437  return Unique(h);
438  }
439 
440  // If second argument of binary op is Leaf
441  template<typename OP>
442  NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
443  auto h = std::make_shared<Choice>(label(), nrChoices());
444  for (auto&& branch : branches_)
445  h->push_back(branch->apply_f_op_g(gL, op));
446  return Unique(h);
447  }
448 
450  NodePtr choose(const L& label, size_t index) const override {
451  if (label_ == label) return branches_[index]; // choose branch
452 
453  // second case, not label of interest, just recurse
454  auto r = std::make_shared<Choice>(label_, branches_.size());
455  for (auto&& branch : branches_) {
456  r->push_back(branch->choose(label, index));
457  }
458 
459  return Unique(r);
460  }
461 
462  private:
464 
465 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
466 
467  friend class boost::serialization::access;
468  template <class ARCHIVE>
469  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
470  ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
471  ar& BOOST_SERIALIZATION_NVP(label_);
472  ar& BOOST_SERIALIZATION_NVP(branches_);
473  ar& BOOST_SERIALIZATION_NVP(allSame_);
474  }
475 #endif
476  }; // Choice
477 
478  /****************************************************************************/
479  // DecisionTree
480  /****************************************************************************/
481  template<typename L, typename Y>
483 
484  template<typename L, typename Y>
486  root_(root) {}
487 
488  /****************************************************************************/
489  template<typename L, typename Y>
491  root_ = NodePtr(new Leaf(y));
492  }
493 
494  /****************************************************************************/
495  template <typename L, typename Y>
496  DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
497  auto a = std::make_shared<Choice>(label, 2);
498  NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
499  a->push_back(l1);
500  a->push_back(l2);
501  root_ = Choice::Unique(a);
502  }
503 
504  /****************************************************************************/
505  template <typename L, typename Y>
507  const Y& y2) {
508  if (labelC.second != 2) throw std::invalid_argument(
509  "DecisionTree: binary constructor called with non-binary label");
510  auto a = std::make_shared<Choice>(labelC.first, 2);
511  NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
512  a->push_back(l1);
513  a->push_back(l2);
514  root_ = Choice::Unique(a);
515  }
516 
517  /****************************************************************************/
518  template<typename L, typename Y>
519  DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
520  const std::vector<Y>& ys) {
521  // call recursive Create
522  root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
523  }
524 
525  /****************************************************************************/
526  template<typename L, typename Y>
527  DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
528  const std::string& table) {
529  // Convert std::string to values of type Y
530  std::vector<Y> ys;
531  std::istringstream iss(table);
532  copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
533  back_inserter(ys));
534 
535  // now call recursive Create
536  root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
537  }
538 
539  /****************************************************************************/
540  template<typename L, typename Y>
541  template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
542  Iterator begin, Iterator end, const L& label) {
543  root_ = compose(begin, end, label);
544  }
545 
546  /****************************************************************************/
547  template<typename L, typename Y>
549  const DecisionTree& f0, const DecisionTree& f1) {
550  const std::vector<DecisionTree> functions{f0, f1};
551  root_ = compose(functions.begin(), functions.end(), label);
552  }
553 
554  /****************************************************************************/
555  template <typename L, typename Y>
556  template <typename X, typename Func>
558  Func Y_of_X) {
559  // Define functor for identity mapping of node label.
560  auto L_of_L = [](const L& label) { return label; };
561  root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
562  }
563 
564  /****************************************************************************/
565  template <typename L, typename Y>
566  template <typename M, typename X, typename Func>
568  const std::map<M, L>& map, Func Y_of_X) {
569  auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
570  root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
571  }
572 
573  /****************************************************************************/
574  // Called by two constructors above.
575  // Takes a label and a corresponding range of decision trees, and creates a
576  // new decision tree. However, the order of the labels needs to be respected,
577  // so we cannot just create a root Choice node on the label: if the label is
578  // not the highest label, we need a complicated/ expensive recursive call.
579  template <typename L, typename Y>
580  template <typename Iterator>
582  Iterator begin, Iterator end, const L& label) const {
583  // find highest label among branches
584  std::optional<L> highestLabel;
585  size_t nrChoices = 0;
586  for (Iterator it = begin; it != end; it++) {
587  if (it->root_->isLeaf())
588  continue;
589  std::shared_ptr<const Choice> c =
590  std::dynamic_pointer_cast<const Choice>(it->root_);
591  if (!highestLabel || c->label() > *highestLabel) {
592  highestLabel = c->label();
593  nrChoices = c->nrChoices();
594  }
595  }
596 
597  // if label is already in correct order, just put together a choice on label
598  if (!nrChoices || !highestLabel || label > *highestLabel) {
599  auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
600  for (Iterator it = begin; it != end; it++)
601  choiceOnLabel->push_back(it->root_);
602  // If no reordering, no need to call Choice::Unique
603  return choiceOnLabel;
604  } else {
605  // Set up a new choice on the highest label
606  auto choiceOnHighestLabel =
607  std::make_shared<Choice>(*highestLabel, nrChoices);
608  // now, for all possible values of highestLabel
609  for (size_t index = 0; index < nrChoices; index++) {
610  // make a new set of functions for composing by iterating over the given
611  // functions, and selecting the appropriate branch.
612  std::vector<DecisionTree> functions;
613  for (Iterator it = begin; it != end; it++) {
614  // by restricting the input functions to value i for labelBelow
615  DecisionTree chosen = it->choose(*highestLabel, index);
616  functions.push_back(chosen);
617  }
618  // We then recurse, for all values of the highest label
619  NodePtr fi = compose(functions.begin(), functions.end(), label);
620  choiceOnHighestLabel->push_back(fi);
621  }
622  return choiceOnHighestLabel;
623  }
624  }
625 
626  /****************************************************************************/
627  // "build" is a bit of a complicated thing, but very useful.
628  // It takes a range of labels and a corresponding range of values,
629  // and builds a decision tree, as follows:
630  // - if there is only one label, creates a choice node with values in leaves
631  // - otherwise, it evenly splits up the range of values and creates a tree for
632  // each sub-range, and assigns that tree to first label's choices
633  // Example:
634  // build([B A],[1 2 3 4]) would call
635  // build([A],[1 2])
636  // build([A],[3 4])
637  // and produce
638  // B=0
639  // A=0: 1
640  // A=1: 2
641  // B=1
642  // A=0: 3
643  // A=1: 4
644  // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
645  // exactly the same tree as above: the highest label is always the root.
646  // However, it will be *way* faster if labels are given highest to lowest.
647  template<typename L, typename Y>
648  template<typename It, typename ValueIt>
650  It begin, It end, ValueIt beginY, ValueIt endY) const {
651  // get crucial counts
652  size_t nrChoices = begin->second;
653  size_t size = endY - beginY;
654 
655  // Find the next key to work on
656  It labelC = begin + 1;
657  if (labelC == end) {
658  // Base case: only one key left
659  // Create a simple choice node with values as leaves.
660  if (size != nrChoices) {
661  std::cout << "Trying to create DD on " << begin->first << std::endl;
662  std::cout << "DecisionTree::create: expected " << nrChoices
663  << " values but got " << size << " instead" << std::endl;
664  throw std::invalid_argument("DecisionTree::create invalid argument");
665  }
666  auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
667  for (ValueIt y = beginY; y != endY; y++) {
668  choice->push_back(NodePtr(new Leaf(*y)));
669  }
670  return choice;
671  }
672 
673  // Recursive case: perform "Shannon expansion"
674  // Creates one tree (i.e.,function) for each choice of current key
675  // by calling create recursively, and then puts them all together.
676  std::vector<DecisionTree> functions;
677  size_t split = size / nrChoices;
678  for (size_t i = 0; i < nrChoices; i++, beginY += split) {
679  NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
680  functions.emplace_back(f);
681  }
682  return compose(functions.begin(), functions.end(), begin->first);
683  }
684 
685  /****************************************************************************/
686  // Top-level factory method, which takes a range of labels and a corresponding
687  // range of values, and creates a decision tree.
688  template<typename L, typename Y>
689  template<typename It, typename ValueIt>
691  It begin, It end, ValueIt beginY, ValueIt endY) const {
692  auto node = build(begin, end, beginY, endY);
693  if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
694  return Choice::Unique(choice);
695  } else {
696  return node;
697  }
698  }
699 
700  /****************************************************************************/
701  template <typename L, typename Y>
702  template <typename M, typename X>
704  const typename DecisionTree<M, X>::NodePtr& f,
705  std::function<L(const M&)> L_of_M,
706  std::function<Y(const X&)> Y_of_X) const {
707  using LY = DecisionTree<L, Y>;
708 
709  // Ugliness below because apparently we can't have templated virtual
710  // functions.
711  // If leaf, apply unary conversion "op" and create a unique leaf.
712  using MXLeaf = typename DecisionTree<M, X>::Leaf;
713  if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
714  return NodePtr(new Leaf(Y_of_X(leaf->constant())));
715  }
716 
717  // Check if Choice
718  using MXChoice = typename DecisionTree<M, X>::Choice;
719  auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
720  if (!choice) throw std::invalid_argument(
721  "DecisionTree::convertFrom: Invalid NodePtr");
722 
723  // get new label
724  const M oldLabel = choice->label();
725  const L newLabel = L_of_M(oldLabel);
726 
727  // put together via Shannon expansion otherwise not sorted.
728  std::vector<LY> functions;
729  for (auto&& branch : choice->branches()) {
730  functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
731  }
732  return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
733  }
734 
735  /****************************************************************************/
746  template <typename L, typename Y>
747  struct Visit {
748  using F = std::function<void(const Y&)>;
749  explicit Visit(F f) : f(f) {}
750  F f;
751 
753  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
754  using Leaf = typename DecisionTree<L, Y>::Leaf;
755  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
756  return f(leaf->constant());
757 
758  using Choice = typename DecisionTree<L, Y>::Choice;
759  auto choice = std::dynamic_pointer_cast<const Choice>(node);
760  if (!choice)
761  throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
762  for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
763  }
764  };
765 
766  template <typename L, typename Y>
767  template <typename Func>
768  void DecisionTree<L, Y>::visit(Func f) const {
769  Visit<L, Y> visit(f);
770  visit(root_);
771  }
772 
773  /****************************************************************************/
783  template <typename L, typename Y>
784  struct VisitLeaf {
785  using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
786  explicit VisitLeaf(F f) : f(f) {}
787  F f;
788 
790  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
791  using Leaf = typename DecisionTree<L, Y>::Leaf;
792  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
793  return f(*leaf);
794 
795  using Choice = typename DecisionTree<L, Y>::Choice;
796  auto choice = std::dynamic_pointer_cast<const Choice>(node);
797  if (!choice)
798  throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
799  for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
800  }
801  };
802 
803  template <typename L, typename Y>
804  template <typename Func>
805  void DecisionTree<L, Y>::visitLeaf(Func f) const {
806  VisitLeaf<L, Y> visit(f);
807  visit(root_);
808  }
809 
810  /****************************************************************************/
817  template <typename L, typename Y>
818  struct VisitWith {
819  using F = std::function<void(const Assignment<L>&, const Y&)>;
820  explicit VisitWith(F f) : f(f) {}
822  F f;
823 
825  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
826  using Leaf = typename DecisionTree<L, Y>::Leaf;
827  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
828  return f(assignment, leaf->constant());
829 
830  using Choice = typename DecisionTree<L, Y>::Choice;
831  auto choice = std::dynamic_pointer_cast<const Choice>(node);
832  if (!choice)
833  throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
834  for (size_t i = 0; i < choice->nrChoices(); i++) {
835  assignment[choice->label()] = i; // Set assignment for label to i
836 
837  (*this)(choice->branches()[i]); // recurse!
838 
839  // Remove the choice so we are backtracking
840  auto choice_it = assignment.find(choice->label());
841  assignment.erase(choice_it);
842  }
843  }
844  };
845 
846  template <typename L, typename Y>
847  template <typename Func>
848  void DecisionTree<L, Y>::visitWith(Func f) const {
849  VisitWith<L, Y> visit(f);
850  visit(root_);
851  }
852 
853  /****************************************************************************/
854  template <typename L, typename Y>
856  size_t total = 0;
857  visit([&total](const Y& node) { total += 1; });
858  return total;
859  }
860 
861  /****************************************************************************/
862  // fold is just done with a visit
863  template <typename L, typename Y>
864  template <typename Func, typename X>
865  X DecisionTree<L, Y>::fold(Func f, X x0) const {
866  visit([&](const Y& y) { x0 = f(y, x0); });
867  return x0;
868  }
869 
870  /****************************************************************************/
884  template <typename L, typename Y>
885  std::set<L> DecisionTree<L, Y>::labels() const {
886  std::set<L> unique;
887  auto f = [&](const Assignment<L>& assignment, const Y&) {
888  for (auto&& kv : assignment) {
889  unique.insert(kv.first);
890  }
891  };
892  visitWith(f);
893  return unique;
894  }
895 
896 /****************************************************************************/
897  template <typename L, typename Y>
899  const CompareFunc& compare) const {
900  return root_->equals(*other.root_, compare);
901  }
902 
903  template <typename L, typename Y>
904  void DecisionTree<L, Y>::print(const std::string& s,
905  const LabelFormatter& labelFormatter,
906  const ValueFormatter& valueFormatter) const {
907  root_->print(s, labelFormatter, valueFormatter);
908  }
909 
910  template<typename L, typename Y>
912  return root_->equals(*other.root_);
913  }
914 
915  template<typename L, typename Y>
917  return root_->operator ()(x);
918  }
919 
920  template<typename L, typename Y>
922  // It is unclear what should happen if tree is empty:
923  if (empty()) {
924  throw std::runtime_error(
925  "DecisionTree::apply(unary op) undefined for empty tree.");
926  }
927  return DecisionTree(root_->apply(op));
928  }
929 
931  template <typename L, typename Y>
933  const UnaryAssignment& op) const {
934  // It is unclear what should happen if tree is empty:
935  if (empty()) {
936  throw std::runtime_error(
937  "DecisionTree::apply(unary op) undefined for empty tree.");
938  }
939  Assignment<L> assignment;
940  return DecisionTree(root_->apply(op, assignment));
941  }
942 
943  /****************************************************************************/
944  template<typename L, typename Y>
946  const Binary& op) const {
947  // It is unclear what should happen if either tree is empty:
948  if (empty() || g.empty()) {
949  throw std::runtime_error(
950  "DecisionTree::apply(binary op) undefined for empty trees.");
951  }
952  // apply the operaton on the root of both diagrams
953  NodePtr h = root_->apply_f_op_g(*g.root_, op);
954  // create a new class with the resulting root "h"
956  return result;
957  }
958 
959  /****************************************************************************/
960  // The way this works:
961  // We have an ADT, picture it as a tree.
962  // At a certain depth, we have a branch on "label".
963  // The function "choose(label,index)" will return a tree of one less depth,
964  // where there is no more branch on "label": only the subtree under that
965  // branch point corresponding to the value "index" is left instead.
966  // The function below get all these smaller trees and "ops" them together.
967  // This implements marginalization in Darwiche09book, pg 330
968  template<typename L, typename Y>
970  size_t cardinality, const Binary& op) const {
971  DecisionTree result = choose(label, 0);
972  for (size_t index = 1; index < cardinality; index++) {
973  DecisionTree chosen = choose(label, index);
974  result = result.apply(chosen, op);
975  }
976  return result;
977  }
978 
979  /****************************************************************************/
980  template <typename L, typename Y>
981  void DecisionTree<L, Y>::dot(std::ostream& os,
982  const LabelFormatter& labelFormatter,
984  bool showZero) const {
985  os << "digraph G {\n";
986  root_->dot(os, labelFormatter, valueFormatter, showZero);
987  os << " [ordering=out]}" << std::endl;
988  }
989 
990  template <typename L, typename Y>
991  void DecisionTree<L, Y>::dot(const std::string& name,
992  const LabelFormatter& labelFormatter,
994  bool showZero) const {
995  std::ofstream os((name + ".dot").c_str());
996  dot(os, labelFormatter, valueFormatter, showZero);
997  int result =
998  system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
999  .c_str());
1000  if (result == -1)
1001  throw std::runtime_error("DecisionTree::dot system call failed");
1002  }
1003 
1004  template <typename L, typename Y>
1005  std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
1007  bool showZero) const {
1008  std::stringstream ss;
1009  dot(ss, labelFormatter, valueFormatter, showZero);
1010  return ss.str();
1011  }
1012 
1013 /******************************************************************************/
1014 
1015  } // namespace gtsam
gtsam::DecisionTree::Choice::branches_
std::vector< NodePtr > branches_
Definition: DecisionTree-inl.h:169
gtsam::DecisionTree::Leaf::apply_f_op_g
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
Definition: DecisionTree-inl.h:123
create
ADT create(const Signature &signature)
Definition: testAlgebraicDecisionTree.cpp:129
compare
bool compare
Definition: SolverComparer.cpp:98
gtsam::DecisionTree::Choice::sameLeaf
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
Definition: DecisionTree-inl.h:334
gtsam::VisitWith::F
std::function< void(const Assignment< L > &, const Y &)> F
Definition: DecisionTree-inl.h:819
Eigen::internal::print
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
Definition: NEON/PacketMath.h:3115
gtsam::DecisionTree::Choice::choose
NodePtr choose(const L &label, size_t index) const override
Definition: DecisionTree-inl.h:450
gtsam::DecisionTree::LabelFormatter
std::function< std::string(L)> LabelFormatter
Definition: DecisionTree.h:71
name
Annotation for function names.
Definition: attr.h:51
gtsam::DecisionTree::Choice::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print (as a tree).
Definition: DecisionTree-inl.h:296
gtsam::DecisionTree::Leaf::apply
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
Definition: DecisionTree-inl.h:112
Y
const char Y
Definition: test/EulerAngles.cpp:31
gtsam::DecisionTree::CompareFunc
std::function< bool(const Y &, const Y &)> CompareFunc
Definition: DecisionTree.h:73
Leaf
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
Definition: testSymbolicEliminationTree.cpp:78
gtsam::Visit::F
std::function< void(const Y &)> F
Definition: DecisionTree-inl.h:748
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::VisitWith
Definition: DecisionTree-inl.h:818
l2
gtsam::Key l2
Definition: testLinearContainerFactor.cpp:24
c
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:898
gtsam::DecisionTree::Choice::apply_g_op_fC
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
Definition: DecisionTree-inl.h:435
gtsam::DecisionTree::Leaf::apply_g_op_fC
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
Definition: DecisionTree-inl.h:135
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition: gnuplot_common_settings.hh:12
gtsam::DecisionTree::Choice::dot
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
Definition: DecisionTree-inl.h:306
gtsam::DecisionTree::ValueFormatter
std::function< std::string(Y)> ValueFormatter
Definition: DecisionTree.h:72
gtsam::Visit
Definition: DecisionTree-inl.h:747
B
Definition: test_numpy_dtypes.cpp:299
noxfile.build
None build(nox.Session session)
Definition: noxfile.py:96
gtsam::DecisionTree::Choice::Choice
Choice(const L &label, const Choice &f, const Unary &op)
Construct from applying unary op to a Choice node.
Definition: DecisionTree-inl.h:367
gtsam::DecisionTree::Choice
Definition: DecisionTree-inl.h:164
gtsam::DecisionTree::DecisionTree
DecisionTree()
Definition: DecisionTree-inl.h:482
copy
int EIGEN_BLAS_FUNC() copy(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy)
Definition: level1_impl.h:29
X
#define X
Definition: icosphere.cpp:20
gtsam::DecisionTree::Choice::~Choice
~Choice() override
Definition: DecisionTree-inl.h:184
gtsam::DecisionTree::Leaf::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print
Definition: DecisionTree-inl.h:85
h
const double h
Definition: testSimpleHelicopter.cpp:19
gtsam::DecisionTree::choose
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:370
os
ofstream os("timeSchurFactors.csv")
gtsam::VisitLeaf
Definition: DecisionTree-inl.h:784
gtsam::DecisionTree::Leaf::Leaf
Leaf(const Y &constant)
Constructor from constant.
Definition: DecisionTree-inl.h:60
gtsam::DecisionTree::Leaf::apply
NodePtr apply(const Unary &op) const override
Definition: DecisionTree-inl.h:106
result
Values result
Definition: OdometryOptimize.cpp:8
size
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
dot
Scalar EIGEN_BLAS_FUNC() dot(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy)
Definition: level1_real_impl.h:49
ss
static std::stringstream ss
Definition: testBTree.cpp:31
gtsam::VisitLeaf::f
F f
folding function object.
Definition: DecisionTree-inl.h:787
gtsam::DecisionTree::Choice::Choice
Choice(const L &label, size_t count)
Constructor, given choice label and mandatory expected branch count.
Definition: DecisionTree-inl.h:240
labels
std::vector< std::string > labels
Definition: dense_solvers.cpp:11
gtsam::DecisionTree::Leaf::Leaf
Leaf()
Default constructor for serialization.
Definition: DecisionTree-inl.h:57
gtsam::DecisionTree::Choice::isLeaf
bool isLeaf() const override
Definition: DecisionTree-inl.h:237
gtsam::DecisionTree::Unary
std::function< Y(const Y &)> Unary
Definition: DecisionTree.h:76
gtsam::DecisionTree::Choice::Unique
static NodePtr Unique(const NodePtr &node)
Merge branches with equal leaf values for every choice node in a decision tree. If all branches are t...
Definition: DecisionTree-inl.h:207
id
static const Similarity3 id
Definition: testSimilarity3.cpp:44
y1
double y1(double x)
Definition: j1.c:199
gtsam::VisitWith::assignment
Assignment< L > assignment
Assignment, mutating through recursion.
Definition: DecisionTree-inl.h:821
table
ArrayXXf table(10, 4)
gtsam::DecisionTree::Leaf::sameLeaf
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
Definition: DecisionTree-inl.h:68
gtsam::DecisionTree::Leaf::constant
const Y & constant() const
Return the constant.
Definition: DecisionTree-inl.h:63
Eigen::numext::q
EIGEN_DEVICE_FUNC const Scalar & q
Definition: SpecialFunctionsImpl.h:1984
gtsam::DecisionTree::Choice::allSame_
size_t allSame_
Definition: DecisionTree-inl.h:176
gtsam::VisitWith::f
F f
folding function object.
Definition: DecisionTree-inl.h:822
operator()
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
Definition: IndexedViewMethods.h:73
gtsam::DecisionTree::Binary
std::function< Y(const Y &, const Y &)> Binary
Definition: DecisionTree.h:78
gtsam::VisitWith::operator()
void operator()(const typename DecisionTree< L, Y >::NodePtr &node)
Do a depth-first visit on the tree rooted at node.
Definition: DecisionTree-inl.h:825
DecisionTree.h
Decision Tree for use in DiscreteFactors.
gtsam::DecisionTree::Choice::apply_f_op_g
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
Definition: DecisionTree-inl.h:422
x0
static Symbol x0('x', 0)
gtsam::DecisionTree::Leaf::apply_g_op_fL
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
Definition: DecisionTree-inl.h:128
gtsam::DecisionTree::Choice::Choice
Choice()
Default constructor for serialization.
Definition: DecisionTree-inl.h:182
L
MatrixXd L
Definition: LLT_example.cpp:6
gtsam::DecisionTree::Choice::operator()
const Y & operator()(const Assignment< L > &x) const override
evaluate
Definition: DecisionTree-inl.h:352
gtsam::DecisionTree::Choice::apply
NodePtr apply(const Unary &op) const override
apply unary operator.
Definition: DecisionTree-inl.h:405
gtsam::Assignment
Definition: Assignment.h:37
gtsam::DecisionTree::Choice::label_
L label_
Definition: DecisionTree-inl.h:166
gtsam::VisitLeaf::operator()
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
Definition: DecisionTree-inl.h:790
gtsam::DecisionTree::Leaf::isLeaf
bool isLeaf() const override
Definition: DecisionTree-inl.h:144
gtsam::DecisionTree::Choice::sameLeaf
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
Definition: DecisionTree-inl.h:329
gtsam::DecisionTree::Choice::branches
const std::vector< NodePtr > & branches() const
Definition: DecisionTree-inl.h:282
g
void g(const string &key, int i)
Definition: testBTree.cpp:41
y
Scalar * y
Definition: level1_cplx_impl.h:124
gtsam::DecisionTree::Choice::Choice
Choice(const L &label, const Choice &f, const UnaryAssignment &op, const Assignment< L > &assignment)
Constructor which accepts a UnaryAssignment op and the corresponding assignment.
Definition: DecisionTree-inl.h:385
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
gtsam::DecisionTree::Choice::ChoicePtr
std::shared_ptr< const Choice > ChoicePtr
Definition: DecisionTree-inl.h:178
gtsam::DecisionTree
a decision tree is a function from assignments to values.
Definition: DecisionTree.h:63
gtsam::DecisionTree::Leaf::equals
bool equals(const Node &q, const CompareFunc &compare) const override
equality up to tolerance
Definition: DecisionTree-inl.h:78
gtsam::DecisionTree::Choice::apply_fC_op_gL
NodePtr apply_fC_op_gL(const Leaf &gL, OP op) const
Definition: DecisionTree-inl.h:442
gtsam::DecisionTree::Choice::equals
bool equals(const Node &q, const CompareFunc &compare) const override
equality
Definition: DecisionTree-inl.h:339
gtsam::DecisionTree::Leaf
Definition: DecisionTree-inl.h:52
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
Eigen::bfloat16_impl::operator==
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:218
empty
Definition: test_copy_move.cpp:19
gtsam::Visit::operator()
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
Definition: DecisionTree-inl.h:753
gtsam
traits
Definition: chartTesting.h:28
gtsam::VisitLeaf::F
std::function< void(const typename DecisionTree< L, Y >::Leaf &)> F
Definition: DecisionTree-inl.h:785
gtsam::DecisionTree::Leaf::choose
NodePtr choose(const L &label, size_t index) const override
Definition: DecisionTree-inl.h:140
OP
#define OP(X)
Definition: gtsam/3rdparty/Eigen/blas/common.h:47
gtsam::testing::compose
T compose(const T &t1, const T &t2)
Definition: lieProxies.h:39
l1
gtsam::Key l1
Definition: testLinearContainerFactor.cpp:24
gtsam::DecisionTree::UnaryAssignment
std::function< Y(const Assignment< L > &, const Y &)> UnaryAssignment
Definition: DecisionTree.h:77
gtsam::DecisionTree< Key, double >::LabelC
std::pair< Key, size_t > LabelC
Definition: DecisionTree.h:81
gtsam::DecisionTree::Choice::apply
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
Definition: DecisionTree-inl.h:411
gtsam::DecisionTree::Leaf::constant_
Y constant_
Definition: DecisionTree-inl.h:54
gtsam::DecisionTree::Choice::label
const L & label() const
Return the label of this choice node.
Definition: DecisionTree-inl.h:274
c_str
const char * c_str(Args &&...args)
Definition: internals.h:599
gtsam::DecisionTree::Choice::push_back
void push_back(const NodePtr &node)
Definition: DecisionTree-inl.h:287
leaf
Definition: testExpressionFactor.cpp:42
gtsam::apply
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:427
gtsam::DecisionTree::Node
Definition: DecisionTree.h:88
gtsam::DecisionTree::Choice::apply_g_op_fL
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
Definition: DecisionTree-inl.h:427
gtsam::split
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
Definition: graph-inl.h:245
gtsam::DecisionTree::Leaf::operator()
const Y & operator()(const Assignment< L > &x) const override
Definition: DecisionTree-inl.h:101
gtsam::DecisionTree::NodePtr
typename Node::Ptr NodePtr
Definition: DecisionTree.h:147
unary::f1
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Definition: testExpression.cpp:79
gtsam::DecisionTree::Choice::nrChoices
size_t nrChoices() const
Definition: DecisionTree-inl.h:278
gtsam::DecisionTree::Choice::Choice
Choice(const Choice &f, const Choice &g, const Binary &op)
Construct from applying binary op to two Choice nodes.
Definition: DecisionTree-inl.h:246
Eigen::placeholders::end
static const EIGEN_DEPRECATED end_t end
Definition: IndexedViewHelper.h:181
gtsam::DecisionTree::Leaf::dot
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
Definition: DecisionTree-inl.h:91
gtsam::valueFormatter
static std::string valueFormatter(const double &v)
Definition: DecisionTreeFactor.cpp:266
gtsam::DecisionTree::Leaf::sameLeaf
bool sameLeaf(const Node &q) const override
polymorphic equality: is q a leaf and is it the same as this leaf?
Definition: DecisionTree-inl.h:73
choose
static const T & choose(int layout, const T &col, const T &row)
Definition: cxx11_tensor_block_access.cpp:27
test_callbacks.value
value
Definition: test_callbacks.py:158
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::Visit::f
F f
folding function object.
Definition: DecisionTree-inl.h:750
M
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51


gtsam
Author(s):
autogenerated on Tue Jun 25 2024 03:00:46