Go to the documentation of this file.
32 DecisionTreeFactor::DecisionTreeFactor() {}
36 const ADT& potentials)
74 std::vector<double> errors;
75 for (
const auto& assignment : assignments) {
76 errors.push_back(
error(assignment));
86 return (
a == 0 ||
b == 0) ? 0 : (
a /
b);
119 ADT::Binary op)
const {
123 for (
Key j :
f.keys()) cs[
j] =
f.cardinality(
j);
126 keys.reserve(cs.size());
127 for (
const auto&
key : cs) {
138 size_t nrFrontals, ADT::Binary op)
const {
139 if (nrFrontals >
size()) {
140 throw invalid_argument(
141 "DecisionTreeFactor::combine: invalid number of frontal "
143 std::to_string(nrFrontals) +
", nr.keys=" + std::to_string(
size()));
149 for (
i = 0;
i < nrFrontals;
i++) {
156 for (;
i <
keys().size();
i++) {
160 return std::make_shared<DecisionTreeFactor>(dkeys,
result);
165 const Ordering& frontalKeys, ADT::Binary op)
const {
166 if (frontalKeys.size() >
size()) {
167 throw invalid_argument(
168 "DecisionTreeFactor::combine: invalid number of frontal "
170 std::to_string(frontalKeys.size()) +
", nr.keys=" +
171 std::to_string(
size()));
177 for (
i = 0;
i < frontalKeys.size();
i++) {
178 Key j = frontalKeys[
i];
193 std::sort(frontalKeys_.begin(), frontalKeys_.end());
194 std::set_difference(
keys_.begin(),
keys_.end(), frontalKeys_.begin(),
195 frontalKeys_.end(), back_inserter(difference));
198 for (
Key key : difference) {
201 return std::make_shared<DecisionTreeFactor>(dkeys,
result);
214 std::vector<std::pair<DiscreteValues, double>>
result;
215 for (
const auto& assignment : assignments) {
216 result.emplace_back(assignment,
operator()(assignment));
226 std::vector<double> probs;
237 std::set<Key> assignment_keys;
238 for (
auto&& [k,
_] :
a) {
239 assignment_keys.insert(k);
243 std::vector<Key> diff;
244 std::set_difference(allKeys.begin(), allKeys.end(),
245 assignment_keys.begin(), assignment_keys.end(),
246 std::back_inserter(diff));
249 size_t nrAssignments = 1;
250 for (
auto&& k : diff) {
254 probs.insert(probs.end(), nrAssignments,
p);
267 std::stringstream
ss;
268 ss << std::setw(4) << std::setprecision(2) << std::fixed <<
v;
275 bool showZero)
const {
282 bool showZero)
const {
288 bool showZero)
const {
301 ss << keyFormatter(
key) <<
"|";
307 for (
size_t j = 0;
j <
size();
j++)
ss <<
":-:|";
312 for (
const auto& kv :
rows) {
314 auto assignment = kv.first;
316 size_t index = assignment.at(
key);
319 ss << kv.second <<
"|\n";
330 ss <<
"<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
335 ss <<
"<th>" << keyFormatter(
key) <<
"</th>";
337 ss <<
"<th>value</th></tr>\n";
340 ss <<
" </thead>\n <tbody>\n";
344 for (
const auto& kv :
rows) {
346 auto assignment = kv.first;
348 size_t index = assignment.at(
key);
351 ss <<
"<td>" << kv.second <<
"</td>";
354 ss <<
" </tbody>\n</table>\n</div>";
360 const vector<double>&
table)
372 const size_t N = maxNrAssignments;
378 if (probabilities.size() <=
N) {
383 std::greater<double>{});
389 auto thresholdFunc = [threshold, &total,
N](
const double&
value) {
390 if (value < threshold || total >=
N) {
Annotation for function names.
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
std::map< Key, size_t > cardinalities_
Map of Keys and their cardinalities.
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
const KeyFormatter & formatter
const EIGEN_DEVICE_FUNC LogReturnType log() const
DiscreteKeys is a set of keys that can be assembled using the & operator.
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
AlgebraicDecisionTree< Key > errorTree() const override
Compute error for each assignment and return as a tree.
ofstream os("timeSchurFactors.csv")
const_iterator begin() const
std::shared_ptr< DecisionTreeFactor > shared_ptr
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
static std::stringstream ss
double evaluate(const DiscreteValues &values) const
A thin wrapper around std::set that uses boost's fast_pool_allocator.
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
double error(const DiscreteValues &values) const
Calculate error for DiscreteValues x, is -log(probability).
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
void print(const std::string &s="DecisionTreeFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
DecisionTree apply(const Unary &op) const
const_iterator end() const
const gtsam::Symbol key('X', 0)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
size_t cardinality(Key j) const
KeyVector keys_
The keys involved in this factor.
DiscreteValues::Names Names
Translation table from values to strings.
void print(const std::string &s="", const typename Base::LabelFormatter &labelFormatter=&DefaultFormatter) const
print method customized to value type double.
const KeyVector & keys() const
Access the factor's involved variable keys.
std::pair< Key, size_t > DiscreteKey
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
Array< int, Dynamic, 1 > v
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
static std::string Translate(const Names &names, Key key, size_t index)
Translate an integer index value for given key to a string.
static std::string valueFormatter(const double &v)
std::uint64_t Key
Integer nonlinear key type.
constexpr descr< N - 1 > _(char const (&text)[N])
DecisionTreeFactor apply(ADT::Unary op) const
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
std::vector< double > probabilities() const
Get all the probabilities in order of assignment values.
static double safe_div(const double &a, const double &b)
gtsam
Author(s):
autogenerated on Thu Jun 13 2024 03:02:10