43 #ifdef DT_DEBUG_MEMORY 44 template<
typename L,
typename Y>
45 int DecisionTree<L, Y>::Node::nrNodes = 0;
51 template <
typename L,
typename Y>
65 Leaf(
const Y& constant,
size_t nrAssignments = 1)
66 : constant_(constant), nrAssignments_(nrAssignments) {}
89 if (!other)
return false;
96 std::cout << s <<
" Leaf " <<
valueFormatter(constant_) << std::endl;
102 bool showZero)
const override {
104 if (showZero || value.compare(
"0"))
105 os <<
"\"" << this->
id() <<
"\" [label=\"" << value
106 <<
"\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
123 NodePtr f(
new Leaf(op(assignment, constant_), nrAssignments_));
150 return NodePtr(
new Leaf(constant(), nrAssignments()));
153 bool isLeaf()
const override {
return true; }
158 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION 160 friend class boost::serialization::access;
161 template <
class ARCHIVE>
162 void serialize(ARCHIVE& ar,
const unsigned int ) {
163 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
164 ar& BOOST_SERIALIZATION_NVP(constant_);
165 ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
173 template<
typename L,
typename Y>
195 #ifdef DT_DEBUG_MEMORY 196 std::cout << Node::nrNodes <<
" destructing (Choice) " << this->
id()
203 #ifndef GTSAM_DT_NO_PRUNING 205 assert(f->branches().size() > 0);
208 size_t nrAssignments = 0;
209 for(
auto branch: f->branches()) {
210 assert(branch->isLeaf());
212 std::dynamic_pointer_cast<
const Leaf>(branch)->nrAssignments();
215 new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
223 bool isLeaf()
const override {
return false; }
227 label_(label), allSame_(true) {
228 branches_.reserve(count);
239 branches_.reserve(count);
240 for (
size_t i = 0;
i < count;
i++)
241 push_back(f.
branches_[
i]->apply_f_op_g(g, op));
246 branches_.reserve(count);
247 for (
size_t i = 0;
i < count;
i++)
248 push_back(g.
branches_[
i]->apply_g_op_fC(f, op));
253 branches_.reserve(count);
254 for (
size_t i = 0;
i < count;
i++)
265 return branches_.size();
275 if (allSame_ && !branches_.empty()) {
276 allSame_ = node->sameLeaf(*branches_.back());
278 branches_.push_back(node);
284 std::cout << s <<
" Choice(";
285 std::cout << labelFormatter(label_) <<
") " << std::endl;
286 for (
size_t i = 0;
i < branches_.size();
i++) {
287 branches_[
i]->print(s +
" " + std::to_string(
i), labelFormatter, valueFormatter);
294 bool showZero)
const override {
295 os <<
"\"" << this->
id() <<
"\" [shape=circle, label=\"" << label_
297 size_t B = branches_.size();
298 for (
size_t i = 0;
i < B;
i++) {
299 const NodePtr& branch = branches_[
i];
303 const Leaf*
leaf =
dynamic_cast<const Leaf*
>(branch.get());
307 os <<
"\"" << this->
id() <<
"\" -> \"" << branch->id() <<
"\"";
308 if (B == 2 &&
i == 0) os <<
" [style=dashed]";
310 branch->dot(os, labelFormatter, valueFormatter, showZero);
327 if (!other)
return false;
328 if (this->label_ != other->
label_)
return false;
329 if (branches_.size() != other->
branches_.size())
return false;
331 for (
size_t i = 0;
i < branches_.size();
i++)
342 std::cout <<
"Trying to find value for " << label_ << std::endl;
343 throw std::invalid_argument(
344 "DecisionTree::operator(): value undefined for a label");
347 size_t index = x.at(label_);
348 NodePtr child = branches_[index];
354 label_(label), allSame_(true) {
357 push_back(branch->apply(op));
373 : label_(label), allSame_(true) {
379 assignment_[label_] =
i;
382 push_back(branch->apply(op, assignment_));
385 auto assignment_it = assignment_.find(label_);
386 assignment_.erase(assignment_it);
392 auto r = std::make_shared<Choice>(label_, *
this, op);
399 auto r = std::make_shared<Choice>(label_, *
this, op, assignment);
414 auto h = std::make_shared<Choice>(label(), nrChoices());
415 for (
auto&& branch : branches_)
422 auto h = std::make_shared<Choice>(fC, *
this, op);
427 template<
typename OP>
429 auto h = std::make_shared<Choice>(label(), nrChoices());
430 for (
auto&& branch : branches_)
431 h->push_back(branch->apply_f_op_g(gL, op));
437 if (label_ == label)
return branches_[index];
440 auto r = std::make_shared<Choice>(label_, branches_.size());
441 for (
auto&& branch : branches_)
442 r->push_back(branch->choose(label, index));
449 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION 451 friend class boost::serialization::access;
452 template <
class ARCHIVE>
453 void serialize(ARCHIVE& ar,
const unsigned int ) {
454 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
455 ar& BOOST_SERIALIZATION_NVP(label_);
456 ar& BOOST_SERIALIZATION_NVP(branches_);
457 ar& BOOST_SERIALIZATION_NVP(allSame_);
465 template<
typename L,
typename Y>
469 template<
typename L,
typename Y>
475 template<
typename L,
typename Y>
481 template <
typename L,
typename Y>
483 auto a = std::make_shared<Choice>(label, 2);
487 root_ = Choice::Unique(
a);
491 template <
typename L,
typename Y>
494 if (labelC.second != 2)
throw std::invalid_argument(
495 "DecisionTree: binary constructor called with non-binary label");
496 auto a = std::make_shared<Choice>(labelC.first, 2);
500 root_ = Choice::Unique(
a);
504 template<
typename L,
typename Y>
506 const std::vector<Y>& ys) {
508 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
512 template<
typename L,
typename Y>
514 const std::string&
table) {
517 std::istringstream iss(table);
518 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
522 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
526 template<
typename L,
typename Y>
528 Iterator begin, Iterator
end,
const L& label) {
529 root_ =
compose(begin, end, label);
533 template<
typename L,
typename Y>
536 const std::vector<DecisionTree> functions{f0, f1};
537 root_ =
compose(functions.begin(), functions.end(), label);
541 template <
typename L,
typename Y>
542 template <
typename X,
typename Func>
546 auto L_of_L = [](
const L& label) {
return label; };
547 root_ = convertFrom<L, X>(other.
root_, L_of_L, Y_of_X);
551 template <
typename L,
typename Y>
552 template <
typename M,
typename X,
typename Func>
554 const std::map<M, L>& map, Func Y_of_X) {
555 auto L_of_M = [&map](
const M& label) ->
L {
return map.at(label); };
556 root_ = convertFrom<M, X>(other.
root_, L_of_M, Y_of_X);
565 template <
typename L,
typename Y>
566 template <
typename Iterator>
568 Iterator begin, Iterator
end,
const L& label)
const {
570 std::optional<L> highestLabel;
571 size_t nrChoices = 0;
572 for (Iterator it = begin; it !=
end; it++) {
573 if (it->root_->isLeaf())
575 std::shared_ptr<const Choice>
c =
576 std::dynamic_pointer_cast<
const Choice>(it->root_);
577 if (!highestLabel || c->label() > *highestLabel) {
578 highestLabel = c->label();
579 nrChoices = c->nrChoices();
584 if (!nrChoices || !highestLabel || label > *highestLabel) {
585 auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
586 for (Iterator it = begin; it !=
end; it++)
587 choiceOnLabel->push_back(it->root_);
588 return Choice::Unique(choiceOnLabel);
591 auto choiceOnHighestLabel =
592 std::make_shared<Choice>(*highestLabel, nrChoices);
594 for (
size_t index = 0; index < nrChoices; index++) {
597 std::vector<DecisionTree> functions;
598 for (Iterator it = begin; it !=
end; it++) {
601 functions.push_back(chosen);
605 choiceOnHighestLabel->push_back(fi);
607 return Choice::Unique(choiceOnHighestLabel);
632 template<
typename L,
typename Y>
633 template<
typename It,
typename ValueIt>
635 It begin, It
end, ValueIt beginY, ValueIt endY)
const {
637 size_t nrChoices = begin->second;
638 size_t size = endY - beginY;
641 It labelC = begin + 1;
645 if (size != nrChoices) {
646 std::cout <<
"Trying to create DD on " << begin->first << std::endl;
647 std::cout <<
"DecisionTree::create: expected " << nrChoices
648 <<
" values but got " << size <<
" instead" << std::endl;
649 throw std::invalid_argument(
"DecisionTree::create invalid argument");
651 auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
652 for (ValueIt
y = beginY;
y != endY;
y++)
654 return Choice::Unique(choice);
660 std::vector<DecisionTree> functions;
661 size_t split = size / nrChoices;
662 for (
size_t i = 0;
i < nrChoices;
i++, beginY +=
split) {
664 functions.emplace_back(f);
666 return compose(functions.begin(), functions.end(), begin->first);
670 template <
typename L,
typename Y>
671 template <
typename M,
typename X>
674 std::function<
L(
const M&)> L_of_M,
675 std::function<
Y(
const X&)> Y_of_X)
const {
682 if (
auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
688 auto choice = std::dynamic_pointer_cast<
const MXChoice>(
f);
689 if (!choice)
throw std::invalid_argument(
690 "DecisionTree::convertFrom: Invalid NodePtr");
693 const M oldLabel = choice->label();
694 const L newLabel = L_of_M(oldLabel);
697 std::vector<LY> functions;
698 for (
auto&& branch : choice->branches()) {
699 functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
701 return LY::compose(functions.begin(), functions.end(), newLabel);
715 template <
typename L,
typename Y>
717 using F = std::function<void(const Y&)>;
724 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
725 return f(
leaf->constant());
728 auto choice = std::dynamic_pointer_cast<
const Choice>(node);
730 throw std::invalid_argument(
"DecisionTree::Visit: Invalid NodePtr");
731 for (
auto&& branch : choice->branches()) (*
this)(branch);
735 template <
typename L,
typename Y>
736 template <
typename Func>
752 template <
typename L,
typename Y>
761 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
765 auto choice = std::dynamic_pointer_cast<
const Choice>(node);
767 throw std::invalid_argument(
"DecisionTree::VisitLeaf: Invalid NodePtr");
768 for (
auto&& branch : choice->branches()) (*
this)(branch);
772 template <
typename L,
typename Y>
773 template <
typename Func>
786 template <
typename L,
typename Y>
788 using F = std::function<void(const Assignment<L>&,
const Y&)>;
796 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
797 return f(assignment,
leaf->constant());
800 auto choice = std::dynamic_pointer_cast<
const Choice>(node);
802 throw std::invalid_argument(
"DecisionTree::VisitWith: Invalid NodePtr");
803 for (
size_t i = 0;
i < choice->nrChoices();
i++) {
804 assignment[choice->label()] =
i;
806 (*this)(choice->branches()[
i]);
809 auto choice_it = assignment.find(choice->label());
810 assignment.erase(choice_it);
815 template <
typename L,
typename Y>
816 template <
typename Func>
823 template <
typename L,
typename Y>
826 visit([&total](
const Y& node) { total += 1; });
832 template <
typename L,
typename Y>
833 template <
typename Func,
typename X>
835 visit([&](
const Y&
y) { x0 =
f(y, x0); });
853 template <
typename L,
typename Y>
857 for (
auto&& kv : assignment) {
858 unique.insert(kv.first);
866 template <
typename L,
typename Y>
872 template <
typename L,
typename Y>
876 root_->print(s, labelFormatter, valueFormatter);
879 template<
typename L,
typename Y>
881 return root_->equals(*other.
root_);
884 template<
typename L,
typename Y>
886 return root_->operator ()(
x);
889 template<
typename L,
typename Y>
893 throw std::runtime_error(
894 "DecisionTree::apply(unary op) undefined for empty tree.");
900 template <
typename L,
typename Y>
905 throw std::runtime_error(
906 "DecisionTree::apply(unary op) undefined for empty tree.");
913 template<
typename L,
typename Y>
918 throw std::runtime_error(
919 "DecisionTree::apply(binary op) undefined for empty trees.");
937 template<
typename L,
typename Y>
939 size_t cardinality,
const Binary& op)
const {
941 for (
size_t index = 1; index < cardinality; index++) {
943 result = result.
apply(chosen, op);
949 template <
typename L,
typename Y>
953 bool showZero)
const {
954 os <<
"digraph G {\n";
955 root_->dot(os, labelFormatter, valueFormatter, showZero);
956 os <<
" [ordering=out]}" << std::endl;
959 template <
typename L,
typename Y>
963 bool showZero)
const {
964 std::ofstream
os((name +
".dot").
c_str());
965 dot(os, labelFormatter, valueFormatter, showZero);
967 system((
"dot -Tpdf " + name +
".dot -o " + name +
".pdf >& /dev/null")
970 throw std::runtime_error(
"DecisionTree::dot system call failed");
973 template <
typename L,
typename Y>
976 bool showZero)
const {
977 std::stringstream
ss;
978 dot(ss, labelFormatter, valueFormatter, showZero);
const Y & operator()(const Assignment< L > &x) const override
evaluate
Decision Tree for use in DiscreteFactors.
Matrix< RealScalar, Dynamic, Dynamic > M
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print (as a tree).
std::function< void(const Assignment< L > &, const Y &)> F
std::function< Y(const Assignment< L > &, const Y &)> UnaryAssignment
NodePtr apply(const Unary &op) const override
apply unary operator.
size_t nrAssignments() const
Return the number of assignments contained within this leaf.
std::vector< NodePtr > branches_
std::string serialize(const T &input)
serializes to a string
std::function< Y(const Y &, const Y &)> Binary
std::vector< std::string > labels
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
static const T & choose(int layout, const T &col, const T &row)
VisitLeaf(F f)
Construct from folding function.
virtual Ptr apply_g_op_fL(const Leaf &, const Binary &) const =0
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
const std::vector< NodePtr > & branches() const
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
std::function< Y(const Y &)> Unary
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16 &a, const bfloat16 &b)
Choice(const Choice &f, const Choice &g, const Binary &op)
Construct from applying binary op to two Choice nodes.
static std::string valueFormatter(const double &v)
NodePtr apply_fC_op_gL(const Leaf &gL, OP op) const
Choice(const L &label, const Choice &f, const Unary &op)
Construct from applying unary op to a Choice node.
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
static const Similarity3 id
virtual bool isLeaf() const =0
void g(const string &key, int i)
const L & label() const
Return the label of this choice node.
T compose(const T &t1, const T &t2)
const char * c_str(Args &&...args)
std::function< void(const Y &)> F
F f
folding function object.
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
Leaf(const Y &constant, size_t nrAssignments=1)
Constructor from constant.
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
NodePtr choose(const L &label, size_t index) const override
Scalar EIGEN_BLAS_FUNC() dot(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy)
Choice(const L &label, size_t count)
Constructor, given choice label and mandatory expected branch count.
virtual bool sameLeaf(const Leaf &q) const =0
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
bool sameLeaf(const Node &q) const override
polymorphic equality: is q a leaf and is it the same as this leaf?
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print
void push_back(const NodePtr &node)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Leaf()
Default constructor for serialization.
static NodePtr Unique(const ChoicePtr &f)
If all branches of a choice node f are the same, just return a branch.
EIGEN_DEVICE_FUNC const Scalar & q
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
NodePtr choose(const L &label, size_t index) const override
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
typename Node::Ptr NodePtr
void operator()(const typename DecisionTree< L, Y >::NodePtr &node)
Do a depth-first visit on the tree rooted at node.
static std::stringstream ss
NodePtr apply(const Unary &op) const override
std::function< bool(const Y &, const Y &)> CompareFunc
std::function< std::string(L)> LabelFormatter
bool empty() const
Check if tree is empty.
bool isLeaf() const override
bool equals(const Node &q, const CompareFunc &compare) const override
equality
Visit(F f)
Construct from folding function.
ADT create(const Signature &signature)
Assignment< L > assignment
Assignment, mutating through recursion.
ofstream os("timeSchurFactors.csv")
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
virtual Ptr apply_g_op_fC(const Choice &, const Binary &) const =0
std::function< std::string(Y)> ValueFormatter
static EIGEN_DEPRECATED const end_t end
std::function< void(const typename DecisionTree< L, Y >::Leaf &)> F
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
bool isLeaf() const override
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
F f
folding function object.
Annotation for function names.
std::shared_ptr< const Choice > ChoicePtr
bool equals(const Node &q, const CompareFunc &compare) const override
equality up to tolerance
std::pair< Key, size_t > LabelC
VisitWith(F f)
Construct from folding function.
Choice(const L &label, const Choice &f, const UnaryAssignment &op, const Assignment< L > &assignment)
Constructor which accepts a UnaryAssignment op and the corresponding assignment.
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
DecisionTree choose(const L &label, size_t index) const
F f
folding function object.
const Y & constant() const
Return the constant.
const Y & operator()(const Assignment< L > &x) const override
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
Choice()
Default constructor for serialization.