DecisionTree-inl.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #pragma once
21 
23 
24 #include <algorithm>
25 
26 #include <cmath>
27 #include <fstream>
28 #include <list>
29 #include <map>
30 #include <set>
31 #include <sstream>
32 #include <string>
33 #include <vector>
34 #include <optional>
35 #include <cassert>
36 #include <iterator>
37 
38 namespace gtsam {
39 
40  /****************************************************************************/
41  // Node
42  /****************************************************************************/
43 #ifdef DT_DEBUG_MEMORY
44  template<typename L, typename Y>
45  int DecisionTree<L, Y>::Node::nrNodes = 0;
46 #endif
47 
48  /****************************************************************************/
49  // Leaf
50  /****************************************************************************/
51  template <typename L, typename Y>
52  struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
55 
60 
62  Leaf() {}
63 
65  Leaf(const Y& constant, size_t nrAssignments = 1)
66  : constant_(constant), nrAssignments_(nrAssignments) {}
67 
69  const Y& constant() const {
70  return constant_;
71  }
72 
74  size_t nrAssignments() const { return nrAssignments_; }
75 
77  bool sameLeaf(const Leaf& q) const override {
78  return constant_ == q.constant_;
79  }
80 
82  bool sameLeaf(const Node& q) const override {
83  return (q.isLeaf() && q.sameLeaf(*this));
84  }
85 
87  bool equals(const Node& q, const CompareFunc& compare) const override {
88  const Leaf* other = dynamic_cast<const Leaf*>(&q);
89  if (!other) return false;
90  return compare(this->constant_, other->constant_);
91  }
92 
94  void print(const std::string& s, const LabelFormatter& labelFormatter,
95  const ValueFormatter& valueFormatter) const override {
96  std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
97  }
98 
100  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
102  bool showZero) const override {
103  std::string value = valueFormatter(constant_);
104  if (showZero || value.compare("0"))
105  os << "\"" << this->id() << "\" [label=\"" << value
106  << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
107  }
108 
110  const Y& operator()(const Assignment<L>& x) const override {
111  return constant_;
112  }
113 
115  NodePtr apply(const Unary& op) const override {
116  NodePtr f(new Leaf(op(constant_), nrAssignments_));
117  return f;
118  }
119 
122  const Assignment<L>& assignment) const override {
123  NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
124  return f;
125  }
126 
127  // Apply binary operator "h = f op g" on Leaf node
128  // Note op is not assumed commutative so we need to keep track of order
129  // Simply calls apply on argument to call correct virtual method:
130  // fL.apply_f_op_g(gL) -> gL.apply_g_op_fL(fL) (below)
131  // fL.apply_f_op_g(gC) -> gC.apply_g_op_fL(fL) (Choice)
132  NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
133  return g.apply_g_op_fL(*this, op);
134  }
135 
136  // Applying binary operator to two leaves results in a leaf
137  NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
138  // fL op gL
139  NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
140  return h;
141  }
142 
143  // If second argument is a Choice node, call it's apply with leaf as second
144  NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
145  return fC.apply_fC_op_gL(*this, op); // operand order back to normal
146  }
147 
149  NodePtr choose(const L& label, size_t index) const override {
150  return NodePtr(new Leaf(constant(), nrAssignments()));
151  }
152 
153  bool isLeaf() const override { return true; }
154 
155  private:
157 
158 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
159 
160  friend class boost::serialization::access;
161  template <class ARCHIVE>
162  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
163  ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
164  ar& BOOST_SERIALIZATION_NVP(constant_);
165  ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
166  }
167 #endif
168  }; // Leaf
169 
170  /****************************************************************************/
171  // Choice
172  /****************************************************************************/
173  template<typename L, typename Y>
174  struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
177 
179  std::vector<NodePtr> branches_;
180 
181  private:
186  size_t allSame_;
187 
188  using ChoicePtr = std::shared_ptr<const Choice>;
189 
190  public:
192  Choice() {}
193 
194  ~Choice() override {
195 #ifdef DT_DEBUG_MEMORY
196  std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
197  << std::endl;
198 #endif
199  }
200 
202  static NodePtr Unique(const ChoicePtr& f) {
203 #ifndef GTSAM_DT_NO_PRUNING
204  if (f->allSame_) {
205  assert(f->branches().size() > 0);
206  NodePtr f0 = f->branches_[0];
207 
208  size_t nrAssignments = 0;
209  for(auto branch: f->branches()) {
210  assert(branch->isLeaf());
211  nrAssignments +=
212  std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
213  }
214  NodePtr newLeaf(
215  new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
216  nrAssignments));
217  return newLeaf;
218  } else
219 #endif
220  return f;
221  }
222 
223  bool isLeaf() const override { return false; }
224 
226  Choice(const L& label, size_t count) :
227  label_(label), allSame_(true) {
228  branches_.reserve(count);
229  }
230 
232  Choice(const Choice& f, const Choice& g, const Binary& op) :
233  allSame_(true) {
234  // Choose what to do based on label
235  if (f.label() > g.label()) {
236  // f higher than g
237  label_ = f.label();
238  size_t count = f.nrChoices();
239  branches_.reserve(count);
240  for (size_t i = 0; i < count; i++)
241  push_back(f.branches_[i]->apply_f_op_g(g, op));
242  } else if (g.label() > f.label()) {
243  // f lower than g
244  label_ = g.label();
245  size_t count = g.nrChoices();
246  branches_.reserve(count);
247  for (size_t i = 0; i < count; i++)
248  push_back(g.branches_[i]->apply_g_op_fC(f, op));
249  } else {
250  // f same level as g
251  label_ = f.label();
252  size_t count = f.nrChoices();
253  branches_.reserve(count);
254  for (size_t i = 0; i < count; i++)
255  push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
256  }
257  }
258 
260  const L& label() const {
261  return label_;
262  }
263 
264  size_t nrChoices() const {
265  return branches_.size();
266  }
267 
268  const std::vector<NodePtr>& branches() const {
269  return branches_;
270  }
271 
273  void push_back(const NodePtr& node) {
274  // allSame_ is restricted to leaf nodes in a decision tree
275  if (allSame_ && !branches_.empty()) {
276  allSame_ = node->sameLeaf(*branches_.back());
277  }
278  branches_.push_back(node);
279  }
280 
282  void print(const std::string& s, const LabelFormatter& labelFormatter,
283  const ValueFormatter& valueFormatter) const override {
284  std::cout << s << " Choice(";
285  std::cout << labelFormatter(label_) << ") " << std::endl;
286  for (size_t i = 0; i < branches_.size(); i++) {
287  branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
288  }
289  }
290 
292  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
294  bool showZero) const override {
295  os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
296  << "\"]\n";
297  size_t B = branches_.size();
298  for (size_t i = 0; i < B; i++) {
299  const NodePtr& branch = branches_[i];
300 
301  // Check if zero
302  if (!showZero) {
303  const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
304  if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
305  }
306 
307  os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
308  if (B == 2 && i == 0) os << " [style=dashed]";
309  os << std::endl;
310  branch->dot(os, labelFormatter, valueFormatter, showZero);
311  }
312  }
313 
315  bool sameLeaf(const Leaf& q) const override {
316  return false;
317  }
318 
320  bool sameLeaf(const Node& q) const override {
321  return (q.isLeaf() && q.sameLeaf(*this));
322  }
323 
325  bool equals(const Node& q, const CompareFunc& compare) const override {
326  const Choice* other = dynamic_cast<const Choice*>(&q);
327  if (!other) return false;
328  if (this->label_ != other->label_) return false;
329  if (branches_.size() != other->branches_.size()) return false;
330  // we don't care about shared pointers being equal here
331  for (size_t i = 0; i < branches_.size(); i++)
332  if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
333  return false;
334  return true;
335  }
336 
338  const Y& operator()(const Assignment<L>& x) const override {
339 #ifndef NDEBUG
340  typename Assignment<L>::const_iterator it = x.find(label_);
341  if (it == x.end()) {
342  std::cout << "Trying to find value for " << label_ << std::endl;
343  throw std::invalid_argument(
344  "DecisionTree::operator(): value undefined for a label");
345  }
346 #endif
347  size_t index = x.at(label_);
348  NodePtr child = branches_[index];
349  return (*child)(x);
350  }
351 
353  Choice(const L& label, const Choice& f, const Unary& op) :
354  label_(label), allSame_(true) {
355  branches_.reserve(f.branches_.size()); // reserve space
356  for (const NodePtr& branch : f.branches_) {
357  push_back(branch->apply(op));
358  }
359  }
360 
371  Choice(const L& label, const Choice& f, const UnaryAssignment& op,
372  const Assignment<L>& assignment)
373  : label_(label), allSame_(true) {
374  branches_.reserve(f.branches_.size()); // reserve space
375 
376  Assignment<L> assignment_ = assignment;
377 
378  for (size_t i = 0; i < f.branches_.size(); i++) {
379  assignment_[label_] = i; // Set assignment for label to i
380 
381  const NodePtr branch = f.branches_[i];
382  push_back(branch->apply(op, assignment_));
383 
384  // Remove the assignment so we are backtracking
385  auto assignment_it = assignment_.find(label_);
386  assignment_.erase(assignment_it);
387  }
388  }
389 
391  NodePtr apply(const Unary& op) const override {
392  auto r = std::make_shared<Choice>(label_, *this, op);
393  return Unique(r);
394  }
395 
398  const Assignment<L>& assignment) const override {
399  auto r = std::make_shared<Choice>(label_, *this, op, assignment);
400  return Unique(r);
401  }
402 
403  // Apply binary operator "h = f op g" on Choice node
404  // Note op is not assumed commutative so we need to keep track of order
405  // Simply calls apply on argument to call correct virtual method:
406  // fC.apply_f_op_g(gL) -> gL.apply_g_op_fC(fC) -> (Leaf)
407  // fC.apply_f_op_g(gC) -> gC.apply_g_op_fC(fC) -> (below)
408  NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
409  return g.apply_g_op_fC(*this, op);
410  }
411 
412  // If second argument of binary op is Leaf node, recurse on branches
413  NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
414  auto h = std::make_shared<Choice>(label(), nrChoices());
415  for (auto&& branch : branches_)
416  h->push_back(fL.apply_f_op_g(*branch, op));
417  return Unique(h);
418  }
419 
420  // If second argument of binary op is Choice, call constructor
421  NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
422  auto h = std::make_shared<Choice>(fC, *this, op);
423  return Unique(h);
424  }
425 
426  // If second argument of binary op is Leaf
427  template<typename OP>
428  NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
429  auto h = std::make_shared<Choice>(label(), nrChoices());
430  for (auto&& branch : branches_)
431  h->push_back(branch->apply_f_op_g(gL, op));
432  return Unique(h);
433  }
434 
436  NodePtr choose(const L& label, size_t index) const override {
437  if (label_ == label) return branches_[index]; // choose branch
438 
439  // second case, not label of interest, just recurse
440  auto r = std::make_shared<Choice>(label_, branches_.size());
441  for (auto&& branch : branches_)
442  r->push_back(branch->choose(label, index));
443  return Unique(r);
444  }
445 
446  private:
448 
449 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
450 
451  friend class boost::serialization::access;
452  template <class ARCHIVE>
453  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
454  ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
455  ar& BOOST_SERIALIZATION_NVP(label_);
456  ar& BOOST_SERIALIZATION_NVP(branches_);
457  ar& BOOST_SERIALIZATION_NVP(allSame_);
458  }
459 #endif
460  }; // Choice
461 
462  /****************************************************************************/
463  // DecisionTree
464  /****************************************************************************/
465  template<typename L, typename Y>
467  }
468 
469  template<typename L, typename Y>
471  root_(root) {
472  }
473 
474  /****************************************************************************/
475  template<typename L, typename Y>
477  root_ = NodePtr(new Leaf(y));
478  }
479 
480  /****************************************************************************/
481  template <typename L, typename Y>
482  DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
483  auto a = std::make_shared<Choice>(label, 2);
484  NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
485  a->push_back(l1);
486  a->push_back(l2);
487  root_ = Choice::Unique(a);
488  }
489 
490  /****************************************************************************/
491  template <typename L, typename Y>
492  DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
493  const Y& y2) {
494  if (labelC.second != 2) throw std::invalid_argument(
495  "DecisionTree: binary constructor called with non-binary label");
496  auto a = std::make_shared<Choice>(labelC.first, 2);
497  NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
498  a->push_back(l1);
499  a->push_back(l2);
500  root_ = Choice::Unique(a);
501  }
502 
503  /****************************************************************************/
504  template<typename L, typename Y>
505  DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
506  const std::vector<Y>& ys) {
507  // call recursive Create
508  root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
509  }
510 
511  /****************************************************************************/
512  template<typename L, typename Y>
513  DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
514  const std::string& table) {
515  // Convert std::string to values of type Y
516  std::vector<Y> ys;
517  std::istringstream iss(table);
518  copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
519  back_inserter(ys));
520 
521  // now call recursive Create
522  root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
523  }
524 
525  /****************************************************************************/
526  template<typename L, typename Y>
527  template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
528  Iterator begin, Iterator end, const L& label) {
529  root_ = compose(begin, end, label);
530  }
531 
532  /****************************************************************************/
533  template<typename L, typename Y>
535  const DecisionTree& f0, const DecisionTree& f1) {
536  const std::vector<DecisionTree> functions{f0, f1};
537  root_ = compose(functions.begin(), functions.end(), label);
538  }
539 
540  /****************************************************************************/
541  template <typename L, typename Y>
542  template <typename X, typename Func>
544  Func Y_of_X) {
545  // Define functor for identity mapping of node label.
546  auto L_of_L = [](const L& label) { return label; };
547  root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
548  }
549 
550  /****************************************************************************/
551  template <typename L, typename Y>
552  template <typename M, typename X, typename Func>
554  const std::map<M, L>& map, Func Y_of_X) {
555  auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
556  root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
557  }
558 
559  /****************************************************************************/
560  // Called by two constructors above.
561  // Takes a label and a corresponding range of decision trees, and creates a
562  // new decision tree. However, the order of the labels needs to be respected,
563  // so we cannot just create a root Choice node on the label: if the label is
564  // not the highest label, we need a complicated/ expensive recursive call.
565  template <typename L, typename Y>
566  template <typename Iterator>
568  Iterator begin, Iterator end, const L& label) const {
569  // find highest label among branches
570  std::optional<L> highestLabel;
571  size_t nrChoices = 0;
572  for (Iterator it = begin; it != end; it++) {
573  if (it->root_->isLeaf())
574  continue;
575  std::shared_ptr<const Choice> c =
576  std::dynamic_pointer_cast<const Choice>(it->root_);
577  if (!highestLabel || c->label() > *highestLabel) {
578  highestLabel = c->label();
579  nrChoices = c->nrChoices();
580  }
581  }
582 
583  // if label is already in correct order, just put together a choice on label
584  if (!nrChoices || !highestLabel || label > *highestLabel) {
585  auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
586  for (Iterator it = begin; it != end; it++)
587  choiceOnLabel->push_back(it->root_);
588  return Choice::Unique(choiceOnLabel);
589  } else {
590  // Set up a new choice on the highest label
591  auto choiceOnHighestLabel =
592  std::make_shared<Choice>(*highestLabel, nrChoices);
593  // now, for all possible values of highestLabel
594  for (size_t index = 0; index < nrChoices; index++) {
595  // make a new set of functions for composing by iterating over the given
596  // functions, and selecting the appropriate branch.
597  std::vector<DecisionTree> functions;
598  for (Iterator it = begin; it != end; it++) {
599  // by restricting the input functions to value i for labelBelow
600  DecisionTree chosen = it->choose(*highestLabel, index);
601  functions.push_back(chosen);
602  }
603  // We then recurse, for all values of the highest label
604  NodePtr fi = compose(functions.begin(), functions.end(), label);
605  choiceOnHighestLabel->push_back(fi);
606  }
607  return Choice::Unique(choiceOnHighestLabel);
608  }
609  }
610 
611  /****************************************************************************/
612  // "create" is a bit of a complicated thing, but very useful.
613  // It takes a range of labels and a corresponding range of values,
614  // and creates a decision tree, as follows:
615  // - if there is only one label, creates a choice node with values in leaves
616  // - otherwise, it evenly splits up the range of values and creates a tree for
617  // each sub-range, and assigns that tree to first label's choices
618  // Example:
619  // create([B A],[1 2 3 4]) would call
620  // create([A],[1 2])
621  // create([A],[3 4])
622  // and produce
623  // B=0
624  // A=0: 1
625  // A=1: 2
626  // B=1
627  // A=0: 3
628  // A=1: 4
629  // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
630  // exactly the same tree as above: the highest label is always the root.
631  // However, it will be *way* faster if labels are given highest to lowest.
632  template<typename L, typename Y>
633  template<typename It, typename ValueIt>
635  It begin, It end, ValueIt beginY, ValueIt endY) const {
636  // get crucial counts
637  size_t nrChoices = begin->second;
638  size_t size = endY - beginY;
639 
640  // Find the next key to work on
641  It labelC = begin + 1;
642  if (labelC == end) {
643  // Base case: only one key left
644  // Create a simple choice node with values as leaves.
645  if (size != nrChoices) {
646  std::cout << "Trying to create DD on " << begin->first << std::endl;
647  std::cout << "DecisionTree::create: expected " << nrChoices
648  << " values but got " << size << " instead" << std::endl;
649  throw std::invalid_argument("DecisionTree::create invalid argument");
650  }
651  auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
652  for (ValueIt y = beginY; y != endY; y++)
653  choice->push_back(NodePtr(new Leaf(*y)));
654  return Choice::Unique(choice);
655  }
656 
657  // Recursive case: perform "Shannon expansion"
658  // Creates one tree (i.e.,function) for each choice of current key
659  // by calling create recursively, and then puts them all together.
660  std::vector<DecisionTree> functions;
661  size_t split = size / nrChoices;
662  for (size_t i = 0; i < nrChoices; i++, beginY += split) {
663  NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
664  functions.emplace_back(f);
665  }
666  return compose(functions.begin(), functions.end(), begin->first);
667  }
668 
669  /****************************************************************************/
670  template <typename L, typename Y>
671  template <typename M, typename X>
673  const typename DecisionTree<M, X>::NodePtr& f,
674  std::function<L(const M&)> L_of_M,
675  std::function<Y(const X&)> Y_of_X) const {
676  using LY = DecisionTree<L, Y>;
677 
678  // Ugliness below because apparently we can't have templated virtual
679  // functions.
680  // If leaf, apply unary conversion "op" and create a unique leaf.
681  using MXLeaf = typename DecisionTree<M, X>::Leaf;
682  if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
683  return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
684  }
685 
686  // Check if Choice
687  using MXChoice = typename DecisionTree<M, X>::Choice;
688  auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
689  if (!choice) throw std::invalid_argument(
690  "DecisionTree::convertFrom: Invalid NodePtr");
691 
692  // get new label
693  const M oldLabel = choice->label();
694  const L newLabel = L_of_M(oldLabel);
695 
696  // put together via Shannon expansion otherwise not sorted.
697  std::vector<LY> functions;
698  for (auto&& branch : choice->branches()) {
699  functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
700  }
701  return LY::compose(functions.begin(), functions.end(), newLabel);
702  }
703 
704  /****************************************************************************/
715  template <typename L, typename Y>
716  struct Visit {
717  using F = std::function<void(const Y&)>;
718  explicit Visit(F f) : f(f) {}
719  F f;
720 
722  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
723  using Leaf = typename DecisionTree<L, Y>::Leaf;
724  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
725  return f(leaf->constant());
726 
727  using Choice = typename DecisionTree<L, Y>::Choice;
728  auto choice = std::dynamic_pointer_cast<const Choice>(node);
729  if (!choice)
730  throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
731  for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
732  }
733  };
734 
735  template <typename L, typename Y>
736  template <typename Func>
737  void DecisionTree<L, Y>::visit(Func f) const {
738  Visit<L, Y> visit(f);
739  visit(root_);
740  }
741 
742  /****************************************************************************/
752  template <typename L, typename Y>
753  struct VisitLeaf {
755  explicit VisitLeaf(F f) : f(f) {}
756  F f;
757 
759  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
760  using Leaf = typename DecisionTree<L, Y>::Leaf;
761  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
762  return f(*leaf);
763 
764  using Choice = typename DecisionTree<L, Y>::Choice;
765  auto choice = std::dynamic_pointer_cast<const Choice>(node);
766  if (!choice)
767  throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
768  for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
769  }
770  };
771 
772  template <typename L, typename Y>
773  template <typename Func>
774  void DecisionTree<L, Y>::visitLeaf(Func f) const {
775  VisitLeaf<L, Y> visit(f);
776  visit(root_);
777  }
778 
779  /****************************************************************************/
786  template <typename L, typename Y>
787  struct VisitWith {
788  using F = std::function<void(const Assignment<L>&, const Y&)>;
789  explicit VisitWith(F f) : f(f) {}
791  F f;
792 
794  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
795  using Leaf = typename DecisionTree<L, Y>::Leaf;
796  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
797  return f(assignment, leaf->constant());
798 
799  using Choice = typename DecisionTree<L, Y>::Choice;
800  auto choice = std::dynamic_pointer_cast<const Choice>(node);
801  if (!choice)
802  throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
803  for (size_t i = 0; i < choice->nrChoices(); i++) {
804  assignment[choice->label()] = i; // Set assignment for label to i
805 
806  (*this)(choice->branches()[i]); // recurse!
807 
808  // Remove the choice so we are backtracking
809  auto choice_it = assignment.find(choice->label());
810  assignment.erase(choice_it);
811  }
812  }
813  };
814 
815  template <typename L, typename Y>
816  template <typename Func>
817  void DecisionTree<L, Y>::visitWith(Func f) const {
818  VisitWith<L, Y> visit(f);
819  visit(root_);
820  }
821 
822  /****************************************************************************/
823  template <typename L, typename Y>
825  size_t total = 0;
826  visit([&total](const Y& node) { total += 1; });
827  return total;
828  }
829 
830  /****************************************************************************/
831  // fold is just done with a visit
832  template <typename L, typename Y>
833  template <typename Func, typename X>
834  X DecisionTree<L, Y>::fold(Func f, X x0) const {
835  visit([&](const Y& y) { x0 = f(y, x0); });
836  return x0;
837  }
838 
839  /****************************************************************************/
853  template <typename L, typename Y>
854  std::set<L> DecisionTree<L, Y>::labels() const {
855  std::set<L> unique;
856  auto f = [&](const Assignment<L>& assignment, const Y&) {
857  for (auto&& kv : assignment) {
858  unique.insert(kv.first);
859  }
860  };
861  visitWith(f);
862  return unique;
863  }
864 
865 /****************************************************************************/
866  template <typename L, typename Y>
868  const CompareFunc& compare) const {
869  return root_->equals(*other.root_, compare);
870  }
871 
872  template <typename L, typename Y>
873  void DecisionTree<L, Y>::print(const std::string& s,
874  const LabelFormatter& labelFormatter,
875  const ValueFormatter& valueFormatter) const {
876  root_->print(s, labelFormatter, valueFormatter);
877  }
878 
879  template<typename L, typename Y>
881  return root_->equals(*other.root_);
882  }
883 
884  template<typename L, typename Y>
886  return root_->operator ()(x);
887  }
888 
889  template<typename L, typename Y>
891  // It is unclear what should happen if tree is empty:
892  if (empty()) {
893  throw std::runtime_error(
894  "DecisionTree::apply(unary op) undefined for empty tree.");
895  }
896  return DecisionTree(root_->apply(op));
897  }
898 
900  template <typename L, typename Y>
902  const UnaryAssignment& op) const {
903  // It is unclear what should happen if tree is empty:
904  if (empty()) {
905  throw std::runtime_error(
906  "DecisionTree::apply(unary op) undefined for empty tree.");
907  }
908  Assignment<L> assignment;
909  return DecisionTree(root_->apply(op, assignment));
910  }
911 
912  /****************************************************************************/
913  template<typename L, typename Y>
915  const Binary& op) const {
916  // It is unclear what should happen if either tree is empty:
917  if (empty() || g.empty()) {
918  throw std::runtime_error(
919  "DecisionTree::apply(binary op) undefined for empty trees.");
920  }
921  // apply the operaton on the root of both diagrams
922  NodePtr h = root_->apply_f_op_g(*g.root_, op);
923  // create a new class with the resulting root "h"
924  DecisionTree result(h);
925  return result;
926  }
927 
928  /****************************************************************************/
929  // The way this works:
930  // We have an ADT, picture it as a tree.
931  // At a certain depth, we have a branch on "label".
932  // The function "choose(label,index)" will return a tree of one less depth,
933  // where there is no more branch on "label": only the subtree under that
934  // branch point corresponding to the value "index" is left instead.
935  // The function below get all these smaller trees and "ops" them together.
936  // This implements marginalization in Darwiche09book, pg 330
937  template<typename L, typename Y>
939  size_t cardinality, const Binary& op) const {
940  DecisionTree result = choose(label, 0);
941  for (size_t index = 1; index < cardinality; index++) {
942  DecisionTree chosen = choose(label, index);
943  result = result.apply(chosen, op);
944  }
945  return result;
946  }
947 
948  /****************************************************************************/
949  template <typename L, typename Y>
950  void DecisionTree<L, Y>::dot(std::ostream& os,
951  const LabelFormatter& labelFormatter,
953  bool showZero) const {
954  os << "digraph G {\n";
955  root_->dot(os, labelFormatter, valueFormatter, showZero);
956  os << " [ordering=out]}" << std::endl;
957  }
958 
959  template <typename L, typename Y>
960  void DecisionTree<L, Y>::dot(const std::string& name,
961  const LabelFormatter& labelFormatter,
963  bool showZero) const {
964  std::ofstream os((name + ".dot").c_str());
965  dot(os, labelFormatter, valueFormatter, showZero);
966  int result =
967  system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
968  .c_str());
969  if (result == -1)
970  throw std::runtime_error("DecisionTree::dot system call failed");
971  }
972 
973  template <typename L, typename Y>
974  std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
976  bool showZero) const {
977  std::stringstream ss;
978  dot(ss, labelFormatter, valueFormatter, showZero);
979  return ss.str();
980  }
981 
982 /******************************************************************************/
983 
984  } // namespace gtsam
const Y & operator()(const Assignment< L > &x) const override
evaluate
Decision Tree for use in DiscreteFactors.
bool compare
const char Y
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51
Scalar * y
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print (as a tree).
std::function< void(const Assignment< L > &, const Y &)> F
std::function< Y(const Assignment< L > &, const Y &)> UnaryAssignment
Definition: DecisionTree.h:63
NodePtr apply(const Unary &op) const override
apply unary operator.
size_t nrAssignments() const
Return the number of assignments contained within this leaf.
std::vector< NodePtr > branches_
std::string serialize(const T &input)
serializes to a string
std::function< Y(const Y &, const Y &)> Binary
Definition: DecisionTree.h:64
std::vector< std::string > labels
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:398
static const T & choose(int layout, const T &col, const T &row)
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
gtsam::Key l2
VisitLeaf(F f)
Construct from folding function.
MatrixXd L
Definition: LLT_example.cpp:6
virtual Ptr apply_g_op_fL(const Leaf &, const Binary &) const =0
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
const std::vector< NodePtr > & branches() const
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
std::function< Y(const Y &)> Unary
Definition: DecisionTree.h:62
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:218
Choice(const Choice &f, const Choice &g, const Binary &op)
Construct from applying binary op to two Choice nodes.
static std::string valueFormatter(const double &v)
NodePtr apply_fC_op_gL(const Leaf &gL, OP op) const
Choice(const L &label, const Choice &f, const Unary &op)
Construct from applying unary op to a Choice node.
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
static const Similarity3 id
virtual bool isLeaf() const =0
void g(const string &key, int i)
Definition: testBTree.cpp:41
const L & label() const
Return the label of this choice node.
T compose(const T &t1, const T &t2)
Definition: lieProxies.h:39
const char * c_str(Args &&...args)
Definition: internals.h:524
std::function< void(const Y &)> F
F f
folding function object.
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
Leaf(const Y &constant, size_t nrAssignments=1)
Constructor from constant.
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:136
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
Definition: graph-inl.h:245
NodePtr choose(const L &label, size_t index) const override
Scalar EIGEN_BLAS_FUNC() dot(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy)
Choice(const L &label, size_t count)
Constructor, given choice label and mandatory expected branch count.
virtual bool sameLeaf(const Leaf &q) const =0
Values result
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
bool sameLeaf(const Node &q) const override
polymorphic equality: is q a leaf and is it the same as this leaf?
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print
void push_back(const NodePtr &node)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Leaf()
Default constructor for serialization.
static NodePtr Unique(const ChoicePtr &f)
If all branches of a choice node f are the same, just return a branch.
RealScalar s
EIGEN_DEVICE_FUNC const Scalar & q
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
NodePtr choose(const L &label, size_t index) const override
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
typename Node::Ptr NodePtr
Definition: DecisionTree.h:133
void operator()(const typename DecisionTree< L, Y >::NodePtr &node)
Do a depth-first visit on the tree rooted at node.
static std::stringstream ss
Definition: testBTree.cpp:31
gtsam::Key l1
traits
Definition: chartTesting.h:28
NodePtr apply(const Unary &op) const override
std::function< bool(const Y &, const Y &)> CompareFunc
Definition: DecisionTree.h:59
static Symbol x0('x', 0)
const double h
std::function< std::string(L)> LabelFormatter
Definition: DecisionTree.h:57
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:240
bool isLeaf() const override
bool equals(const Node &q, const CompareFunc &compare) const override
equality
Visit(F f)
Construct from folding function.
ADT create(const Signature &signature)
Assignment< L > assignment
Assignment, mutating through recursion.
ofstream os("timeSchurFactors.csv")
ArrayXXf table(10, 4)
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
virtual Ptr apply_g_op_fC(const Choice &, const Binary &) const =0
std::function< std::string(Y)> ValueFormatter
Definition: DecisionTree.h:58
static EIGEN_DEPRECATED const end_t end
std::function< void(const typename DecisionTree< L, Y >::Leaf &)> F
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
bool isLeaf() const override
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
F f
folding function object.
Annotation for function names.
Definition: attr.h:48
std::shared_ptr< const Choice > ChoicePtr
bool equals(const Node &q, const CompareFunc &compare) const override
equality up to tolerance
std::pair< Key, size_t > LabelC
Definition: DecisionTree.h:67
VisitWith(F f)
Construct from folding function.
Choice(const L &label, const Choice &f, const UnaryAssignment &op, const Assignment< L > &assignment)
Constructor which accepts a UnaryAssignment op and the corresponding assignment.
#define X
Definition: icosphere.cpp:20
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:341
F f
folding function object.
const Y & constant() const
Return the constant.
const Y & operator()(const Assignment< L > &x) const override
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
Choice()
Default constructor for serialization.


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:09