43 #ifdef DT_DEBUG_MEMORY
44 template<
typename L,
typename Y>
45 int DecisionTree<L, Y>::Node::nrNodes = 0;
51 template <
typename L,
typename Y>
60 Leaf(
const Y& constant) : constant_(constant) {}
69 return constant_ ==
q.constant_;
74 return (
q.isLeaf() &&
q.sameLeaf(*
this));
80 if (!
other)
return false;
93 bool showZero)
const override {
95 if (showZero ||
value.compare(
"0"))
96 os <<
"\"" << this->
id() <<
"\" [label=\"" << value
97 <<
"\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
124 return g.apply_g_op_fL(*
this, op);
144 bool isLeaf()
const override {
return true; }
149 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
151 friend class boost::serialization::access;
152 template <
class ARCHIVE>
153 void serialize(ARCHIVE& ar,
const unsigned int ) {
154 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
155 ar& BOOST_SERIALIZATION_NVP(constant_);
163 template<
typename L,
typename Y>
185 #ifdef DT_DEBUG_MEMORY
186 std::cout << Node::nrNodes <<
" destructing (Choice) " << this->
id()
208 if (
auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
211 auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
214 for (
size_t i = 0;
i < choice->nrChoices();
i++) {
215 auto branch = choice->branches_[
i];
216 f->push_back(Unique(branch));
219 #ifdef GTSAM_DT_MERGING
222 assert(
f->branches().size() > 0);
226 new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
237 bool isLeaf()
const override {
return false; }
241 label_(label), allSame_(true) {
242 branches_.reserve(count);
249 if (
f.label() >
g.label()) {
252 size_t count =
f.nrChoices();
253 branches_.reserve(count);
254 for (
size_t i = 0;
i < count;
i++)
255 push_back(
f.branches_[
i]->apply_f_op_g(
g, op));
256 }
else if (
g.label() >
f.label()) {
259 size_t count =
g.nrChoices();
260 branches_.reserve(count);
261 for (
size_t i = 0;
i < count;
i++)
262 push_back(
g.branches_[
i]->apply_g_op_fC(
f, op));
266 size_t count =
f.nrChoices();
267 branches_.reserve(count);
268 for (
size_t i = 0;
i < count;
i++)
269 push_back(
f.branches_[
i]->apply_f_op_g(*
g.branches_[
i], op));
279 return branches_.size();
289 if (allSame_ && !branches_.empty()) {
290 allSame_ = node->sameLeaf(*branches_.back());
292 branches_.push_back(node);
298 std::cout <<
s <<
" Choice(";
299 std::cout << labelFormatter(label_) <<
") " << std::endl;
300 for (
size_t i = 0;
i < branches_.size();
i++) {
301 branches_[
i]->print(
s +
" " + std::to_string(
i), labelFormatter,
valueFormatter);
308 bool showZero)
const override {
309 os <<
"\"" << this->
id() <<
"\" [shape=circle, label=\"" << label_
311 size_t B = branches_.size();
312 for (
size_t i = 0;
i <
B;
i++) {
313 const NodePtr& branch = branches_[
i];
317 const Leaf*
leaf =
dynamic_cast<const Leaf*
>(branch.get());
321 os <<
"\"" << this->
id() <<
"\" -> \"" << branch->id() <<
"\"";
322 if (
B == 2 &&
i == 0)
os <<
" [style=dashed]";
335 return (
q.isLeaf() &&
q.sameLeaf(*
this));
341 if (!
other)
return false;
342 if (this->label_ !=
other->label_)
return false;
343 if (branches_.size() !=
other->branches_.size())
return false;
345 for (
size_t i = 0;
i < branches_.size();
i++)
356 std::cout <<
"Trying to find value for " << label_ << std::endl;
357 throw std::invalid_argument(
358 "DecisionTree::operator(): value undefined for a label");
361 size_t index =
x.at(label_);
362 NodePtr child = branches_[index];
368 label_(label), allSame_(true) {
369 branches_.reserve(
f.branches_.size());
370 for (
const NodePtr& branch :
f.branches_) {
371 push_back(branch->apply(op));
387 : label_(label), allSame_(true) {
388 branches_.reserve(
f.branches_.size());
392 for (
size_t i = 0;
i <
f.branches_.size();
i++) {
393 assignment_[label_] =
i;
396 push_back(branch->apply(op, assignment_));
399 auto assignment_it = assignment_.find(label_);
400 assignment_.erase(assignment_it);
406 auto r = std::make_shared<Choice>(label_, *
this, op);
413 auto r = std::make_shared<Choice>(label_, *
this, op, assignment);
423 return g.apply_g_op_fC(*
this, op);
428 auto h = std::make_shared<Choice>(label(), nrChoices());
429 for (
auto&& branch : branches_)
436 auto h = std::make_shared<Choice>(fC, *
this, op);
441 template<
typename OP>
443 auto h = std::make_shared<Choice>(label(), nrChoices());
444 for (
auto&& branch : branches_)
445 h->push_back(branch->apply_f_op_g(gL, op));
451 if (label_ == label)
return branches_[index];
454 auto r = std::make_shared<Choice>(label_, branches_.size());
455 for (
auto&& branch : branches_) {
456 r->push_back(branch->choose(label, index));
465 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
467 friend class boost::serialization::access;
468 template <
class ARCHIVE>
469 void serialize(ARCHIVE& ar,
const unsigned int ) {
470 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
471 ar& BOOST_SERIALIZATION_NVP(label_);
472 ar& BOOST_SERIALIZATION_NVP(branches_);
473 ar& BOOST_SERIALIZATION_NVP(allSame_);
481 template<
typename L,
typename Y>
484 template<
typename L,
typename Y>
489 template<
typename L,
typename Y>
495 template <
typename L,
typename Y>
497 auto a = std::make_shared<Choice>(label, 2);
501 root_ = Choice::Unique(
a);
505 template <
typename L,
typename Y>
508 if (labelC.second != 2)
throw std::invalid_argument(
509 "DecisionTree: binary constructor called with non-binary label");
510 auto a = std::make_shared<Choice>(labelC.first, 2);
514 root_ = Choice::Unique(
a);
518 template<
typename L,
typename Y>
520 const std::vector<Y>& ys) {
522 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
526 template<
typename L,
typename Y>
528 const std::string&
table) {
531 std::istringstream iss(
table);
532 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
536 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
540 template<
typename L,
typename Y>
542 Iterator begin, Iterator
end,
const L& label) {
547 template<
typename L,
typename Y>
550 const std::vector<DecisionTree> functions{f0,
f1};
551 root_ =
compose(functions.begin(), functions.end(), label);
555 template <
typename L,
typename Y>
556 template <
typename X,
typename Func>
560 auto L_of_L = [](
const L& label) {
return label; };
561 root_ = convertFrom<L, X>(
other.root_, L_of_L, Y_of_X);
565 template <
typename L,
typename Y>
566 template <
typename M,
typename X,
typename Func>
568 const std::map<M, L>& map, Func Y_of_X) {
569 auto L_of_M = [&map](
const M& label) ->
L {
return map.at(label); };
570 root_ = convertFrom<M, X>(
other.root_, L_of_M, Y_of_X);
579 template <
typename L,
typename Y>
580 template <
typename Iterator>
582 Iterator begin, Iterator
end,
const L& label)
const {
584 std::optional<L> highestLabel;
585 size_t nrChoices = 0;
586 for (Iterator it = begin; it !=
end; it++) {
587 if (it->root_->isLeaf())
589 std::shared_ptr<const Choice>
c =
590 std::dynamic_pointer_cast<const Choice>(it->root_);
591 if (!highestLabel ||
c->label() > *highestLabel) {
592 highestLabel =
c->label();
593 nrChoices =
c->nrChoices();
598 if (!nrChoices || !highestLabel || label > *highestLabel) {
599 auto choiceOnLabel = std::make_shared<Choice>(label,
end - begin);
600 for (Iterator it = begin; it !=
end; it++)
601 choiceOnLabel->push_back(it->root_);
603 return choiceOnLabel;
606 auto choiceOnHighestLabel =
607 std::make_shared<Choice>(*highestLabel, nrChoices);
609 for (
size_t index = 0; index < nrChoices; index++) {
612 std::vector<DecisionTree> functions;
613 for (Iterator it = begin; it !=
end; it++) {
616 functions.push_back(chosen);
620 choiceOnHighestLabel->push_back(fi);
622 return choiceOnHighestLabel;
647 template<
typename L,
typename Y>
648 template<
typename It,
typename ValueIt>
650 It begin, It
end, ValueIt beginY, ValueIt endY)
const {
652 size_t nrChoices = begin->second;
653 size_t size = endY - beginY;
656 It labelC = begin + 1;
660 if (
size != nrChoices) {
661 std::cout <<
"Trying to create DD on " << begin->first << std::endl;
662 std::cout <<
"DecisionTree::create: expected " << nrChoices
663 <<
" values but got " <<
size <<
" instead" << std::endl;
664 throw std::invalid_argument(
"DecisionTree::create invalid argument");
666 auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
667 for (ValueIt
y = beginY;
y != endY;
y++) {
676 std::vector<DecisionTree> functions;
678 for (
size_t i = 0;
i < nrChoices;
i++, beginY +=
split) {
680 functions.emplace_back(
f);
682 return compose(functions.begin(), functions.end(), begin->first);
688 template<
typename L,
typename Y>
689 template<
typename It,
typename ValueIt>
691 It begin, It
end, ValueIt beginY, ValueIt endY)
const {
692 auto node =
build(begin,
end, beginY, endY);
693 if (
auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
694 return Choice::Unique(choice);
701 template <
typename L,
typename Y>
702 template <
typename M,
typename X>
705 std::function<
L(
const M&)> L_of_M,
706 std::function<
Y(
const X&)> Y_of_X)
const {
713 if (
auto leaf = std::dynamic_pointer_cast<const MXLeaf>(
f)) {
719 auto choice = std::dynamic_pointer_cast<const MXChoice>(
f);
720 if (!choice)
throw std::invalid_argument(
721 "DecisionTree::convertFrom: Invalid NodePtr");
724 const M oldLabel = choice->label();
725 const L newLabel = L_of_M(oldLabel);
728 std::vector<LY> functions;
729 for (
auto&& branch : choice->branches()) {
730 functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
732 return Choice::Unique(
LY::compose(functions.begin(), functions.end(), newLabel));
746 template <
typename L,
typename Y>
748 using F = std::function<void(
const Y&)>;
755 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
756 return f(
leaf->constant());
759 auto choice = std::dynamic_pointer_cast<const Choice>(node);
761 throw std::invalid_argument(
"DecisionTree::Visit: Invalid NodePtr");
762 for (
auto&& branch : choice->branches()) (*this)(branch);
766 template <
typename L,
typename Y>
767 template <
typename Func>
783 template <
typename L,
typename Y>
792 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
796 auto choice = std::dynamic_pointer_cast<const Choice>(node);
798 throw std::invalid_argument(
"DecisionTree::VisitLeaf: Invalid NodePtr");
799 for (
auto&& branch : choice->branches()) (*this)(branch);
803 template <
typename L,
typename Y>
804 template <
typename Func>
817 template <
typename L,
typename Y>
827 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
828 return f(assignment,
leaf->constant());
831 auto choice = std::dynamic_pointer_cast<const Choice>(node);
833 throw std::invalid_argument(
"DecisionTree::VisitWith: Invalid NodePtr");
834 for (
size_t i = 0;
i < choice->nrChoices();
i++) {
835 assignment[choice->label()] =
i;
837 (*this)(choice->branches()[
i]);
840 auto choice_it = assignment.find(choice->label());
841 assignment.erase(choice_it);
846 template <
typename L,
typename Y>
847 template <
typename Func>
854 template <
typename L,
typename Y>
857 visit([&total](
const Y& node) { total += 1; });
863 template <
typename L,
typename Y>
864 template <
typename Func,
typename X>
866 visit([&](
const Y&
y) {
x0 =
f(
y,
x0); });
884 template <
typename L,
typename Y>
888 for (
auto&& kv : assignment) {
889 unique.insert(kv.first);
897 template <
typename L,
typename Y>
903 template <
typename L,
typename Y>
910 template<
typename L,
typename Y>
912 return root_->equals(*
other.root_);
915 template<
typename L,
typename Y>
917 return root_->operator ()(
x);
920 template<
typename L,
typename Y>
924 throw std::runtime_error(
925 "DecisionTree::apply(unary op) undefined for empty tree.");
931 template <
typename L,
typename Y>
936 throw std::runtime_error(
937 "DecisionTree::apply(unary op) undefined for empty tree.");
944 template<
typename L,
typename Y>
948 if (
empty() ||
g.empty()) {
949 throw std::runtime_error(
950 "DecisionTree::apply(binary op) undefined for empty trees.");
953 NodePtr h = root_->apply_f_op_g(*
g.root_, op);
968 template<
typename L,
typename Y>
970 size_t cardinality,
const Binary& op)
const {
972 for (
size_t index = 1; index < cardinality; index++) {
980 template <
typename L,
typename Y>
984 bool showZero)
const {
985 os <<
"digraph G {\n";
987 os <<
" [ordering=out]}" << std::endl;
990 template <
typename L,
typename Y>
994 bool showZero)
const {
998 system((
"dot -Tpdf " +
name +
".dot -o " +
name +
".pdf >& /dev/null")
1001 throw std::runtime_error(
"DecisionTree::dot system call failed");
1004 template <
typename L,
typename Y>
1007 bool showZero)
const {
1008 std::stringstream
ss;