DecisionTree.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #pragma once
21 
22 #include <gtsam/base/Testable.h>
23 #include <gtsam/base/types.h>
25 
26 #if GTSAM_ENABLE_BOOST_SERIALIZATION
27 #include <boost/serialization/nvp.hpp>
28 #endif
29 #include <memory>
30 #include <functional>
31 #include <iostream>
32 #include <map>
33 #include <set>
34 #include <string>
35 #include <utility>
36 #include <vector>
37 
38 namespace gtsam {
39 
61  template<typename L, typename Y>
62  class DecisionTree {
63  protected:
65  static bool DefaultCompare(const Y& a, const Y& b) {
66  return a == b;
67  }
68 
69  public:
70  using LabelFormatter = std::function<std::string(L)>;
71  using ValueFormatter = std::function<std::string(Y)>;
72  using CompareFunc = std::function<bool(const Y&, const Y&)>;
73 
75  using Unary = std::function<Y(const Y&)>;
76  using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
77  using Binary = std::function<Y(const Y&, const Y&)>;
78 
80  using LabelC = std::pair<L, size_t>;
81 
83  struct Leaf;
84  struct Choice;
85 
87  struct Node {
88  using Ptr = std::shared_ptr<Node>;
89 
90 #ifdef DT_DEBUG_MEMORY
91  static int nrNodes;
92 #endif
93 
94  // Constructor
95  Node() {
96 #ifdef DT_DEBUG_MEMORY
97  std::cout << ++nrNodes << " constructed " << id() << std::endl;
98  std::cout.flush();
99 #endif
100  }
101 
102  // Destructor
103  virtual ~Node() {
104 #ifdef DT_DEBUG_MEMORY
105  std::cout << --nrNodes << " destructed " << id() << std::endl;
106  std::cout.flush();
107 #endif
108  }
109 
110  // Unique ID for dot files
111  const void* id() const { return this; }
112 
113  // everything else is virtual, no documentation here as internal
114  virtual void print(const std::string& s,
115  const LabelFormatter& labelFormatter,
116  const ValueFormatter& valueFormatter) const = 0;
117  virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
119  bool showZero) const = 0;
120  virtual bool sameLeaf(const Leaf& q) const = 0;
121  virtual bool sameLeaf(const Node& q) const = 0;
122  virtual bool equals(const Node& other, const CompareFunc& compare =
123  &DefaultCompare) const = 0;
124  virtual const Y& operator()(const Assignment<L>& x) const = 0;
125  virtual Ptr apply(const Unary& op) const = 0;
126  virtual Ptr apply(const UnaryAssignment& op,
127  const Assignment<L>& assignment) const = 0;
128  virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
129  virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
130  virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
131  virtual Ptr choose(const L& label, size_t index) const = 0;
132  virtual bool isLeaf() const = 0;
133 
134  private:
135 #if GTSAM_ENABLE_BOOST_SERIALIZATION
136 
137  friend class boost::serialization::access;
138  template <class ARCHIVE>
139  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
140 #endif
141  };
144  public:
146  using NodePtr = typename Node::Ptr;
147 
150 
151  protected:
156  template <typename It, typename ValueIt>
157  static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
158 
164  template <typename It, typename ValueIt>
165  static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
166 
176  template <typename X>
177  static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f,
178  std::function<Y(const X&)> Y_of_X);
179 
190  template <typename M, typename X>
191  static NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
192  std::function<L(const M&)> L_of_M,
193  std::function<Y(const X&)> Y_of_X);
194 
195  public:
198 
200  DecisionTree();
201 
203  explicit DecisionTree(const Y& y);
204 
212  DecisionTree(const L& label, const Y& y1, const Y& y2);
213 
215  DecisionTree(const LabelC& label, const Y& y1, const Y& y2);
216 
218  DecisionTree(const std::vector<LabelC>& labelCs, const std::vector<Y>& ys);
219 
221  DecisionTree(const std::vector<LabelC>& labelCs, const std::string& table);
222 
224  template<typename Iterator>
225  DecisionTree(Iterator begin, Iterator end, const L& label);
226 
228  DecisionTree(const L& label, const DecisionTree& f0,
229  const DecisionTree& f1);
230 
238  DecisionTree(const Unary& op, DecisionTree&& other) noexcept;
239 
247  template <typename X, typename Func>
248  DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
249 
260  template <typename M, typename X, typename Func>
261  DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
262  Func Y_of_X);
263 
267 
275  void print(const std::string& s, const LabelFormatter& labelFormatter,
276  const ValueFormatter& valueFormatter) const;
277 
278  // Testable
279  bool equals(const DecisionTree& other,
280  const CompareFunc& compare = &DefaultCompare) const;
281 
285 
287  virtual ~DecisionTree() = default;
288 
290  bool empty() const { return !root_; }
291 
293  bool operator==(const DecisionTree& q) const;
294 
296  const Y& operator()(const Assignment<L>& x) const;
297 
312  template <typename Func>
313  void visit(Func f) const;
314 
329  template <typename Func>
330  void visitLeaf(Func f) const;
331 
346  template <typename Func>
347  void visitWith(Func f) const;
348 
350  size_t nrLeaves() const;
351 
367  template <typename Func, typename X>
368  X fold(Func f, X x0) const;
369 
371  std::set<L> labels() const;
372 
374  DecisionTree apply(const Unary& op) const;
375 
384  DecisionTree apply(const UnaryAssignment& op) const;
385 
387  DecisionTree apply(const DecisionTree& g, const Binary& op) const;
388 
391  DecisionTree choose(const L& label, size_t index) const {
392  NodePtr newRoot = root_->choose(label, index);
393  return DecisionTree(newRoot);
394  }
395 
397  DecisionTree combine(const L& label, size_t cardinality,
398  const Binary& op) const;
399 
401  DecisionTree combine(const LabelC& labelC, const Binary& op) const {
402  return combine(labelC.first, labelC.second, op);
403  }
404 
406  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
407  const ValueFormatter& valueFormatter, bool showZero = true) const;
408 
410  void dot(const std::string& name, const LabelFormatter& labelFormatter,
411  const ValueFormatter& valueFormatter, bool showZero = true) const;
412 
414  std::string dot(const LabelFormatter& labelFormatter,
416  bool showZero = true) const;
417 
426  template <typename A, typename B>
427  std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
428  std::function<std::pair<A, B>(const Y&)> AB_of_Y) const;
429 
432 
433  // internal use only
434  explicit DecisionTree(const NodePtr& root);
435 
436  // internal use only
437  template<typename Iterator> NodePtr
438  static compose(Iterator begin, Iterator end, const L& label);
439 
441 
442  private:
443 #if GTSAM_ENABLE_BOOST_SERIALIZATION
444 
445  friend class boost::serialization::access;
446  template <class ARCHIVE>
447  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
448  ar& BOOST_SERIALIZATION_NVP(root_);
449  }
450 #endif
451  }; // DecisionTree
452 
453  template <class L, class Y>
454  struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
455 
458  template<typename L, typename Y>
461  const typename DecisionTree<L, Y>::Unary& op) {
462  return f.apply(op);
463  }
464 
466  template<typename L, typename Y>
468  const typename DecisionTree<L, Y>::UnaryAssignment& op) {
469  return f.apply(op);
470  }
471 
473  template<typename L, typename Y>
475  const DecisionTree<L, Y>& g,
476  const typename DecisionTree<L, Y>::Binary& op) {
477  return f.apply(g, op);
478  }
479 
486  template <typename L, typename T1, typename T2>
487  std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
488  const DecisionTree<L, std::pair<T1, T2> >& input) {
489  return {
490  DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
491  DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; })
492  };
493  }
494 
495 } // namespace gtsam
compare
bool compare
Definition: SolverComparer.cpp:98
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::LabelFormatter
std::function< std::string(Key)> LabelFormatter
Definition: DecisionTree.h:70
name
Annotation for function names.
Definition: attr.h:51
test_constructor::f1
auto f1
Definition: testHybridNonlinearFactor.cpp:56
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::CompareFunc
std::function< bool(const GaussianFactorGraphValuePair &, const GaussianFactorGraphValuePair &)> CompareFunc
Definition: DecisionTree.h:72
Leaf
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
Definition: testSymbolicEliminationTree.cpp:78
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::DecisionTree::split
std::pair< DecisionTree< L, A >, DecisionTree< L, B > > split(std::function< std::pair< A, B >(const Y &)> AB_of_Y) const
Convert into two trees with value types A and B.
Definition: DecisionTree-inl.h:1096
types.h
Typedefs for easier changing of types.
Testable.h
Concept check for values that can be used in unit tests.
gtsam::DecisionTree::empty
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:290
gtsam::Y
GaussianFactorGraphValuePair Y
Definition: HybridGaussianProductFactor.cpp:29
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:972
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition: gnuplot_common_settings.hh:12
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::ValueFormatter
std::function< std::string(GaussianFactorGraphValuePair)> ValueFormatter
Definition: DecisionTree.h:71
gtsam::DecisionTree::labels
std::set< L > labels() const
Definition: DecisionTree-inl.h:959
gtsam::DecisionTree::Choice
Definition: DecisionTree-inl.h:162
gtsam::DecisionTree::DecisionTree
DecisionTree()
Definition: DecisionTree-inl.h:491
gtsam::DecisionTree::Node::operator()
virtual const Y & operator()(const Assignment< L > &x) const =0
Iterator
Definition: typing.h:54
X
#define X
Definition: icosphere.cpp:20
gtsam::DecisionTree::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:978
gtsam::DecisionTree::build
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY)
Definition: DecisionTree-inl.h:687
gtsam::DecisionTree::choose
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:391
os
ofstream os("timeSchurFactors.csv")
gtsam::DecisionTree::Node::equals
virtual bool equals(const Node &other, const CompareFunc &compare=&DefaultCompare) const =0
gtsam::DecisionTree::DefaultCompare
static bool DefaultCompare(const Y &a, const Y &b)
Default method for comparison of two objects of type Y.
Definition: DecisionTree.h:65
gtsam::DecisionTree::Node::isLeaf
virtual bool isLeaf() const =0
gtsam::DecisionTree::Node::apply_f_op_g
virtual Ptr apply_f_op_g(const Node &, const Binary &) const =0
gtsam::DecisionTree::combine
DecisionTree combine(const LabelC &labelC, const Binary &op) const
Definition: DecisionTree.h:401
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::Unary
std::function< GaussianFactorGraphValuePair(const GaussianFactorGraphValuePair &)> Unary
Definition: DecisionTree.h:75
y1
double y1(double x)
Definition: j1.c:199
gtsam::DecisionTree::root_
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:149
gtsam::DecisionTree::Node::apply_g_op_fL
virtual Ptr apply_g_op_fL(const Leaf &, const Binary &) const =0
table
ArrayXXf table(10, 4)
gtsam::DecisionTree::compose
static NodePtr compose(Iterator begin, Iterator end, const L &label)
Definition: DecisionTree-inl.h:617
test_constructor::f0
auto f0
Definition: testHybridNonlinearFactor.cpp:55
Eigen::numext::q
EIGEN_DEVICE_FUNC const Scalar & q
Definition: SpecialFunctionsImpl.h:1984
gtsam::DecisionTree::Node::choose
virtual Ptr choose(const L &label, size_t index) const =0
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::Binary
std::function< GaussianFactorGraphValuePair(const GaussianFactorGraphValuePair &, const GaussianFactorGraphValuePair &)> Binary
Definition: DecisionTree.h:77
Assignment.h
An assignment from labels to a discrete value index (size_t)
x0
static Symbol x0('x', 0)
gtsam::DecisionTree::visit
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:842
gtsam::DecisionTree::dot
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
Definition: DecisionTree-inl.h:1061
L
MatrixXd L
Definition: LLT_example.cpp:6
gtsam::Assignment
Definition: Assignment.h:37
gtsam::DecisionTree::visitWith
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:922
gtsam::DecisionTree::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:1000
gtsam::DecisionTree::Node::dot
virtual void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const =0
g
void g(const string &key, int i)
Definition: testBTree.cpp:41
y
Scalar * y
Definition: level1_cplx_impl.h:124
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
gtsam::DecisionTree
a decision tree is a function from assignments to values.
Definition: DecisionTree.h:62
gtsam::DecisionTree::create
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY)
Definition: DecisionTree-inl.h:729
gtsam::b
const G & b
Definition: Group.h:79
gtsam::DecisionTree::Node::sameLeaf
virtual bool sameLeaf(const Leaf &q) const =0
gtsam::DecisionTree::~DecisionTree
virtual ~DecisionTree()=default
Make virtual.
gtsam::DecisionTree::Leaf
Definition: DecisionTree-inl.h:50
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
gtsam
traits
Definition: SFMdata.h:40
gtsam::Testable
Definition: Testable.h:152
gtsam::traits
Definition: Group.h:36
gtsam::DecisionTree::Node::id
const void * id() const
Definition: DecisionTree.h:111
gtsam::DecisionTree::Node::apply
virtual Ptr apply(const Unary &op) const =0
gtsam::DecisionTree::operator==
bool operator==(const DecisionTree &q) const
Definition: DecisionTree-inl.h:985
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::UnaryAssignment
std::function< GaussianFactorGraphValuePair(const Assignment< Key > &, const GaussianFactorGraphValuePair &)> UnaryAssignment
Definition: DecisionTree.h:76
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::LabelC
std::pair< Key, size_t > LabelC
Definition: DecisionTree.h:80
gtsam::apply
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:460
gtsam::DecisionTree::Node
Definition: DecisionTree.h:87
gtsam::DecisionTree::nrLeaves
size_t nrLeaves() const
Return the number of leaves in the tree.
Definition: DecisionTree-inl.h:929
gtsam::DecisionTree< Key, GaussianFactorGraphValuePair >::NodePtr
typename Node::Ptr NodePtr
Definition: DecisionTree.h:146
gtsam::DecisionTree::combine
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
Definition: DecisionTree-inl.h:1049
Eigen::placeholders::end
static const EIGEN_DEPRECATED end_t end
Definition: IndexedViewHelper.h:181
gtsam::unzip
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:487
gtsam::valueFormatter
static std::string valueFormatter(const double &v)
Definition: DecisionTreeFactor.cpp:292
gtsam::DecisionTree::fold
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
Definition: DecisionTree-inl.h:939
gtsam::DecisionTree::Node::~Node
virtual ~Node()
Definition: DecisionTree.h:103
gtsam::DecisionTree::operator()
const Y & operator()(const Assignment< L > &x) const
Definition: DecisionTree-inl.h:991
gtsam::DecisionTree::convertFrom
static NodePtr convertFrom(const typename DecisionTree< L, X >::NodePtr &f, std::function< Y(const X &)> Y_of_X)
Convert from a DecisionTree<L, X> to DecisionTree<L, Y>.
Definition: DecisionTree-inl.h:742
gtsam::DecisionTree::Node::Node
Node()
Definition: DecisionTree.h:95
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::DecisionTree::Node::print
virtual void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const =0
gtsam::DecisionTree::visitLeaf
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:879
gtsam::DecisionTree::Node::apply_g_op_fC
virtual Ptr apply_g_op_fC(const Choice &, const Binary &) const =0
M
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51
gtsam::DecisionTree::Node::Ptr
std::shared_ptr< Node > Ptr
Definition: DecisionTree.h:88


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:08