DecisionTreeFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #include <gtsam/base/FastSet.h>
24 
25 #include <utility>
26 
27 using namespace std;
28 
29 namespace gtsam {
30 
31  /* ************************************************************************ */
32  DecisionTreeFactor::DecisionTreeFactor() {}
33 
34  /* ************************************************************************ */
35  DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
36  const ADT& potentials)
37  : DiscreteFactor(keys.indices()),
38  ADT(potentials),
39  cardinalities_(keys.cardinalities()) {}
40 
41  /* ************************************************************************ */
43  : DiscreteFactor(c.keys()),
46 
47  /* ************************************************************************ */
49  double tol) const {
50  if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
51  return false;
52  } else {
53  const auto& f(static_cast<const DecisionTreeFactor&>(other));
54  return ADT::equals(f, tol);
55  }
56  }
57 
58  /* ************************************************************************ */
60  return -std::log(evaluate(values));
61  }
62 
63  /* ************************************************************************ */
65  return error(values.discrete());
66  }
67 
68  /* ************************************************************************ */
69  double DecisionTreeFactor::safe_div(const double& a, const double& b) {
70  // The use for safe_div is when we divide the product factor by the sum
71  // factor. If the product or sum is zero, we accord zero probability to the
72  // event.
73  return (a == 0 || b == 0) ? 0 : (a / b);
74  }
75 
76  /* ************************************************************************ */
77  void DecisionTreeFactor::print(const string& s,
78  const KeyFormatter& formatter) const {
79  cout << s;
80  cout << " f[";
81  for (auto&& key : keys()) {
82  cout << " (" << formatter(key) << "," << cardinality(key) << "),";
83  }
84  cout << " ]" << endl;
85  ADT::print("", formatter);
86  }
87 
88  /* ************************************************************************ */
90  ADT::Binary op) const {
91  map<Key, size_t> cs; // new cardinalities
92  // make unique key-cardinality map
93  for (Key j : keys()) cs[j] = cardinality(j);
94  for (Key j : f.keys()) cs[j] = f.cardinality(j);
95  // Convert map into keys
97  keys.reserve(cs.size());
98  for (const auto& key : cs) {
99  keys.emplace_back(key);
100  }
101  // apply operand
102  ADT result = ADT::apply(f, op);
103  // Make a new factor
104  return DecisionTreeFactor(keys, result);
105  }
106 
107  /* ************************************************************************ */
109  size_t nrFrontals, ADT::Binary op) const {
110  if (nrFrontals > size()) {
111  throw invalid_argument(
112  "DecisionTreeFactor::combine: invalid number of frontal "
113  "keys " +
114  std::to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
115  }
116 
117  // sum over nrFrontals keys
118  size_t i;
119  ADT result(*this);
120  for (i = 0; i < nrFrontals; i++) {
121  Key j = keys()[i];
122  result = result.combine(j, cardinality(j), op);
123  }
124 
125  // create new factor, note we start keys after nrFrontals
126  DiscreteKeys dkeys;
127  for (; i < keys().size(); i++) {
128  Key j = keys()[i];
129  dkeys.push_back(DiscreteKey(j, cardinality(j)));
130  }
131  return std::make_shared<DecisionTreeFactor>(dkeys, result);
132  }
133 
134  /* ************************************************************************ */
136  const Ordering& frontalKeys, ADT::Binary op) const {
137  if (frontalKeys.size() > size()) {
138  throw invalid_argument(
139  "DecisionTreeFactor::combine: invalid number of frontal "
140  "keys " +
141  std::to_string(frontalKeys.size()) + ", nr.keys=" +
142  std::to_string(size()));
143  }
144 
145  // sum over nrFrontals keys
146  size_t i;
147  ADT result(*this);
148  for (i = 0; i < frontalKeys.size(); i++) {
149  Key j = frontalKeys[i];
150  result = result.combine(j, cardinality(j), op);
151  }
152 
153  // create new factor, note we collect keys that are not in frontalKeys
154  // TODO(frank): why do we need this??? result should contain correct keys!!!
155  DiscreteKeys dkeys;
156  for (i = 0; i < keys().size(); i++) {
157  Key j = keys()[i];
158  // TODO(frank): inefficient!
159  if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
160  frontalKeys.end())
161  continue;
162  dkeys.push_back(DiscreteKey(j, cardinality(j)));
163  }
164  return std::make_shared<DecisionTreeFactor>(dkeys, result);
165  }
166 
167  /* ************************************************************************ */
168  std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
169  const {
170  // Get all possible assignments
171  DiscreteKeys pairs = discreteKeys();
172  // Reverse to make cartesian product output a more natural ordering.
173  DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
174  const auto assignments = DiscreteValues::CartesianProduct(rpairs);
175 
176  // Construct unordered_map with values
177  std::vector<std::pair<DiscreteValues, double>> result;
178  for (const auto& assignment : assignments) {
179  result.emplace_back(assignment, operator()(assignment));
180  }
181  return result;
182  }
183 
184  /* ************************************************************************ */
187  for (auto&& key : keys()) {
188  DiscreteKey dkey(key, cardinality(key));
189  if (std::find(result.begin(), result.end(), dkey) == result.end()) {
190  result.push_back(dkey);
191  }
192  }
193  return result;
194  }
195 
196  /* ************************************************************************ */
197  static std::string valueFormatter(const double& v) {
198  std::stringstream ss;
199  ss << std::setw(4) << std::setprecision(2) << std::fixed << v;
200  return ss.str();
201  }
202 
204  void DecisionTreeFactor::dot(std::ostream& os,
205  const KeyFormatter& keyFormatter,
206  bool showZero) const {
207  ADT::dot(os, keyFormatter, valueFormatter, showZero);
208  }
209 
211  void DecisionTreeFactor::dot(const std::string& name,
212  const KeyFormatter& keyFormatter,
213  bool showZero) const {
214  ADT::dot(name, keyFormatter, valueFormatter, showZero);
215  }
216 
218  std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
219  bool showZero) const {
220  return ADT::dot(keyFormatter, valueFormatter, showZero);
221  }
222 
223  // Print out header.
224  /* ************************************************************************ */
225  string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
226  const Names& names) const {
227  stringstream ss;
228 
229  // Print out header.
230  ss << "|";
231  for (auto& key : keys()) {
232  ss << keyFormatter(key) << "|";
233  }
234  ss << "value|\n";
235 
236  // Print out separator with alignment hints.
237  ss << "|";
238  for (size_t j = 0; j < size(); j++) ss << ":-:|";
239  ss << ":-:|\n";
240 
241  // Print out all rows.
242  auto rows = enumerate();
243  for (const auto& kv : rows) {
244  ss << "|";
245  auto assignment = kv.first;
246  for (auto& key : keys()) {
247  size_t index = assignment.at(key);
248  ss << DiscreteValues::Translate(names, key, index) << "|";
249  }
250  ss << kv.second << "|\n";
251  }
252  return ss.str();
253  }
254 
255  /* ************************************************************************ */
256  string DecisionTreeFactor::html(const KeyFormatter& keyFormatter,
257  const Names& names) const {
258  stringstream ss;
259 
260  // Print out preamble.
261  ss << "<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
262 
263  // Print out header row.
264  ss << " <tr>";
265  for (auto& key : keys()) {
266  ss << "<th>" << keyFormatter(key) << "</th>";
267  }
268  ss << "<th>value</th></tr>\n";
269 
270  // Finish header and start body.
271  ss << " </thead>\n <tbody>\n";
272 
273  // Print out all rows.
274  auto rows = enumerate();
275  for (const auto& kv : rows) {
276  ss << " <tr>";
277  auto assignment = kv.first;
278  for (auto& key : keys()) {
279  size_t index = assignment.at(key);
280  ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
281  }
282  ss << "<td>" << kv.second << "</td>"; // value
283  ss << "</tr>\n";
284  }
285  ss << " </tbody>\n</table>\n</div>";
286  return ss.str();
287  }
288 
289  /* ************************************************************************ */
291  const vector<double>& table)
292  : DiscreteFactor(keys.indices()),
293  AlgebraicDecisionTree<Key>(keys, table),
294  cardinalities_(keys.cardinalities()) {}
295 
296  /* ************************************************************************ */
298  const string& table)
299  : DiscreteFactor(keys.indices()),
300  AlgebraicDecisionTree<Key>(keys, table),
301  cardinalities_(keys.cardinalities()) {}
302 
303  /* ************************************************************************ */
304  DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
305  const size_t N = maxNrAssignments;
306 
307  // Get the probabilities in the decision tree so we can threshold.
308  std::vector<double> probabilities;
309  this->visitLeaf([&](const Leaf& leaf) {
310  size_t nrAssignments = leaf.nrAssignments();
311  double prob = leaf.constant();
312  probabilities.insert(probabilities.end(), nrAssignments, prob);
313  });
314 
315  // The number of probabilities can be lower than max_leaves
316  if (probabilities.size() <= N) {
317  return *this;
318  }
319 
320  std::sort(probabilities.begin(), probabilities.end(),
321  std::greater<double>{});
322 
323  double threshold = probabilities[N - 1];
324 
325  // Now threshold the decision tree
326  size_t total = 0;
327  auto thresholdFunc = [threshold, &total, N](const double& value) {
328  if (value < threshold || total >= N) {
329  return 0.0;
330  } else {
331  total += 1;
332  return value;
333  }
334  };
335  DecisionTree<Key, double> thresholded(*this, thresholdFunc);
336 
337  // Create pruned decision tree factor and return.
338  return DecisionTreeFactor(this->discreteKeys(), thresholded);
339  }
340 
341  /* ************************************************************************ */
342 } // namespace gtsam
const gtsam::Symbol key('X', 0)
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
static double safe_div(const double &a, const double &b)
std::function< double(const double &, const double &)> Binary
Definition: DecisionTree.h:64
void print(const std::string &s="", const typename Base::LabelFormatter &labelFormatter=&DefaultFormatter) const
print method customized to value type double.
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
size_t size() const
Definition: Factor.h:159
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
leaf::MyValues values
Definition: BFloat16.h:88
#define N
Definition: gksort.c:12
static std::string valueFormatter(const double &v)
EIGEN_DEVICE_FUNC const LogReturnType log() const
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
double error(const DiscreteValues &values) const
Calculate error for DiscreteValues x, is -log(probability).
const KeyFormatter & formatter
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
Values result
double evaluate(const DiscreteValues &values) const
size_t cardinality(Key j) const
Array< int, Dynamic, 1 > v
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
std::map< Key, size_t > cardinalities_
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
RealScalar s
void print(const std::string &s="DecisionTreeFactor:\, const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
const G & b
Definition: Group.h:86
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
static std::stringstream ss
Definition: testBTree.cpp:31
DiscreteValues::Names Names
Translation table from values to strings.
traits
Definition: chartTesting.h:28
std::shared_ptr< DecisionTreeFactor > shared_ptr
ofstream os("timeSchurFactors.csv")
ArrayXXf table(10, 4)
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
static std::string Translate(const Names &names, Key key, size_t index)
Translate an integer index value for given key to a string.
DecisionTreeFactor apply(const DecisionTreeFactor &f, ADT::Binary op) const
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.h:92
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
const KeyVector & keys() const
Access the factor&#39;s involved variable keys.
Definition: Factor.h:142
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
Annotation for function names.
Definition: attr.h:48
const G double tol
Definition: Group.h:86
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
const KeyVector keys
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
A thin wrapper around std::set that uses boost&#39;s fast_pool_allocator.
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
std::ptrdiff_t j
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:09