DecisionTreeFactor.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
19 #pragma once
20 
24 #include <gtsam/discrete/Ring.h>
26 
27 #include <algorithm>
28 #include <map>
29 #include <memory>
30 #include <stdexcept>
31 #include <string>
32 #include <utility>
33 #include <vector>
34 
35 namespace gtsam {
36 
37  class DiscreteConditional;
38  class HybridValues;
39 
45  class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
46  public AlgebraicDecisionTree<Key> {
47  public:
48  // typedefs needed to play nice with gtsam
50  typedef DiscreteFactor Base;
51  typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
53 
54  // Needed since we have definitions in both DiscreteFactor and DecisionTree
55  using Base::Binary;
56  using Base::Unary;
58 
61 
64 
66  DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
67 
88  const std::vector<double>& table);
89 
108  DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
109 
111  template <class SOURCE>
114 
116  DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
118 
120  explicit DecisionTreeFactor(const DiscreteConditional& c);
121 
125 
127  bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
128 
129  // print
130  void print(
131  const std::string& s = "DecisionTreeFactor:\n",
132  const KeyFormatter& formatter = DefaultKeyFormatter) const override;
133 
137 
140  virtual double evaluate(const Assignment<Key>& values) const override {
141  return ADT::operator()(values);
142  }
143 
145  using DiscreteFactor::operator();
146 
148  double error(const DiscreteValues& values) const override;
149 
164  virtual DiscreteFactor::shared_ptr multiply(
165  const DiscreteFactor::shared_ptr& f) const override;
166 
168  DiscreteFactor::shared_ptr operator*(double s) const override {
169  return std::make_shared<DecisionTreeFactor>(
170  apply([s](const double& a) { return Ring::mul(a, s); }));
171  }
172 
175  return apply(f, Ring::mul);
176  }
177 
178  static double safe_div(const double& a, const double& b);
179 
190  return apply(f, safe_div);
191  }
192 
195  const DiscreteFactor::shared_ptr& f) const override;
196 
198  DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
199 
201  DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
202  return combine(nrFrontals, Ring::add);
203  }
204 
206  DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
207  return combine(keys, Ring::add);
208  }
209 
211  double max() const override { return ADT::max(); };
212 
214  DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
215  return combine(nrFrontals, Ring::max);
216  }
217 
219  DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
220  return combine(keys, Ring::max);
221  }
222 
225  const DiscreteValues& assignment) const override;
226 
230 
235  DecisionTreeFactor apply(Unary op) const;
236 
242  DecisionTreeFactor apply(UnaryAssignment op) const;
243 
249  DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const;
250 
257  shared_ptr combine(size_t nrFrontals, Binary op) const;
258 
265  shared_ptr combine(const Ordering& keys, Binary op) const;
266 
268  std::vector<std::pair<DiscreteValues, double>> enumerate() const;
269 
271  std::vector<double> probabilities() const;
272 
282  double computeThreshold(const size_t N) const;
283 
302  DecisionTreeFactor prune(size_t maxNrAssignments) const;
303 
308  uint64_t nrValues() const override { return nrLeaves(); }
309 
313 
315  void dot(std::ostream& os,
316  const KeyFormatter& keyFormatter = DefaultKeyFormatter,
317  bool showZero = true) const;
318 
320  void dot(const std::string& name,
321  const KeyFormatter& keyFormatter = DefaultKeyFormatter,
322  bool showZero = true) const;
323 
325  std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
326  bool showZero = true) const;
327 
335  std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
336  const Names& names = {}) const override;
337 
345  std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
346  const Names& names = {}) const override;
347 
351 
356  double error(const HybridValues& values) const override;
357 
359 
360  private:
361 #if GTSAM_ENABLE_BOOST_SERIALIZATION
362 
363  friend class boost::serialization::access;
364  template <class ARCHIVE>
365  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
366  ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
367  ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
368  }
369 #endif
370  };
371 
372 // traits
373 template <>
374 struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
375 } // namespace gtsam
gtsam::markdown
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
Definition: DiscreteValues.cpp:153
name
Annotation for function names.
Definition: attr.h:51
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
gtsam::DecisionTreeFactor::toDecisionTreeFactor
DecisionTreeFactor toDecisionTreeFactor() const override
Convert into a decision tree.
Definition: DecisionTreeFactor.h:198
s
RealScalar s
Definition: level1_cplx_impl.h:126
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
c
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:247
Ordering.h
Variable ordering for the elimination algorithm.
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
HybridValues
os
ofstream os("timeSchurFactors.csv")
gtsam::DecisionTreeFactor::shared_ptr
std::shared_ptr< DecisionTreeFactor > shared_ptr
Definition: DecisionTreeFactor.h:51
DiscreteFactor
Discrete values for.
gtsam::DecisionTreeFactor::max
DiscreteFactor::shared_ptr max(const Ordering &keys) const override
Create new factor by maximizing over all values with the same separator.
Definition: DecisionTreeFactor.h:219
Ring::mul
static double mul(const double &a, const double &b)
Definition: Ring.h:31
gtsam::DefaultKeyFormatter
KeyFormatter DefaultKeyFormatter
Assign default key formatter.
Definition: Key.cpp:30
gtsam::AlgebraicDecisionTree
Definition: AlgebraicDecisionTree.h:41
gtsam::DecisionTreeFactor::DecisionTreeFactor
DecisionTreeFactor(const DiscreteKey &key, const std::vector< double > &row)
Single-key specialization, with vector of doubles.
Definition: DecisionTreeFactor.h:116
AlgebraicDecisionTree.h
Algebraic Decision Trees.
gtsam::print
void print(const Matrix &A, const string &s, ostream &stream)
Definition: Matrix.cpp:145
gtsam::DecisionTreeFactor::Base
DiscreteFactor Base
Typedef to base class.
Definition: DecisionTreeFactor.h:50
ADT
AlgebraicDecisionTree< Key > ADT
Definition: testAlgebraicDecisionTree.cpp:32
table
ArrayXXf table(10, 4)
gtsam::row
const MATRIX::ConstRowXpr row(const MATRIX &A, size_t j)
Definition: base/Matrix.h:215
operator()
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
Definition: IndexedViewMethods.h:73
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
DiscreteFactor.h
Ring::max
static double max(const double &a, const double &b)
Definition: Ring.h:28
gtsam::DecisionTreeFactor::This
DecisionTreeFactor This
Definition: DecisionTreeFactor.h:49
gtsam::DecisionTreeFactor::operator*
DecisionTreeFactor operator*(const DecisionTreeFactor &f) const override
multiply two factors
Definition: DecisionTreeFactor.h:174
gtsam::dot
double dot(const V1 &a, const V2 &b)
Definition: Vector.h:196
gtsam::Assignment< Key >
gtsam::DecisionTreeFactor::evaluate
virtual double evaluate(const Assignment< Key > &values) const override
Definition: DecisionTreeFactor.h:140
gtsam::DecisionTreeFactor::operator/
DecisionTreeFactor operator/(const DecisionTreeFactor &f) const
Divide by factor f (safely). Division of a factor by another factor results in a function which inv...
Definition: DecisionTreeFactor.h:189
gtsam::DiscreteFactor::Binary
std::function< double(const double, const double)> Binary
Definition: DiscreteFactor.h:53
DiscreteKey.h
specialized key for discrete variables
key
const gtsam::Symbol key('X', 0)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
gtsam::b
const G & b
Definition: Group.h:79
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
Ring::add
static double add(const double &a, const double &b)
Definition: Ring.h:27
gtsam
traits
Definition: SFMdata.h:40
gtsam::Testable
Definition: Testable.h:152
gtsam::DecisionTreeFactor::max
double max() const override
Find the maximum value in the factor.
Definition: DecisionTreeFactor.h:211
gtsam::DiscreteFactor::shared_ptr
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Definition: DiscreteFactor.h:45
error
static double error
Definition: testRot3.cpp:37
gtsam::traits
Definition: Group.h:36
operator/
EIGEN_DEVICE_FUNC const EIGEN_STRONG_INLINE CwiseBinaryOp< internal::scalar_quotient_op< Scalar, typename OtherDerived::Scalar >, const Derived, const OtherDerived > operator/(const EIGEN_CURRENT_STORAGE_BASE_CLASS< OtherDerived > &other) const
Definition: ArrayCwiseBinaryOps.h:21
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
Ring.h
Real Ring definition.
gtsam::apply
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:467
gtsam::DecisionTreeFactor::sum
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override
Create new factor by summing all values with the same separator values.
Definition: DecisionTreeFactor.h:201
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::DecisionTreeFactor::max
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override
Create new factor by maximizing over all values with the same separator.
Definition: DecisionTreeFactor.h:214
uint64_t
unsigned __int64 uint64_t
Definition: ms_stdint.h:95
gtsam::DecisionTreeFactor::ADT
AlgebraicDecisionTree< Key > ADT
Definition: DecisionTreeFactor.h:52
N
#define N
Definition: igam.h:9
gtsam::html
string html(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of html.
Definition: DiscreteValues.cpp:158
gtsam::DiscreteFactor
Definition: DiscreteFactor.h:40
gtsam::DiscreteFactor::UnaryAssignment
std::function< double(const Assignment< Key > &, const double &)> UnaryAssignment
Definition: DiscreteFactor.h:52
Base
Definition: test_virtual_functions.cpp:156
gtsam::DiscreteFactor::Unary
std::function< double(const double &)> Unary
Definition: DiscreteFactor.h:50
gtsam::DecisionTreeFactor::nrValues
uint64_t nrValues() const override
Definition: DecisionTreeFactor.h:308
max
#define max(a, b)
Definition: datatypes.h:20
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::DecisionTreeFactor::operator*
DiscreteFactor::shared_ptr operator*(double s) const override
multiply with a scalar
Definition: DecisionTreeFactor.h:168
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::DecisionTreeFactor::DecisionTreeFactor
DecisionTreeFactor(const DiscreteKey &key, SOURCE table)
Single-key specialization.
Definition: DecisionTreeFactor.h:112
gtsam::DecisionTreeFactor::sum
DiscreteFactor::shared_ptr sum(const Ordering &keys) const override
Create new factor by summing all values with the same separator values.
Definition: DecisionTreeFactor.h:206


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:01:34