32 DecisionTreeFactor::DecisionTreeFactor() {}
36 const ADT& potentials)
39 cardinalities_(keys.cardinalities()) {}
50 if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
53 const auto&
f(static_cast<const DecisionTreeFactor&>(other));
73 return (a == 0 || b == 0) ? 0 : (a /
b);
97 keys.reserve(cs.size());
98 for (
const auto&
key : cs) {
99 keys.emplace_back(
key);
110 if (nrFrontals >
size()) {
111 throw invalid_argument(
112 "DecisionTreeFactor::combine: invalid number of frontal " 114 std::to_string(nrFrontals) +
", nr.keys=" + std::to_string(
size()));
120 for (i = 0; i < nrFrontals; i++) {
127 for (; i <
keys().size(); i++) {
131 return std::make_shared<DecisionTreeFactor>(dkeys,
result);
137 if (frontalKeys.size() >
size()) {
138 throw invalid_argument(
139 "DecisionTreeFactor::combine: invalid number of frontal " 141 std::to_string(frontalKeys.size()) +
", nr.keys=" +
142 std::to_string(
size()));
148 for (i = 0; i < frontalKeys.size(); i++) {
149 Key j = frontalKeys[
i];
156 for (i = 0; i <
keys().size(); i++) {
159 if (std::find(frontalKeys.begin(), frontalKeys.end(),
j) !=
164 return std::make_shared<DecisionTreeFactor>(dkeys,
result);
177 std::vector<std::pair<DiscreteValues, double>>
result;
178 for (
const auto& assignment : assignments) {
179 result.emplace_back(assignment,
operator()(assignment));
189 if (std::find(result.begin(), result.end(), dkey) == result.end()) {
190 result.push_back(dkey);
198 std::stringstream
ss;
199 ss << std::setw(4) << std::setprecision(2) << std::fixed <<
v;
206 bool showZero)
const {
213 bool showZero)
const {
219 bool showZero)
const {
232 ss << keyFormatter(
key) <<
"|";
238 for (
size_t j = 0;
j <
size();
j++) ss <<
":-:|";
243 for (
const auto& kv :
rows) {
245 auto assignment = kv.first;
247 size_t index = assignment.at(
key);
250 ss << kv.second <<
"|\n";
261 ss <<
"<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
266 ss <<
"<th>" << keyFormatter(
key) <<
"</th>";
268 ss <<
"<th>value</th></tr>\n";
271 ss <<
" </thead>\n <tbody>\n";
275 for (
const auto& kv :
rows) {
277 auto assignment = kv.first;
279 size_t index = assignment.at(
key);
282 ss <<
"<td>" << kv.second <<
"</td>";
285 ss <<
" </tbody>\n</table>\n</div>";
291 const vector<double>&
table)
305 const size_t N = maxNrAssignments;
308 std::vector<double> probabilities;
310 size_t nrAssignments = leaf.nrAssignments();
311 double prob = leaf.constant();
312 probabilities.insert(probabilities.end(), nrAssignments, prob);
316 if (probabilities.size() <=
N) {
320 std::sort(probabilities.begin(), probabilities.end(),
321 std::greater<double>{});
323 double threshold = probabilities[N - 1];
327 auto thresholdFunc = [threshold, &total,
N](
const double&
value) {
328 if (value < threshold || total >= N) {
const gtsam::Symbol key('X', 0)
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
static double safe_div(const double &a, const double &b)
std::function< double(const double &, const double &)> Binary
void print(const std::string &s="", const typename Base::LabelFormatter &labelFormatter=&DefaultFormatter) const
print method customized to value type double.
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
static std::string valueFormatter(const double &v)
EIGEN_DEVICE_FUNC const LogReturnType log() const
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
double error(const DiscreteValues &values) const
Calculate error for DiscreteValues x, is -log(probability).
const KeyFormatter & formatter
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
double evaluate(const DiscreteValues &values) const
size_t cardinality(Key j) const
Array< int, Dynamic, 1 > v
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
std::map< Key, size_t > cardinalities_
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
void print(const std::string &s="DecisionTreeFactor:\, const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
static std::stringstream ss
DiscreteValues::Names Names
Translation table from values to strings.
std::shared_ptr< DecisionTreeFactor > shared_ptr
ofstream os("timeSchurFactors.csv")
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
static std::string Translate(const Names &names, Key key, size_t index)
Translate an integer index value for given key to a string.
DecisionTreeFactor apply(const DecisionTreeFactor &f, ADT::Binary op) const
const DiscreteValues & discrete() const
Return the discrete values.
std::pair< Key, size_t > DiscreteKey
const KeyVector & keys() const
Access the factor's involved variable keys.
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
Annotation for function names.
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
A thin wrapper around std::set that uses boost's fast_pool_allocator.
std::uint64_t Key
Integer nonlinear key type.
DiscreteKeys is a set of keys that can be assembled using the & operator.