DecisionTree.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #pragma once
21 
22 #include <gtsam/base/Testable.h>
23 #include <gtsam/base/types.h>
25 
26 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
27 #include <boost/serialization/nvp.hpp>
28 #endif
29 #include <memory>
30 #include <functional>
31 #include <iostream>
32 #include <map>
33 #include <set>
34 #include <sstream>
35 #include <string>
36 #include <utility>
37 #include <vector>
38 
39 namespace gtsam {
40 
48  template<typename L, typename Y>
49  class DecisionTree {
50  protected:
52  static bool DefaultCompare(const Y& a, const Y& b) {
53  return a == b;
54  }
55 
56  public:
57  using LabelFormatter = std::function<std::string(L)>;
58  using ValueFormatter = std::function<std::string(Y)>;
59  using CompareFunc = std::function<bool(const Y&, const Y&)>;
60 
62  using Unary = std::function<Y(const Y&)>;
63  using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
64  using Binary = std::function<Y(const Y&, const Y&)>;
65 
67  using LabelC = std::pair<L, size_t>;
68 
70  struct Leaf;
71  struct Choice;
72 
74  struct Node {
75  using Ptr = std::shared_ptr<const Node>;
76 
77 #ifdef DT_DEBUG_MEMORY
78  static int nrNodes;
79 #endif
80 
81  // Constructor
82  Node() {
83 #ifdef DT_DEBUG_MEMORY
84  std::cout << ++nrNodes << " constructed " << id() << std::endl;
85  std::cout.flush();
86 #endif
87  }
88 
89  // Destructor
90  virtual ~Node() {
91 #ifdef DT_DEBUG_MEMORY
92  std::cout << --nrNodes << " destructed " << id() << std::endl;
93  std::cout.flush();
94 #endif
95  }
96 
97  // Unique ID for dot files
98  const void* id() const { return this; }
99 
100  // everything else is virtual, no documentation here as internal
101  virtual void print(const std::string& s,
102  const LabelFormatter& labelFormatter,
103  const ValueFormatter& valueFormatter) const = 0;
104  virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
105  const ValueFormatter& valueFormatter,
106  bool showZero) const = 0;
107  virtual bool sameLeaf(const Leaf& q) const = 0;
108  virtual bool sameLeaf(const Node& q) const = 0;
109  virtual bool equals(const Node& other, const CompareFunc& compare =
110  &DefaultCompare) const = 0;
111  virtual const Y& operator()(const Assignment<L>& x) const = 0;
112  virtual Ptr apply(const Unary& op) const = 0;
113  virtual Ptr apply(const UnaryAssignment& op,
114  const Assignment<L>& assignment) const = 0;
115  virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
116  virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
117  virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
118  virtual Ptr choose(const L& label, size_t index) const = 0;
119  virtual bool isLeaf() const = 0;
120 
121  private:
122 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
123 
124  friend class boost::serialization::access;
125  template <class ARCHIVE>
126  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
127 #endif
128  };
131  public:
133  using NodePtr = typename Node::Ptr;
134 
137 
138  protected:
142  template<typename It, typename ValueIt>
143  NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
144 
155  template <typename M, typename X>
157  std::function<L(const M&)> L_of_M,
158  std::function<Y(const X&)> Y_of_X) const;
159 
160  public:
163 
165  DecisionTree();
166 
168  explicit DecisionTree(const Y& y);
169 
171  DecisionTree(const L& label, const Y& y1, const Y& y2);
172 
174  DecisionTree(const LabelC& label, const Y& y1, const Y& y2);
175 
177  DecisionTree(const std::vector<LabelC>& labelCs, const std::vector<Y>& ys);
178 
180  DecisionTree(const std::vector<LabelC>& labelCs, const std::string& table);
181 
183  template<typename Iterator>
184  DecisionTree(Iterator begin, Iterator end, const L& label);
185 
187  DecisionTree(const L& label, const DecisionTree& f0,
188  const DecisionTree& f1);
189 
197  template <typename X, typename Func>
198  DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
199 
210  template <typename M, typename X, typename Func>
211  DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
212  Func Y_of_X);
213 
217 
225  void print(const std::string& s, const LabelFormatter& labelFormatter,
226  const ValueFormatter& valueFormatter) const;
227 
228  // Testable
229  bool equals(const DecisionTree& other,
230  const CompareFunc& compare = &DefaultCompare) const;
231 
235 
237  virtual ~DecisionTree() = default;
238 
240  bool empty() const { return !root_; }
241 
243  bool operator==(const DecisionTree& q) const;
244 
246  const Y& operator()(const Assignment<L>& x) const;
247 
262  template <typename Func>
263  void visit(Func f) const;
264 
279  template <typename Func>
280  void visitLeaf(Func f) const;
281 
296  template <typename Func>
297  void visitWith(Func f) const;
298 
300  size_t nrLeaves() const;
301 
317  template <typename Func, typename X>
318  X fold(Func f, X x0) const;
319 
321  std::set<L> labels() const;
322 
324  DecisionTree apply(const Unary& op) const;
325 
334  DecisionTree apply(const UnaryAssignment& op) const;
335 
337  DecisionTree apply(const DecisionTree& g, const Binary& op) const;
338 
341  DecisionTree choose(const L& label, size_t index) const {
342  NodePtr newRoot = root_->choose(label, index);
343  return DecisionTree(newRoot);
344  }
345 
347  DecisionTree combine(const L& label, size_t cardinality,
348  const Binary& op) const;
349 
351  DecisionTree combine(const LabelC& labelC, const Binary& op) const {
352  return combine(labelC.first, labelC.second, op);
353  }
354 
356  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
357  const ValueFormatter& valueFormatter, bool showZero = true) const;
358 
360  void dot(const std::string& name, const LabelFormatter& labelFormatter,
361  const ValueFormatter& valueFormatter, bool showZero = true) const;
362 
364  std::string dot(const LabelFormatter& labelFormatter,
365  const ValueFormatter& valueFormatter,
366  bool showZero = true) const;
367 
370 
371  // internal use only
372  explicit DecisionTree(const NodePtr& root);
373 
374  // internal use only
375  template<typename Iterator> NodePtr
376  compose(Iterator begin, Iterator end, const L& label) const;
377 
379 
380  private:
381 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
382 
383  friend class boost::serialization::access;
384  template <class ARCHIVE>
385  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
386  ar& BOOST_SERIALIZATION_NVP(root_);
387  }
388 #endif
389  }; // DecisionTree
390 
391  template <class L, class Y>
392  struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
393 
396  template<typename L, typename Y>
399  const typename DecisionTree<L, Y>::Unary& op) {
400  return f.apply(op);
401  }
402 
404  template<typename L, typename Y>
406  const typename DecisionTree<L, Y>::UnaryAssignment& op) {
407  return f.apply(op);
408  }
409 
411  template<typename L, typename Y>
413  const DecisionTree<L, Y>& g,
414  const typename DecisionTree<L, Y>::Binary& op) {
415  return f.apply(g, op);
416  }
417 
424  template <typename L, typename T1, typename T2>
425  std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
426  const DecisionTree<L, std::pair<T1, T2> >& input) {
427  return {
428  DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
429  DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; })
430  };
431  }
432 
433 } // namespace gtsam
virtual ~DecisionTree()=default
Make virtual.
NodePtr convertFrom(const typename DecisionTree< M, X >::NodePtr &f, std::function< L(const M &)> L_of_M, std::function< Y(const X &)> Y_of_X) const
Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
bool compare
const char Y
const void * id() const
Definition: DecisionTree.h:98
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51
Typedefs for easier changing of types.
Scalar * y
std::function< double(const Assignment< Key > &, const double &)> UnaryAssignment
Definition: DecisionTree.h:63
Concept check for values that can be used in unit tests.
void visit(Func f) const
Visit all leaves in depth-first fashion.
std::string serialize(const T &input)
serializes to a string
std::function< double(const double &, const double &)> Binary
Definition: DecisionTree.h:64
An assignment from labels to a discrete value index (size_t)
virtual Ptr choose(const L &label, size_t index) const =0
MatrixXd L
Definition: LLT_example.cpp:6
virtual Ptr apply_g_op_fL(const Leaf &, const Binary &) const =0
std::function< double(const double &)> Unary
Definition: DecisionTree.h:62
static std::string valueFormatter(const double &v)
std::set< L > labels() const
virtual bool isLeaf() const =0
void g(const string &key, int i)
Definition: testBTree.cpp:41
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:136
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
virtual void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const =0
virtual const Y & operator()(const Assignment< L > &x) const =0
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
virtual bool sameLeaf(const Leaf &q) const =0
size_t nrLeaves() const
Return the number of leaves in the tree.
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:425
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
virtual void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const =0
RealScalar s
EIGEN_DEVICE_FUNC const Scalar & q
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
const G & b
Definition: Group.h:86
DecisionTree combine(const LabelC &labelC, const Binary &op) const
Definition: DecisionTree.h:351
traits
Definition: chartTesting.h:28
std::function< bool(const double &, const double &)> CompareFunc
Definition: DecisionTree.h:59
static Symbol x0('x', 0)
std::shared_ptr< const Node > Ptr
Definition: DecisionTree.h:75
std::function< std::string(Key)> LabelFormatter
Definition: DecisionTree.h:57
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:240
ofstream os("timeSchurFactors.csv")
ArrayXXf table(10, 4)
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
virtual Ptr apply_g_op_fC(const Choice &, const Binary &) const =0
bool operator==(const DecisionTree &q) const
std::function< std::string(double)> ValueFormatter
Definition: DecisionTree.h:58
static EIGEN_DEPRECATED const end_t end
Annotation for function names.
Definition: attr.h:48
virtual bool equals(const Node &other, const CompareFunc &compare=&DefaultCompare) const =0
std::pair< Key, size_t > LabelC
Definition: DecisionTree.h:67
virtual Ptr apply_f_op_g(const Node &, const Binary &) const =0
#define X
Definition: icosphere.cpp:20
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
NodePtr compose(Iterator begin, Iterator end, const L &label) const
virtual Ptr apply(const Unary &op) const =0
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:341
static bool DefaultCompare(const Y &a, const Y &b)
Default method for comparison of two objects of type Y.
Definition: DecisionTree.h:52


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:09