DecisionTreeFactor.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
19 #pragma once
20 
25 
26 #include <algorithm>
27 #include <memory>
28 #include <map>
29 #include <stdexcept>
30 #include <string>
31 #include <utility>
32 #include <vector>
33 
34 namespace gtsam {
35 
36  class DiscreteConditional;
37  class HybridValues;
38 
44  class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
45  public AlgebraicDecisionTree<Key> {
46  public:
47  // typedefs needed to play nice with gtsam
49  typedef DiscreteFactor Base;
50  typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
52 
53  protected:
54  std::map<Key, size_t> cardinalities_;
55 
56  public:
59 
62 
64  DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
65 
68  const std::vector<double>& table);
69 
71  DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
72 
74  template <class SOURCE>
77 
79  DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
81 
83  explicit DecisionTreeFactor(const DiscreteConditional& c);
84 
88 
90  bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
91 
92  // print
93  void print(
94  const std::string& s = "DecisionTreeFactor:\n",
95  const KeyFormatter& formatter = DefaultKeyFormatter) const override;
96 
100 
103  double evaluate(const DiscreteValues& values) const {
104  return ADT::operator()(values);
105  }
106 
108  double operator()(const DiscreteValues& values) const override {
109  return ADT::operator()(values);
110  }
111 
113  double error(const DiscreteValues& values) const;
114 
117  return apply(f, ADT::Ring::mul);
118  }
119 
120  static double safe_div(const double& a, const double& b);
121 
122  size_t cardinality(Key j) const { return cardinalities_.at(j); }
123 
126  return apply(f, safe_div);
127  }
128 
130  DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
131 
133  shared_ptr sum(size_t nrFrontals) const {
134  return combine(nrFrontals, ADT::Ring::add);
135  }
136 
138  shared_ptr sum(const Ordering& keys) const {
139  return combine(keys, ADT::Ring::add);
140  }
141 
143  shared_ptr max(size_t nrFrontals) const {
144  return combine(nrFrontals, ADT::Ring::max);
145  }
146 
148  shared_ptr max(const Ordering& keys) const {
149  return combine(keys, ADT::Ring::max);
150  }
151 
155 
161  DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
162 
169  shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
170 
177  shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
178 
180  std::vector<std::pair<DiscreteValues, double>> enumerate() const;
181 
183  DiscreteKeys discreteKeys() const;
184 
203  DecisionTreeFactor prune(size_t maxNrAssignments) const;
204 
208 
210  void dot(std::ostream& os,
211  const KeyFormatter& keyFormatter = DefaultKeyFormatter,
212  bool showZero = true) const;
213 
215  void dot(const std::string& name,
216  const KeyFormatter& keyFormatter = DefaultKeyFormatter,
217  bool showZero = true) const;
218 
220  std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
221  bool showZero = true) const;
222 
230  std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
231  const Names& names = {}) const override;
232 
240  std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
241  const Names& names = {}) const override;
242 
246 
251  double error(const HybridValues& values) const override;
252 
254 
255  private:
256 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
257 
258  friend class boost::serialization::access;
259  template <class ARCHIVE>
260  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
261  ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
262  ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
263  ar& BOOST_SERIALIZATION_NVP(cardinalities_);
264  }
265 #endif
266  };
267 
268 // traits
269 template <>
270 struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
271 } // namespace gtsam
void print(const Matrix &A, const string &s, ostream &stream)
Definition: Matrix.cpp:155
const gtsam::Symbol key('X', 0)
DecisionTreeFactor operator*(const DecisionTreeFactor &f) const override
multiply two factors
#define max(a, b)
Definition: datatypes.h:20
const MATRIX::ConstRowXpr row(const MATRIX &A, size_t j)
Definition: base/Matrix.h:221
double dot(const V1 &a, const V2 &b)
Definition: Vector.h:195
std::string serialize(const T &input)
serializes to a string
DecisionTreeFactor operator/(const DecisionTreeFactor &f) const
divide by factor f (safely)
double mul(const double &a, const double &b)
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:398
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
shared_ptr max(size_t nrFrontals) const
Create new factor by maximizing over all values with the same separator.
leaf::MyValues values
DecisionTreeFactor(const DiscreteKey &key, SOURCE table)
Single-key specialization.
static const KeyFormatter DefaultKeyFormatter
Definition: Key.h:43
Algebraic Decision Trees.
const KeyFormatter & formatter
DecisionTreeFactor toDecisionTreeFactor() const override
Convert into a decisiontree.
string html(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of html.
double evaluate(const DiscreteValues &values) const
size_t cardinality(Key j) const
double operator()(const DiscreteValues &values) const override
Evaluate probability distribution, sugar.
std::map< Key, size_t > cardinalities_
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
RealScalar s
DiscreteFactor Base
Typedef to base class.
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
const G & b
Definition: Group.h:86
DiscreteValues::Names Names
Translation table from values to strings.
traits
Definition: chartTesting.h:28
specialized key for discrete variables
std::shared_ptr< DecisionTreeFactor > shared_ptr
ofstream os("timeSchurFactors.csv")
ArrayXXf table(10, 4)
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
shared_ptr sum(size_t nrFrontals) const
Create new factor by summing all values with the same separator values.
graph add(PriorFactor< Pose2 >(1, priorMean, priorNoise))
static double error
Definition: testRot3.cpp:37
shared_ptr max(const Ordering &keys) const
Create new factor by maximizing over all values with the same separator.
Annotation for function names.
Definition: attr.h:48
const G double tol
Definition: Group.h:86
const KeyVector keys
shared_ptr sum(const Ordering &keys) const
Create new factor by summing all values with the same separator values.
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
std::ptrdiff_t j
AlgebraicDecisionTree< Key > ADT
DecisionTreeFactor(const DiscreteKey &key, const std::vector< double > &row)
Single-key specialization, with vector of doubles.
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:09