DecisionTree.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #pragma once
21 
22 #include <gtsam/base/Testable.h>
23 #include <gtsam/base/types.h>
25 
26 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
27 #include <boost/serialization/nvp.hpp>
28 #endif
29 #include <memory>
30 #include <functional>
31 #include <iostream>
32 #include <map>
33 #include <set>
34 #include <sstream>
35 #include <string>
36 #include <utility>
37 #include <vector>
38 
39 namespace gtsam {
40 
62  template<typename L, typename Y>
63  class DecisionTree {
64  protected:
66  static bool DefaultCompare(const Y& a, const Y& b) {
67  return a == b;
68  }
69 
70  public:
71  using LabelFormatter = std::function<std::string(L)>;
72  using ValueFormatter = std::function<std::string(Y)>;
73  using CompareFunc = std::function<bool(const Y&, const Y&)>;
74 
76  using Unary = std::function<Y(const Y&)>;
77  using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
78  using Binary = std::function<Y(const Y&, const Y&)>;
79 
81  using LabelC = std::pair<L, size_t>;
82 
84  struct Leaf;
85  struct Choice;
86 
88  struct Node {
89  using Ptr = std::shared_ptr<const Node>;
90 
91 #ifdef DT_DEBUG_MEMORY
92  static int nrNodes;
93 #endif
94 
95  // Constructor
96  Node() {
97 #ifdef DT_DEBUG_MEMORY
98  std::cout << ++nrNodes << " constructed " << id() << std::endl;
99  std::cout.flush();
100 #endif
101  }
102 
103  // Destructor
104  virtual ~Node() {
105 #ifdef DT_DEBUG_MEMORY
106  std::cout << --nrNodes << " destructed " << id() << std::endl;
107  std::cout.flush();
108 #endif
109  }
110 
111  // Unique ID for dot files
112  const void* id() const { return this; }
113 
114  // everything else is virtual, no documentation here as internal
115  virtual void print(const std::string& s,
116  const LabelFormatter& labelFormatter,
117  const ValueFormatter& valueFormatter) const = 0;
118  virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
120  bool showZero) const = 0;
121  virtual bool sameLeaf(const Leaf& q) const = 0;
122  virtual bool sameLeaf(const Node& q) const = 0;
123  virtual bool equals(const Node& other, const CompareFunc& compare =
124  &DefaultCompare) const = 0;
125  virtual const Y& operator()(const Assignment<L>& x) const = 0;
126  virtual Ptr apply(const Unary& op) const = 0;
127  virtual Ptr apply(const UnaryAssignment& op,
128  const Assignment<L>& assignment) const = 0;
129  virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
130  virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
131  virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
132  virtual Ptr choose(const L& label, size_t index) const = 0;
133  virtual bool isLeaf() const = 0;
134 
135  private:
136 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
137 
138  friend class boost::serialization::access;
139  template <class ARCHIVE>
140  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
141 #endif
142  };
145  public:
147  using NodePtr = typename Node::Ptr;
148 
151 
152  protected:
157  template <typename It, typename ValueIt>
158  NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;
159 
165  template <typename It, typename ValueIt>
166  NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
167 
178  template <typename M, typename X>
180  std::function<L(const M&)> L_of_M,
181  std::function<Y(const X&)> Y_of_X) const;
182 
183  public:
186 
188  DecisionTree();
189 
191  explicit DecisionTree(const Y& y);
192 
200  DecisionTree(const L& label, const Y& y1, const Y& y2);
201 
203  DecisionTree(const LabelC& label, const Y& y1, const Y& y2);
204 
206  DecisionTree(const std::vector<LabelC>& labelCs, const std::vector<Y>& ys);
207 
209  DecisionTree(const std::vector<LabelC>& labelCs, const std::string& table);
210 
212  template<typename Iterator>
213  DecisionTree(Iterator begin, Iterator end, const L& label);
214 
216  DecisionTree(const L& label, const DecisionTree& f0,
217  const DecisionTree& f1);
218 
226  template <typename X, typename Func>
227  DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
228 
239  template <typename M, typename X, typename Func>
240  DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
241  Func Y_of_X);
242 
246 
254  void print(const std::string& s, const LabelFormatter& labelFormatter,
255  const ValueFormatter& valueFormatter) const;
256 
257  // Testable
258  bool equals(const DecisionTree& other,
259  const CompareFunc& compare = &DefaultCompare) const;
260 
264 
266  virtual ~DecisionTree() = default;
267 
269  bool empty() const { return !root_; }
270 
272  bool operator==(const DecisionTree& q) const;
273 
275  const Y& operator()(const Assignment<L>& x) const;
276 
291  template <typename Func>
292  void visit(Func f) const;
293 
308  template <typename Func>
309  void visitLeaf(Func f) const;
310 
325  template <typename Func>
326  void visitWith(Func f) const;
327 
329  size_t nrLeaves() const;
330 
346  template <typename Func, typename X>
347  X fold(Func f, X x0) const;
348 
350  std::set<L> labels() const;
351 
353  DecisionTree apply(const Unary& op) const;
354 
363  DecisionTree apply(const UnaryAssignment& op) const;
364 
366  DecisionTree apply(const DecisionTree& g, const Binary& op) const;
367 
370  DecisionTree choose(const L& label, size_t index) const {
371  NodePtr newRoot = root_->choose(label, index);
372  return DecisionTree(newRoot);
373  }
374 
376  DecisionTree combine(const L& label, size_t cardinality,
377  const Binary& op) const;
378 
380  DecisionTree combine(const LabelC& labelC, const Binary& op) const {
381  return combine(labelC.first, labelC.second, op);
382  }
383 
385  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
386  const ValueFormatter& valueFormatter, bool showZero = true) const;
387 
389  void dot(const std::string& name, const LabelFormatter& labelFormatter,
390  const ValueFormatter& valueFormatter, bool showZero = true) const;
391 
393  std::string dot(const LabelFormatter& labelFormatter,
395  bool showZero = true) const;
396 
399 
400  // internal use only
401  explicit DecisionTree(const NodePtr& root);
402 
403  // internal use only
404  template<typename Iterator> NodePtr
405  compose(Iterator begin, Iterator end, const L& label) const;
406 
408 
409  private:
410 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
411 
412  friend class boost::serialization::access;
413  template <class ARCHIVE>
414  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
415  ar& BOOST_SERIALIZATION_NVP(root_);
416  }
417 #endif
418  }; // DecisionTree
419 
420  template <class L, class Y>
421  struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
422 
425  template<typename L, typename Y>
428  const typename DecisionTree<L, Y>::Unary& op) {
429  return f.apply(op);
430  }
431 
433  template<typename L, typename Y>
435  const typename DecisionTree<L, Y>::UnaryAssignment& op) {
436  return f.apply(op);
437  }
438 
440  template<typename L, typename Y>
442  const DecisionTree<L, Y>& g,
443  const typename DecisionTree<L, Y>::Binary& op) {
444  return f.apply(g, op);
445  }
446 
453  template <typename L, typename T1, typename T2>
454  std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
455  const DecisionTree<L, std::pair<T1, T2> >& input) {
456  return {
457  DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
458  DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; })
459  };
460  }
461 
462 } // namespace gtsam
compare
bool compare
Definition: SolverComparer.cpp:98
gtsam::DecisionTree::create
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
Definition: DecisionTree-inl.h:690
gtsam::DecisionTree< Key, double >::LabelFormatter
std::function< std::string(Key)> LabelFormatter
Definition: DecisionTree.h:71
name
Annotation for function names.
Definition: attr.h:51
Y
const char Y
Definition: test/EulerAngles.cpp:31
gtsam::DecisionTree< Key, double >::CompareFunc
std::function< bool(const double &, const double &)> CompareFunc
Definition: DecisionTree.h:73
Leaf
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
Definition: testSymbolicEliminationTree.cpp:78
s
RealScalar s
Definition: level1_cplx_impl.h:126
types.h
Typedefs for easier changing of types.
Testable.h
Concept check for values that can be used in unit tests.
gtsam::DecisionTree::empty
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:269
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:898
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition: gnuplot_common_settings.hh:12
gtsam::DecisionTree< Key, double >::ValueFormatter
std::function< std::string(double)> ValueFormatter
Definition: DecisionTree.h:72
gtsam::DecisionTree::labels
std::set< L > labels() const
Definition: DecisionTree-inl.h:885
gtsam::DecisionTree::Choice
Definition: DecisionTree-inl.h:164
gtsam::DecisionTree::DecisionTree
DecisionTree()
Definition: DecisionTree-inl.h:482
gtsam::DecisionTree::Node::operator()
virtual const Y & operator()(const Assignment< L > &x) const =0
X
#define X
Definition: icosphere.cpp:20
gtsam::DecisionTree::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:904
gtsam::DecisionTree::choose
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:370
os
ofstream os("timeSchurFactors.csv")
gtsam::DecisionTree::Node::equals
virtual bool equals(const Node &other, const CompareFunc &compare=&DefaultCompare) const =0
gtsam::DecisionTree::DefaultCompare
static bool DefaultCompare(const Y &a, const Y &b)
Default method for comparison of two objects of type Y.
Definition: DecisionTree.h:66
gtsam::DecisionTree::Node::Ptr
std::shared_ptr< const Node > Ptr
Definition: DecisionTree.h:89
gtsam::DecisionTree::Node::isLeaf
virtual bool isLeaf() const =0
gtsam::DecisionTree::Node::apply_f_op_g
virtual Ptr apply_f_op_g(const Node &, const Binary &) const =0
gtsam::DecisionTree::combine
DecisionTree combine(const LabelC &labelC, const Binary &op) const
Definition: DecisionTree.h:380
gtsam::DecisionTree< Key, double >::Unary
std::function< double(const double &)> Unary
Definition: DecisionTree.h:76
y1
double y1(double x)
Definition: j1.c:199
gtsam::DecisionTree::root_
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:150
gtsam::DecisionTree::Node::apply_g_op_fL
virtual Ptr apply_g_op_fL(const Leaf &, const Binary &) const =0
table
ArrayXXf table(10, 4)
Eigen::numext::q
EIGEN_DEVICE_FUNC const Scalar & q
Definition: SpecialFunctionsImpl.h:1984
gtsam::DecisionTree::Node::choose
virtual Ptr choose(const L &label, size_t index) const =0
gtsam::DecisionTree< Key, double >::Binary
std::function< double(const double &, const double &)> Binary
Definition: DecisionTree.h:78
Assignment.h
An assignment from labels to a discrete value index (size_t)
x0
static Symbol x0('x', 0)
gtsam::DecisionTree::visit
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:768
gtsam::DecisionTree::dot
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
Definition: DecisionTree-inl.h:981
L
MatrixXd L
Definition: LLT_example.cpp:6
gtsam::Assignment
Definition: Assignment.h:37
gtsam::DecisionTree::visitWith
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:848
gtsam::DecisionTree::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:921
gtsam::DecisionTree::Node::dot
virtual void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const =0
g
void g(const string &key, int i)
Definition: testBTree.cpp:41
y
Scalar * y
Definition: level1_cplx_impl.h:124
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
gtsam::DecisionTree
a decision tree is a function from assignments to values.
Definition: DecisionTree.h:63
gtsam::b
const G & b
Definition: Group.h:79
gtsam::DecisionTree::Node::sameLeaf
virtual bool sameLeaf(const Leaf &q) const =0
gtsam::DecisionTree::~DecisionTree
virtual ~DecisionTree()=default
Make virtual.
gtsam::DecisionTree::Leaf
Definition: DecisionTree-inl.h:52
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
gtsam
traits
Definition: chartTesting.h:28
gtsam::Testable
Definition: Testable.h:152
gtsam::traits
Definition: Group.h:36
gtsam::DecisionTree::Node::id
const void * id() const
Definition: DecisionTree.h:112
gtsam::DecisionTree::Node::apply
virtual Ptr apply(const Unary &op) const =0
gtsam::DecisionTree::operator==
bool operator==(const DecisionTree &q) const
Definition: DecisionTree-inl.h:911
gtsam::DecisionTree< Key, double >::UnaryAssignment
std::function< double(const Assignment< Key > &, const double &)> UnaryAssignment
Definition: DecisionTree.h:77
gtsam::DecisionTree< Key, double >::LabelC
std::pair< Key, size_t > LabelC
Definition: DecisionTree.h:81
gtsam::DecisionTree::compose
NodePtr compose(Iterator begin, Iterator end, const L &label) const
Definition: DecisionTree-inl.h:581
gtsam::apply
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:427
gtsam::DecisionTree::Node
Definition: DecisionTree.h:88
gtsam::DecisionTree::nrLeaves
size_t nrLeaves() const
Return the number of leaves in the tree.
Definition: DecisionTree-inl.h:855
gtsam::DecisionTree< Key, double >::NodePtr
typename Node::Ptr NodePtr
Definition: DecisionTree.h:147
unary::f1
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Definition: testExpression.cpp:79
gtsam::DecisionTree::combine
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
Definition: DecisionTree-inl.h:969
Eigen::placeholders::end
static const EIGEN_DEPRECATED end_t end
Definition: IndexedViewHelper.h:181
gtsam::unzip
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:454
gtsam::valueFormatter
static std::string valueFormatter(const double &v)
Definition: DecisionTreeFactor.cpp:266
gtsam::DecisionTree::fold
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
Definition: DecisionTree-inl.h:865
gtsam::DecisionTree::convertFrom
NodePtr convertFrom(const typename DecisionTree< M, X >::NodePtr &f, std::function< L(const M &)> L_of_M, std::function< Y(const X &)> Y_of_X) const
Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
Definition: DecisionTree-inl.h:703
gtsam::DecisionTree::Node::~Node
virtual ~Node()
Definition: DecisionTree.h:104
gtsam::DecisionTree::operator()
const Y & operator()(const Assignment< L > &x) const
Definition: DecisionTree-inl.h:916
gtsam::DecisionTree::Node::Node
Node()
Definition: DecisionTree.h:96
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::DecisionTree::build
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const
Definition: DecisionTree-inl.h:649
gtsam::DecisionTree::Node::print
virtual void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const =0
gtsam::DecisionTree::visitLeaf
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:805
gtsam::DecisionTree::Node::apply_g_op_fC
virtual Ptr apply_g_op_fC(const Choice &, const Binary &) const =0
M
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51


gtsam
Author(s):
autogenerated on Thu Jun 13 2024 03:02:10