DecisionTreeFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #include <gtsam/base/FastSet.h>
24 
25 #include <utility>
26 
27 using namespace std;
28 
29 namespace gtsam {
30 
31  /* ************************************************************************ */
32  DecisionTreeFactor::DecisionTreeFactor() {}
33 
34  /* ************************************************************************ */
35  DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
36  const ADT& potentials)
37  : DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
38 
39  /* ************************************************************************ */
41  : DiscreteFactor(c.keys(), c.cardinalities()),
43 
44  /* ************************************************************************ */
46  double tol) const {
47  if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
48  return false;
49  } else {
50  const auto& f(static_cast<const DecisionTreeFactor&>(other));
51  return ADT::equals(f, tol);
52  }
53  }
54 
55  /* ************************************************************************ */
57  return -std::log(evaluate(values));
58  }
59 
60  /* ************************************************************************ */
62  return error(values.discrete());
63  }
64 
65  /* ************************************************************************ */
67  // Get all possible assignments
68  DiscreteKeys dkeys = discreteKeys();
69  // Reverse to make cartesian product output a more natural ordering.
70  DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
71  const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
72 
73  // Construct vector with error values
74  std::vector<double> errors;
75  for (const auto& assignment : assignments) {
76  errors.push_back(error(assignment));
77  }
78  return AlgebraicDecisionTree<Key>(dkeys, errors);
79  }
80 
81  /* ************************************************************************ */
82  double DecisionTreeFactor::safe_div(const double& a, const double& b) {
83  // The use for safe_div is when we divide the product factor by the sum
84  // factor. If the product or sum is zero, we accord zero probability to the
85  // event.
86  return (a == 0 || b == 0) ? 0 : (a / b);
87  }
88 
89  /* ************************************************************************ */
90  void DecisionTreeFactor::print(const string& s,
91  const KeyFormatter& formatter) const {
92  cout << s;
93  cout << " f[";
94  for (auto&& key : keys()) {
95  cout << " (" << formatter(key) << "," << cardinality(key) << "),";
96  }
97  cout << " ]" << endl;
98  ADT::print("", formatter);
99  }
100 
101  /* ************************************************************************ */
103  // apply operand
104  ADT result = ADT::apply(op);
105  // Make a new factor
107  }
108 
109  /* ************************************************************************ */
110  DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
111  // apply operand
112  ADT result = ADT::apply(op);
113  // Make a new factor
115  }
116 
117  /* ************************************************************************ */
119  ADT::Binary op) const {
120  map<Key, size_t> cs; // new cardinalities
121  // make unique key-cardinality map
122  for (Key j : keys()) cs[j] = cardinality(j);
123  for (Key j : f.keys()) cs[j] = f.cardinality(j);
124  // Convert map into keys
126  keys.reserve(cs.size());
127  for (const auto& key : cs) {
128  keys.emplace_back(key);
129  }
130  // apply operand
131  ADT result = ADT::apply(f, op);
132  // Make a new factor
133  return DecisionTreeFactor(keys, result);
134  }
135 
136  /* ************************************************************************ */
138  size_t nrFrontals, ADT::Binary op) const {
139  if (nrFrontals > size()) {
140  throw invalid_argument(
141  "DecisionTreeFactor::combine: invalid number of frontal "
142  "keys " +
143  std::to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
144  }
145 
146  // sum over nrFrontals keys
147  size_t i;
148  ADT result(*this);
149  for (i = 0; i < nrFrontals; i++) {
150  Key j = keys()[i];
151  result = result.combine(j, cardinality(j), op);
152  }
153 
154  // create new factor, note we start keys after nrFrontals
155  DiscreteKeys dkeys;
156  for (; i < keys().size(); i++) {
157  Key j = keys()[i];
158  dkeys.push_back(DiscreteKey(j, cardinality(j)));
159  }
160  return std::make_shared<DecisionTreeFactor>(dkeys, result);
161  }
162 
163  /* ************************************************************************ */
165  const Ordering& frontalKeys, ADT::Binary op) const {
166  if (frontalKeys.size() > size()) {
167  throw invalid_argument(
168  "DecisionTreeFactor::combine: invalid number of frontal "
169  "keys " +
170  std::to_string(frontalKeys.size()) + ", nr.keys=" +
171  std::to_string(size()));
172  }
173 
174  // sum over nrFrontals keys
175  size_t i;
176  ADT result(*this);
177  for (i = 0; i < frontalKeys.size(); i++) {
178  Key j = frontalKeys[i];
179  result = result.combine(j, cardinality(j), op);
180  }
181 
182  // create new factor, note we collect keys that are not in frontalKeys
183  /*
184  Due to branch merging, the labels in `result` may be missing some keys
185  E.g. After branch merging, we may get a ADT like:
186  Leaf [2] 1.0204082
187 
188  This is missing the key values used for branching.
189  */
190  KeyVector difference, frontalKeys_(frontalKeys), keys_(keys());
191  // Get the difference of the frontalKeys and the factor keys using set_difference
192  std::sort(keys_.begin(), keys_.end());
193  std::sort(frontalKeys_.begin(), frontalKeys_.end());
194  std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(),
195  frontalKeys_.end(), back_inserter(difference));
196 
197  DiscreteKeys dkeys;
198  for (Key key : difference) {
199  dkeys.push_back(DiscreteKey(key, cardinality(key)));
200  }
201  return std::make_shared<DecisionTreeFactor>(dkeys, result);
202  }
203 
204  /* ************************************************************************ */
205  std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
206  const {
207  // Get all possible assignments
208  DiscreteKeys pairs = discreteKeys();
209  // Reverse to make cartesian product output a more natural ordering.
210  DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
211  const auto assignments = DiscreteValues::CartesianProduct(rpairs);
212 
213  // Construct unordered_map with values
214  std::vector<std::pair<DiscreteValues, double>> result;
215  for (const auto& assignment : assignments) {
216  result.emplace_back(assignment, operator()(assignment));
217  }
218  return result;
219  }
220 
221  /* ************************************************************************ */
222  std::vector<double> DecisionTreeFactor::probabilities() const {
223  // Set of all keys
224  std::set<Key> allKeys(keys().begin(), keys().end());
225 
226  std::vector<double> probs;
227 
228  /* An operation that takes each leaf probability, and computes the
229  * nrAssignments by checking the difference between the keys in the factor
230  * and the keys in the assignment.
231  * The nrAssignments is then used to append
232  * the correct number of leaf probability values to the `probs` vector
233  * defined above.
234  */
235  auto op = [&](const Assignment<Key>& a, double p) {
236  // Get all the keys in the current assignment
237  std::set<Key> assignment_keys;
238  for (auto&& [k, _] : a) {
239  assignment_keys.insert(k);
240  }
241 
242  // Find the keys missing in the assignment
243  std::vector<Key> diff;
244  std::set_difference(allKeys.begin(), allKeys.end(),
245  assignment_keys.begin(), assignment_keys.end(),
246  std::back_inserter(diff));
247 
248  // Compute the total number of assignments in the (pruned) subtree
249  size_t nrAssignments = 1;
250  for (auto&& k : diff) {
251  nrAssignments *= cardinalities_.at(k);
252  }
253  // Add p `nrAssignments` times to the probs vector.
254  probs.insert(probs.end(), nrAssignments, p);
255 
256  return p;
257  };
258 
259  // Go through the tree
260  this->apply(op);
261 
262  return probs;
263  }
264 
265  /* ************************************************************************ */
266  static std::string valueFormatter(const double& v) {
267  std::stringstream ss;
268  ss << std::setw(4) << std::setprecision(2) << std::fixed << v;
269  return ss.str();
270  }
271 
273  void DecisionTreeFactor::dot(std::ostream& os,
274  const KeyFormatter& keyFormatter,
275  bool showZero) const {
276  ADT::dot(os, keyFormatter, valueFormatter, showZero);
277  }
278 
280  void DecisionTreeFactor::dot(const std::string& name,
281  const KeyFormatter& keyFormatter,
282  bool showZero) const {
283  ADT::dot(name, keyFormatter, valueFormatter, showZero);
284  }
285 
287  std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
288  bool showZero) const {
289  return ADT::dot(keyFormatter, valueFormatter, showZero);
290  }
291 
292  // Print out header.
293  /* ************************************************************************ */
294  string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
295  const Names& names) const {
296  stringstream ss;
297 
298  // Print out header.
299  ss << "|";
300  for (auto& key : keys()) {
301  ss << keyFormatter(key) << "|";
302  }
303  ss << "value|\n";
304 
305  // Print out separator with alignment hints.
306  ss << "|";
307  for (size_t j = 0; j < size(); j++) ss << ":-:|";
308  ss << ":-:|\n";
309 
310  // Print out all rows.
311  auto rows = enumerate();
312  for (const auto& kv : rows) {
313  ss << "|";
314  auto assignment = kv.first;
315  for (auto& key : keys()) {
316  size_t index = assignment.at(key);
317  ss << DiscreteValues::Translate(names, key, index) << "|";
318  }
319  ss << kv.second << "|\n";
320  }
321  return ss.str();
322  }
323 
324  /* ************************************************************************ */
325  string DecisionTreeFactor::html(const KeyFormatter& keyFormatter,
326  const Names& names) const {
327  stringstream ss;
328 
329  // Print out preamble.
330  ss << "<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
331 
332  // Print out header row.
333  ss << " <tr>";
334  for (auto& key : keys()) {
335  ss << "<th>" << keyFormatter(key) << "</th>";
336  }
337  ss << "<th>value</th></tr>\n";
338 
339  // Finish header and start body.
340  ss << " </thead>\n <tbody>\n";
341 
342  // Print out all rows.
343  auto rows = enumerate();
344  for (const auto& kv : rows) {
345  ss << " <tr>";
346  auto assignment = kv.first;
347  for (auto& key : keys()) {
348  size_t index = assignment.at(key);
349  ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
350  }
351  ss << "<td>" << kv.second << "</td>"; // value
352  ss << "</tr>\n";
353  }
354  ss << " </tbody>\n</table>\n</div>";
355  return ss.str();
356  }
357 
358  /* ************************************************************************ */
360  const vector<double>& table)
361  : DiscreteFactor(keys.indices(), keys.cardinalities()),
363 
364  /* ************************************************************************ */
366  const string& table)
367  : DiscreteFactor(keys.indices(), keys.cardinalities()),
369 
370  /* ************************************************************************ */
371  DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
372  const size_t N = maxNrAssignments;
373 
374  // Get the probabilities in the decision tree so we can threshold.
375  std::vector<double> probabilities = this->probabilities();
376 
377  // The number of probabilities can be lower than max_leaves
378  if (probabilities.size() <= N) {
379  return *this;
380  }
381 
382  std::sort(probabilities.begin(), probabilities.end(),
383  std::greater<double>{});
384 
385  double threshold = probabilities[N - 1];
386 
387  // Now threshold the decision tree
388  size_t total = 0;
389  auto thresholdFunc = [threshold, &total, N](const double& value) {
390  if (value < threshold || total >= N) {
391  return 0.0;
392  } else {
393  total += 1;
394  return value;
395  }
396  };
397  DecisionTree<Key, double> thresholded(*this, thresholdFunc);
398 
399  // Create pruned decision tree factor and return.
400  return DecisionTreeFactor(this->discreteKeys(), thresholded);
401  }
402 
403  /* ************************************************************************ */
404 } // namespace gtsam
name
Annotation for function names.
Definition: attr.h:51
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
gtsam::HybridValues
Definition: HybridValues.h:38
gtsam::DecisionTreeFactor::DecisionTreeFactor
DecisionTreeFactor()
Definition: DecisionTreeFactor.cpp:32
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::DecisionTreeFactor::dot
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
Definition: DecisionTreeFactor.cpp:273
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
c
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
gtsam::DiscreteFactor::cardinalities_
std::map< Key, size_t > cardinalities_
Map of Keys and their cardinalities.
Definition: DiscreteFactor.h:51
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:371
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
log
const EIGEN_DEVICE_FUNC LogReturnType log() const
Definition: ArrayCwiseUnaryOps.h:128
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::AlgebraicDecisionTree< Key >::equals
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
Definition: AlgebraicDecisionTree.h:222
DiscreteConditional.h
gtsam::DecisionTreeFactor::errorTree
AlgebraicDecisionTree< Key > errorTree() const override
Compute error for each assignment and return as a tree.
Definition: DecisionTreeFactor.cpp:66
os
ofstream os("timeSchurFactors.csv")
gtsam::Factor::begin
const_iterator begin() const
Definition: Factor.h:145
gtsam::DecisionTreeFactor::shared_ptr
std::shared_ptr< DecisionTreeFactor > shared_ptr
Definition: DecisionTreeFactor.h:50
gtsam::KeyVector
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:92
result
Values result
Definition: OdometryOptimize.cpp:8
rows
int rows
Definition: Tutorial_commainit_02.cpp:1
gtsam::AlgebraicDecisionTree< Key >
test_eigen_tensor.indices
indices
Definition: test_eigen_tensor.py:31
ss
static std::stringstream ss
Definition: testBTree.cpp:31
gtsam::DecisionTreeFactor::evaluate
double evaluate(const DiscreteValues &values) const
Definition: DecisionTreeFactor.h:134
FastSet.h
A thin wrapper around std::set that uses boost's fast_pool_allocator.
table
ArrayXXf table(10, 4)
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
gtsam::DiscreteValues::CartesianProduct
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
Definition: DiscreteValues.h:85
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::DecisionTreeFactor::equals
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
Definition: DecisionTreeFactor.cpp:45
gtsam::DecisionTreeFactor::error
double error(const DiscreteValues &values) const
Calculate error for DiscreteValues x, is -log(probability).
Definition: DecisionTreeFactor.cpp:56
gtsam::DecisionTree< Key, double >::dot
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
Definition: DecisionTree-inl.h:981
gtsam::DecisionTreeFactor::print
void print(const std::string &s="DecisionTreeFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Definition: DecisionTreeFactor.cpp:90
gtsam::Assignment< Key >
gtsam::DecisionTree< Key, double >::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:921
gtsam::Factor::end
const_iterator end() const
Definition: Factor.h:148
key
const gtsam::Symbol key('X', 0)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
gtsam::DecisionTree< Key, double >
gtsam::DiscreteFactor::cardinality
size_t cardinality(Key j) const
Definition: DiscreteFactor.h:93
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
gtsam::b
const G & b
Definition: Group.h:79
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
gtsam
traits
Definition: chartTesting.h:28
gtsam::Factor::keys_
KeyVector keys_
The keys involved in this factor.
Definition: Factor.h:87
gtsam::DiscreteFactor::Names
DiscreteValues::Names Names
Translation table from values to strings.
Definition: DiscreteFactor.h:121
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
leaf::values
leaf::MyValues values
gtsam::AlgebraicDecisionTree< Key >::print
void print(const std::string &s="", const typename Base::LabelFormatter &labelFormatter=&DefaultFormatter) const
print method customized to value type double.
Definition: AlgebraicDecisionTree.h:210
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:142
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::DecisionTreeFactor::combine
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const
Definition: DecisionTreeFactor.cpp:137
gtsam::DecisionTreeFactor::html
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
Definition: DecisionTreeFactor.cpp:325
v
Array< int, Dynamic, 1 > v
Definition: Array_initializer_list_vector_cxx11.cpp:1
gtsam::DecisionTreeFactor::enumerate
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
Definition: DecisionTreeFactor.cpp:205
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::DiscreteFactor::discreteKeys
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Definition: DiscreteFactor.cpp:32
N
#define N
Definition: igam.h:9
gtsam::DiscreteFactor
Definition: DiscreteFactor.h:39
gtsam::DiscreteValues::Translate
static std::string Translate(const Names &names, Key key, size_t index)
Translate an integer index value for given key to a string.
Definition: DiscreteValues.cpp:78
gtsam::valueFormatter
static std::string valueFormatter(const double &v)
Definition: DecisionTreeFactor.cpp:266
gtsam::Factor::size
size_t size() const
Definition: Factor.h:159
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
DecisionTreeFactor.h
gtsam::DecisionTreeFactor::apply
DecisionTreeFactor apply(ADT::Unary op) const
Definition: DecisionTreeFactor.cpp:102
gtsam::DecisionTreeFactor::markdown
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
Definition: DecisionTreeFactor.cpp:294
test_callbacks.value
value
Definition: test_callbacks.py:158
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::DecisionTreeFactor::probabilities
std::vector< double > probabilities() const
Get all the probabilities in order of assignment values.
Definition: DecisionTreeFactor.cpp:222
gtsam::DecisionTreeFactor::safe_div
static double safe_div(const double &a, const double &b)
Definition: DecisionTreeFactor.cpp:82


gtsam
Author(s):
autogenerated on Thu Jun 13 2024 03:02:10