Go to the documentation of this file.
33 DecisionTreeFactor::DecisionTreeFactor() {}
37 const ADT& potentials)
70 if (
auto tf = std::dynamic_pointer_cast<TableFactor>(
f)) {
76 }
else if (
auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
78 result = std::make_shared<DecisionTreeFactor>(this->
operator*(*dtf));
85 result = std::make_shared<DecisionTreeFactor>(
f->operator*(*
this));
93 if (
auto tf = std::dynamic_pointer_cast<TableFactor>(
f)) {
96 return std::make_shared<TableFactor>(tf->operator/(
TableFactor(*
this)));
98 }
else if (
auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
100 return std::make_shared<DecisionTreeFactor>(this->
operator/(*dtf));
104 return std::make_shared<DecisionTreeFactor>(
105 this->
operator/(f->toDecisionTreeFactor()));
114 return (
a == 0 ||
b == 0) ? 0 : (
a /
b);
125 cout <<
" ]" << endl;
151 for (
Key j :
f.keys()) cs[
j] =
f.cardinality(
j);
154 keys.reserve(cs.size());
155 for (
const auto&
key : cs) {
167 if (nrFrontals >
size()) {
168 throw invalid_argument(
169 "DecisionTreeFactor::combine: invalid number of frontal "
171 std::to_string(nrFrontals) +
", nr.keys=" + std::to_string(
size()));
177 for (
i = 0;
i < nrFrontals;
i++) {
184 for (;
i <
keys_.size();
i++) {
188 return std::make_shared<DecisionTreeFactor>(dkeys,
result);
194 if (frontalKeys.size() >
size()) {
195 throw invalid_argument(
196 "DecisionTreeFactor::combine: invalid number of frontal "
198 std::to_string(frontalKeys.size()) +
", nr.keys=" +
199 std::to_string(
size()));
205 for (
i = 0;
i < frontalKeys.size();
i++) {
206 Key j = frontalKeys[
i];
227 return std::make_shared<DecisionTreeFactor>(dkeys,
result);
240 std::vector<std::pair<DiscreteValues, double>>
result;
241 for (
const auto& assignment : assignments) {
252 std::vector<double> probs;
263 std::set<Key> assignment_keys;
264 for (
auto&& [k,
_] :
a) {
265 assignment_keys.insert(k);
269 std::vector<Key> diff;
270 std::set_difference(allKeys.begin(), allKeys.end(),
271 assignment_keys.begin(), assignment_keys.end(),
272 std::back_inserter(diff));
275 size_t nrAssignments = 1;
276 for (
auto&& k : diff) {
280 probs.insert(probs.end(), nrAssignments,
p);
293 std::stringstream
ss;
294 ss << std::setw(4) << std::setprecision(2) << std::fixed <<
v;
301 bool showZero)
const {
308 bool showZero)
const {
314 bool showZero)
const {
327 ss << keyFormatter(
key) <<
"|";
333 for (
size_t j = 0;
j <
size();
j++)
ss <<
":-:|";
338 for (
const auto& kv :
rows) {
340 auto assignment = kv.first;
342 size_t index = assignment.at(
key);
345 ss << kv.second <<
"|\n";
356 ss <<
"<div>\n<table class='DecisionTreeFactor'>\n <thead>\n";
361 ss <<
"<th>" << keyFormatter(
key) <<
"</th>";
363 ss <<
"<th>value</th></tr>\n";
366 ss <<
" </thead>\n <tbody>\n";
370 for (
const auto& kv :
rows) {
372 auto assignment = kv.first;
374 size_t index = assignment.at(
key);
377 ss <<
"<td>" << kv.second <<
"</td>";
380 ss <<
" </tbody>\n</table>\n</div>";
386 const vector<double>&
table)
401 std::vector<double>
v_;
410 std::push_heap(
v_.begin(),
v_.end(), std::greater<double>{});
415 for (
size_t i = 0;
i <
n; ++
i) {
417 std::push_heap(
v_.begin(),
v_.end(), std::greater<double>{});
423 std::pop_heap(
v_.begin(),
v_.end(), std::greater<double>{});
424 double x =
v_.back();
438 std::cout << (
s.empty() ?
"" :
s +
" ");
439 for (
size_t i = 0;
i <
v_.size();
i++) {
440 std::cout <<
v_.at(
i);
441 if (
v_.size() > 1 &&
i <
v_.size() - 1) std::cout <<
", ";
443 std::cout << std::endl;
450 size_t size()
const {
return v_.size(); }
456 std::set<Key> allKeys = this->
labels();
461 std::set<Key> assignment_keys;
462 for (
auto&& [k,
_] :
a) {
463 assignment_keys.insert(k);
467 std::vector<Key> diff;
468 std::set_difference(allKeys.begin(), allKeys.end(),
469 assignment_keys.begin(), assignment_keys.end(),
470 std::back_inserter(diff));
473 size_t nrAssignments = 1;
474 for (
auto&& k : diff) {
480 if (min_heap.
empty()) {
484 for (
size_t i = 0;
i <
std::min(nrAssignments,
N); ++
i) {
489 if (
p > min_heap.
top()) {
490 if (min_heap.
size() ==
N) {
504 return min_heap.
top();
509 const size_t N = maxNrAssignments;
515 auto thresholdFunc = [threshold, &total,
N](
const double&
value) {
526 if (value < threshold || total >=
N) {
Annotation for function names.
Min-Heap class to help with pruning. The top element is always the smallest value.
void push(double x)
Push value onto the heap.
double computeThreshold(const size_t N) const
Compute the probability value which is the threshold above which only N leaves are present.
Array< double, 1, 3 > e(1./3., 0.5, 2.)
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, bool showZero=true) const
std::map< Key, size_t > cardinalities_
Map of Keys and their cardinalities.
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
DecisionTreeFactor apply(Unary op) const
std::set< Key > labels() const
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
const KeyFormatter & formatter
double error(const DiscreteValues &values) const override
Calculate error for DiscreteValues x, is -log(probability).
const EIGEN_DEVICE_FUNC LogReturnType log() const
DiscreteKeys is a set of keys that can be assembled using the & operator.
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
double pop()
Pop the top value of the heap.
ofstream os("timeSchurFactors.csv")
const_iterator begin() const
std::shared_ptr< DecisionTreeFactor > shared_ptr
void print(const std::string &s="")
Print the heap as a sequence.
static std::stringstream ss
A thin wrapper around std::set that uses boost's fast_pool_allocator.
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
bool empty() const
Return true if heap is empty.
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
virtual DiscreteFactor::shared_ptr multiply(const DiscreteFactor::shared_ptr &f) const override
Multiply factors, DiscreteFactor::shared_ptr edition.
bool fpEqual(double a, double b, double tol, bool check_relative_also)
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
void print(const std::string &s="DecisionTreeFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
virtual double evaluate(const Assignment< Key > &values) const override
DecisionTree apply(const Unary &op) const
const_iterator end() const
DecisionTreeFactor operator/(const DecisionTreeFactor &f) const
Divide by factor f (safely). Division of a factor by another factor results in a function which inv...
std::function< double(const double, const double)> Binary
const gtsam::Symbol key('X', 0)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
double top()
Return the top value of the heap without popping it.
size_t cardinality(Key j) const
shared_ptr combine(size_t nrFrontals, Binary op) const
KeyVector keys_
The keys involved in this factor.
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
DiscreteValues::Names Names
Translation table from values to strings.
void print(const std::string &s="", const typename Base::LabelFormatter &labelFormatter=&DefaultFormatter) const
print method customized to value type double.
const KeyVector & keys() const
Access the factor's involved variable keys.
std::pair< Key, size_t > DiscreteKey
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
Array< int, Dynamic, 1 > v
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
MinHeap()
Default constructor.
std::function< double(const Assignment< Key > &, const double &)> UnaryAssignment
static std::string Translate(const Names &names, Key key, size_t index)
Translate an integer index value for given key to a string.
static std::string valueFormatter(const double &v)
std::uint64_t Key
Integer nonlinear key type.
constexpr descr< N - 1 > _(char const (&text)[N])
std::function< double(const double &)> Unary
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
bool contains(const Key &key) const
Check if key exists in ordering.
void push(double x, size_t n)
Push value x, n number of times.
std::vector< double > probabilities() const
Get all the probabilities in order of assignment values.
static double safe_div(const double &a, const double &b)
size_t size() const
Return the size of the heap.
virtual bool equals(const DiscreteFactor &lf, double tol=1e-9) const
equals
gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:08