41 #ifdef DT_DEBUG_MEMORY
42 template<
typename L,
typename Y>
43 int DecisionTree<L, Y>::Node::nrNodes = 0;
49 template <
typename L,
typename Y>
58 Leaf(
const Y& constant) : constant_(constant) {}
67 return constant_ ==
q.constant_;
72 return (
q.isLeaf() &&
q.sameLeaf(*
this));
77 if (!
q.isLeaf())
return false;
91 bool showZero)
const override {
93 if (showZero ||
value.compare(
"0"))
94 os <<
"\"" << this->
id() <<
"\" [label=\"" << value
95 <<
"\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
122 return g.apply_g_op_fL(*
this, op);
142 bool isLeaf()
const override {
return true; }
147 #if GTSAM_ENABLE_BOOST_SERIALIZATION
149 friend class boost::serialization::access;
150 template <
class ARCHIVE>
151 void serialize(ARCHIVE& ar,
const unsigned int ) {
152 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
153 ar& BOOST_SERIALIZATION_NVP(constant_);
161 template<
typename L,
typename Y>
183 #ifdef DT_DEBUG_MEMORY
184 std::cout << Node::nrNodes <<
" destructing (Choice) " << this->
id()
205 #ifdef GTSAM_DT_MERGING
208 if (node->isLeaf())
return node;
210 auto choice = std::static_pointer_cast<const Choice>(node);
213 auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
216 for (
const auto& branch : choice->branches_) {
217 f->push_back(Unique(branch));
222 assert(
f->branches().size() > 0);
223 auto f0 = std::static_pointer_cast<const Leaf>(
f->branches_[0]);
224 return std::make_shared<Leaf>(
f0->constant());
238 bool isLeaf()
const override {
return false; }
242 label_(label), allSame_(true) {
243 branches_.reserve(count);
250 if (
f.label() >
g.label()) {
253 size_t count =
f.nrChoices();
254 branches_.reserve(count);
255 for (
size_t i = 0;
i < count;
i++) {
256 NodePtr newBranch =
f.branches_[
i]->apply_f_op_g(
g, op);
257 push_back(std::move(newBranch));
259 }
else if (
g.label() >
f.label()) {
262 size_t count =
g.nrChoices();
263 branches_.reserve(count);
264 for (
size_t i = 0;
i < count;
i++) {
265 NodePtr newBranch =
g.branches_[
i]->apply_g_op_fC(
f, op);
266 push_back(std::move(newBranch));
271 size_t count =
f.nrChoices();
272 branches_.reserve(count);
273 for (
size_t i = 0;
i < count;
i++) {
274 NodePtr newBranch =
f.branches_[
i]->apply_f_op_g(*
g.branches_[
i], op);
275 push_back(std::move(newBranch));
286 return branches_.size();
300 if (allSame_ && !branches_.empty()) {
301 allSame_ = node->sameLeaf(*branches_.back());
303 branches_.push_back(std::move(node));
309 std::cout <<
s <<
" Choice(";
310 std::cout << labelFormatter(label_) <<
") " << std::endl;
311 for (
size_t i = 0;
i < branches_.size();
i++) {
312 branches_[
i]->print(
s +
" " + std::to_string(
i), labelFormatter,
valueFormatter);
319 bool showZero)
const override {
320 const std::string label = labelFormatter(label_);
321 os <<
"\"" << this->
id() <<
"\" [shape=circle, label=\"" << label
323 size_t B = branches_.size();
324 for (
size_t i = 0;
i <
B;
i++) {
325 const NodePtr& branch = branches_[
i];
328 if (!showZero && branch->isLeaf()) {
329 auto leaf = std::static_pointer_cast<const Leaf>(branch);
333 os <<
"\"" << this->
id() <<
"\" -> \"" << branch->id() <<
"\"";
334 if (
B == 2 &&
i == 0)
os <<
" [style=dashed]";
347 return (
q.isLeaf() &&
q.sameLeaf(*
this));
352 if (
q.isLeaf())
return false;
354 if (this->label_ !=
other->label_)
return false;
355 if (branches_.size() !=
other->branches_.size())
return false;
357 for (
size_t i = 0;
i < branches_.size();
i++)
368 std::cout <<
"Trying to find value for " << label_ << std::endl;
369 throw std::invalid_argument(
370 "DecisionTree::operator(): value undefined for a label");
373 size_t index =
x.at(label_);
374 NodePtr child = branches_[index];
380 label_(label), allSame_(true) {
381 branches_.reserve(
f.branches_.size());
382 for (
const NodePtr& branch :
f.branches_) {
383 push_back(branch->apply(op));
399 : label_(label), allSame_(true) {
400 branches_.reserve(
f.branches_.size());
404 for (
size_t i = 0;
i <
f.branches_.size();
i++) {
405 assignment_[label_] =
i;
408 push_back(branch->apply(op, assignment_));
411 auto assignment_it = assignment_.find(label_);
412 assignment_.erase(assignment_it);
418 auto r = std::make_shared<Choice>(label_, *
this, op);
425 auto r = std::make_shared<Choice>(label_, *
this, op, assignment);
435 return g.apply_g_op_fC(*
this, op);
440 auto h = std::make_shared<Choice>(label(), nrChoices());
441 for (
auto&& branch : branches_)
448 auto h = std::make_shared<Choice>(fC, *
this, op);
453 template<
typename OP>
455 auto h = std::make_shared<Choice>(label(), nrChoices());
456 for (
auto&& branch : branches_)
457 h->push_back(branch->apply_f_op_g(gL, op));
463 if (label_ == label)
return branches_[index];
466 auto r = std::make_shared<Choice>(label_, branches_.size());
467 for (
auto&& branch : branches_) {
468 r->push_back(branch->choose(label, index));
477 #if GTSAM_ENABLE_BOOST_SERIALIZATION
479 friend class boost::serialization::access;
480 template <
class ARCHIVE>
481 void serialize(ARCHIVE& ar,
const unsigned int ) {
482 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
483 ar& BOOST_SERIALIZATION_NVP(label_);
484 ar& BOOST_SERIALIZATION_NVP(branches_);
485 ar& BOOST_SERIALIZATION_NVP(allSame_);
493 template <
typename L,
typename Y>
496 template<
typename L,
typename Y>
501 template<
typename L,
typename Y>
507 template <
typename L,
typename Y>
509 auto a = std::make_shared<Choice>(label, 2);
511 a->push_back(std::move(
l1));
512 a->push_back(std::move(
l2));
513 root_ = Choice::Unique(std::move(
a));
517 template <
typename L,
typename Y>
520 if (labelC.second != 2)
throw std::invalid_argument(
521 "DecisionTree: binary constructor called with non-binary label");
522 auto a = std::make_shared<Choice>(labelC.first, 2);
524 a->push_back(std::move(
l1));
525 a->push_back(std::move(
l2));
526 root_ = Choice::Unique(std::move(
a));
529 template<
typename L,
typename Y>
531 const std::vector<Y>& ys) {
533 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
537 template<
typename L,
typename Y>
539 const std::string&
table) {
542 std::istringstream iss(
table);
543 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
547 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
551 template<
typename L,
typename Y>
558 template<
typename L,
typename Y>
561 const std::vector<DecisionTree> functions{
f0,
f1};
562 root_ =
compose(functions.begin(), functions.end(), label);
566 template <
typename L,
typename Y>
569 : root_(std::move(
other.root_)) {
576 if (node->isLeaf()) {
578 auto leaf = std::static_pointer_cast<Leaf>(node);
579 leaf->constant_ = op(
leaf->constant_);
582 auto choice = std::static_pointer_cast<Choice>(node);
583 for (
NodePtr& branch : choice->branches()) {
590 ApplyUnary applyUnary{op};
594 other.root_ =
nullptr;
598 template <
typename L,
typename Y>
599 template <
typename X,
typename Func>
602 root_ = convertFrom<X>(
other.root_, Y_of_X);
606 template <
typename L,
typename Y>
607 template <
typename M,
typename X,
typename Func>
609 const std::map<M, L>& map, Func Y_of_X) {
610 auto L_of_M = [&map](
const M& label) ->
L {
return map.at(label); };
611 root_ = convertFrom<M, X>(
other.root_, L_of_M, Y_of_X);
620 template <
typename L,
typename Y>
621 template <
typename Iterator>
625 std::optional<L> highestLabel;
626 size_t nrChoices = 0;
628 if (it->root_->isLeaf())
630 auto c = std::static_pointer_cast<const Choice>(it->root_);
631 if (!highestLabel ||
c->label() > *highestLabel) {
632 highestLabel =
c->label();
633 nrChoices =
c->nrChoices();
638 if (!nrChoices || !highestLabel || label > *highestLabel) {
639 auto choiceOnLabel = std::make_shared<Choice>(label,
end - begin);
642 choiceOnLabel->push_back(std::move(root));
645 return choiceOnLabel;
648 auto choiceOnHighestLabel =
649 std::make_shared<Choice>(*highestLabel, nrChoices);
651 for (
size_t index = 0; index < nrChoices; index++) {
654 std::vector<DecisionTree> functions;
658 functions.push_back(chosen);
662 choiceOnHighestLabel->push_back(std::move(fi));
664 return choiceOnHighestLabel;
689 template<
typename L,
typename Y>
690 template<
typename It,
typename ValueIt>
692 It begin, It
end, ValueIt beginY, ValueIt endY) {
694 size_t nrChoices = begin->second;
695 size_t size = endY - beginY;
698 It labelC = begin + 1;
702 if (
size != nrChoices) {
703 std::cout <<
"Trying to create DD on " << begin->first << std::endl;
704 std::cout <<
"DecisionTree::create: expected " << nrChoices
705 <<
" values but got " <<
size <<
" instead" << std::endl;
706 throw std::invalid_argument(
"DecisionTree::create invalid argument");
708 auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
709 for (ValueIt
y = beginY;
y != endY;
y++) {
718 std::vector<DecisionTree> functions;
719 functions.reserve(nrChoices);
721 for (
size_t i = 0;
i < nrChoices;
i++, beginY +=
split) {
723 functions.emplace_back(
f);
725 return compose(functions.begin(), functions.end(), begin->first);
731 template<
typename L,
typename Y>
732 template<
typename It,
typename ValueIt>
734 It begin, It
end, ValueIt beginY, ValueIt endY) {
735 auto node =
build(begin,
end, beginY, endY);
736 return Choice::Unique(node);
740 template <
typename L,
typename Y>
741 template <
typename X>
744 std::function<
Y(
const X&)> Y_of_X) {
750 auto leaf = std::static_pointer_cast<LXLeaf>(
f);
755 auto choice = std::static_pointer_cast<const LXChoice>(
f);
758 auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
761 for (
auto&& branch : choice->branches()) {
762 newChoice->push_back(convertFrom<X>(branch, Y_of_X));
765 return Choice::Unique(newChoice);
769 template <
typename L,
typename Y>
770 template <
typename M,
typename X>
773 std::function<
L(
const M&)> L_of_M, std::function<
Y(
const X&)> Y_of_X) {
780 auto leaf = std::static_pointer_cast<const MXLeaf>(
f);
785 auto choice = std::static_pointer_cast<const MXChoice>(
f);
788 const M oldLabel = choice->label();
789 const L newLabel = L_of_M(oldLabel);
799 std::vector<LY> functions;
800 for (
auto&& branch : choice->branches()) {
801 functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
803 return Choice::Unique(
804 LY::compose(functions.begin(), functions.end(), newLabel));
818 template <
typename L,
typename Y>
820 using F = std::function<void(
const Y&)>;
829 if (node->isLeaf()) {
830 auto leaf = std::static_pointer_cast<const Leaf>(node);
831 return f(
leaf->constant());
834 auto choice = std::static_pointer_cast<const Choice>(node);
835 for (
auto&& branch : choice->branches()) (*this)(branch);
839 template <
typename L,
typename Y>
840 template <
typename Func>
856 template <
typename L,
typename Y>
867 if (node->isLeaf()) {
868 auto leaf = std::static_pointer_cast<const Leaf>(node);
872 auto choice = std::static_pointer_cast<const Choice>(node);
873 for (
auto&& branch : choice->branches()) (*this)(branch);
877 template <
typename L,
typename Y>
878 template <
typename Func>
891 template <
typename L,
typename Y>
903 if (node->isLeaf()) {
904 auto leaf = std::static_pointer_cast<const Leaf>(node);
905 return f(assignment,
leaf->constant());
910 auto choice = std::static_pointer_cast<const Choice>(node);
911 for (
size_t i = 0;
i < choice->nrChoices();
i++) {
912 assignment[choice->label()] =
i;
914 (*this)(choice->branches()[
i]);
917 auto choice_it = assignment.find(choice->label());
918 assignment.erase(choice_it);
923 template <
typename L,
typename Y>
924 template <
typename Func>
931 template <
typename L,
typename Y>
934 visit([&total](
const Y& node) { total += 1; });
940 template <
typename L,
typename Y>
941 template <
typename Func,
typename X>
943 visit([&](
const Y&
y) {
x0 =
f(
y,
x0); });
961 template <
typename L,
typename Y>
965 for (
auto&& kv : assignment) {
966 unique.insert(kv.first);
974 template <
typename L,
typename Y>
980 template <
typename L,
typename Y>
987 template<
typename L,
typename Y>
989 return root_->equals(*
other.root_);
993 template<
typename L,
typename Y>
995 if (root_ ==
nullptr)
996 throw std::invalid_argument(
997 "DecisionTree::operator() called on empty tree");
998 return root_->operator ()(
x);
1002 template<
typename L,
typename Y>
1006 throw std::runtime_error(
1007 "DecisionTree::apply(unary op) undefined for empty tree.");
1014 template <
typename L,
typename Y>
1019 throw std::runtime_error(
1020 "DecisionTree::apply(unary op) undefined for empty tree.");
1027 template<
typename L,
typename Y>
1029 const Binary& op)
const {
1031 if (
empty() ||
g.empty()) {
1032 throw std::runtime_error(
1033 "DecisionTree::apply(binary op) undefined for empty trees.");
1036 NodePtr h = root_->apply_f_op_g(*
g.root_, op);
1051 template<
typename L,
typename Y>
1053 size_t cardinality,
const Binary& op)
const {
1055 for (
size_t index = 1; index < cardinality; index++) {
1063 template <
typename L,
typename Y>
1067 bool showZero)
const {
1068 os <<
"digraph G {\n";
1070 os <<
" [ordering=out]}" << std::endl;
1073 template <
typename L,
typename Y>
1077 bool showZero)
const {
1081 system((
"dot -Tpdf " +
name +
".dot -o " +
name +
".pdf >& /dev/null")
1084 throw std::runtime_error(
"DecisionTree::dot system call failed");
1087 template <
typename L,
typename Y>
1090 bool showZero)
const {
1091 std::stringstream
ss;
1097 template <
typename L,
typename Y>
1098 template <
typename A,
typename B>
1100 std::function<std::pair<A, B>(
const Y&)> AB_of_Y)
const {
1101 using AB = std::pair<A, B>;