41 #ifdef DT_DEBUG_MEMORY
42 template<
typename L,
typename Y>
43 int DecisionTree<L, Y>::Node::nrNodes = 0;
49 template <
typename L,
typename Y>
58 Leaf(
const Y& constant) : constant_(constant) {}
67 return constant_ ==
q.constant_;
72 return (
q.isLeaf() &&
q.sameLeaf(*
this));
78 if (!
other)
return false;
91 bool showZero)
const override {
93 if (showZero ||
value.compare(
"0"))
94 os <<
"\"" << this->
id() <<
"\" [label=\"" << value
95 <<
"\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
122 return g.apply_g_op_fL(*
this, op);
142 bool isLeaf()
const override {
return true; }
147 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
149 friend class boost::serialization::access;
150 template <
class ARCHIVE>
151 void serialize(ARCHIVE& ar,
const unsigned int ) {
152 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
153 ar& BOOST_SERIALIZATION_NVP(constant_);
161 template<
typename L,
typename Y>
183 #ifdef DT_DEBUG_MEMORY
184 std::cout << Node::nrNodes <<
" destructing (Choice) " << this->
id()
206 if (
auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
209 auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
212 for (
size_t i = 0;
i < choice->nrChoices();
i++) {
213 auto branch = choice->branches_[
i];
214 f->push_back(Unique(branch));
217 #ifdef GTSAM_DT_MERGING
220 assert(
f->branches().size() > 0);
224 new Leaf(std::dynamic_pointer_cast<const Leaf>(
f0)->constant()));
235 bool isLeaf()
const override {
return false; }
239 label_(label), allSame_(true) {
240 branches_.reserve(count);
247 if (
f.label() >
g.label()) {
250 size_t count =
f.nrChoices();
251 branches_.reserve(count);
252 for (
size_t i = 0;
i < count;
i++) {
253 NodePtr newBranch =
f.branches_[
i]->apply_f_op_g(
g, op);
254 push_back(std::move(newBranch));
256 }
else if (
g.label() >
f.label()) {
259 size_t count =
g.nrChoices();
260 branches_.reserve(count);
261 for (
size_t i = 0;
i < count;
i++) {
262 NodePtr newBranch =
g.branches_[
i]->apply_g_op_fC(
f, op);
263 push_back(std::move(newBranch));
268 size_t count =
f.nrChoices();
269 branches_.reserve(count);
270 for (
size_t i = 0;
i < count;
i++) {
271 NodePtr newBranch =
f.branches_[
i]->apply_f_op_g(*
g.branches_[
i], op);
272 push_back(std::move(newBranch));
283 return branches_.size();
297 if (allSame_ && !branches_.empty()) {
298 allSame_ = node->sameLeaf(*branches_.back());
300 branches_.push_back(std::move(node));
306 std::cout <<
s <<
" Choice(";
307 std::cout << labelFormatter(label_) <<
") " << std::endl;
308 for (
size_t i = 0;
i < branches_.size();
i++) {
309 branches_[
i]->print(
s +
" " + std::to_string(
i), labelFormatter,
valueFormatter);
316 bool showZero)
const override {
317 const std::string label = labelFormatter(label_);
318 os <<
"\"" << this->
id() <<
"\" [shape=circle, label=\"" << label
320 size_t B = branches_.size();
321 for (
size_t i = 0;
i <
B;
i++) {
322 const NodePtr& branch = branches_[
i];
326 const Leaf*
leaf =
dynamic_cast<const Leaf*
>(branch.get());
330 os <<
"\"" << this->
id() <<
"\" -> \"" << branch->id() <<
"\"";
331 if (
B == 2 &&
i == 0)
os <<
" [style=dashed]";
344 return (
q.isLeaf() &&
q.sameLeaf(*
this));
350 if (!
other)
return false;
351 if (this->label_ !=
other->label_)
return false;
352 if (branches_.size() !=
other->branches_.size())
return false;
354 for (
size_t i = 0;
i < branches_.size();
i++)
365 std::cout <<
"Trying to find value for " << label_ << std::endl;
366 throw std::invalid_argument(
367 "DecisionTree::operator(): value undefined for a label");
370 size_t index =
x.at(label_);
371 NodePtr child = branches_[index];
377 label_(label), allSame_(true) {
378 branches_.reserve(
f.branches_.size());
379 for (
const NodePtr& branch :
f.branches_) {
380 push_back(branch->apply(op));
396 : label_(label), allSame_(true) {
397 branches_.reserve(
f.branches_.size());
401 for (
size_t i = 0;
i <
f.branches_.size();
i++) {
402 assignment_[label_] =
i;
405 push_back(branch->apply(op, assignment_));
408 auto assignment_it = assignment_.find(label_);
409 assignment_.erase(assignment_it);
415 auto r = std::make_shared<Choice>(label_, *
this, op);
422 auto r = std::make_shared<Choice>(label_, *
this, op, assignment);
432 return g.apply_g_op_fC(*
this, op);
437 auto h = std::make_shared<Choice>(label(), nrChoices());
438 for (
auto&& branch : branches_)
445 auto h = std::make_shared<Choice>(fC, *
this, op);
450 template<
typename OP>
452 auto h = std::make_shared<Choice>(label(), nrChoices());
453 for (
auto&& branch : branches_)
454 h->push_back(branch->apply_f_op_g(gL, op));
460 if (label_ == label)
return branches_[index];
463 auto r = std::make_shared<Choice>(label_, branches_.size());
464 for (
auto&& branch : branches_) {
465 r->push_back(branch->choose(label, index));
474 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
476 friend class boost::serialization::access;
477 template <
class ARCHIVE>
478 void serialize(ARCHIVE& ar,
const unsigned int ) {
479 ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(
Base);
480 ar& BOOST_SERIALIZATION_NVP(label_);
481 ar& BOOST_SERIALIZATION_NVP(branches_);
482 ar& BOOST_SERIALIZATION_NVP(allSame_);
490 template <
typename L,
typename Y>
493 template<
typename L,
typename Y>
498 template<
typename L,
typename Y>
504 template <
typename L,
typename Y>
506 auto a = std::make_shared<Choice>(label, 2);
508 a->push_back(std::move(
l1));
509 a->push_back(std::move(
l2));
510 root_ = Choice::Unique(std::move(
a));
514 template <
typename L,
typename Y>
517 if (labelC.second != 2)
throw std::invalid_argument(
518 "DecisionTree: binary constructor called with non-binary label");
519 auto a = std::make_shared<Choice>(labelC.first, 2);
521 a->push_back(std::move(
l1));
522 a->push_back(std::move(
l2));
523 root_ = Choice::Unique(std::move(
a));
526 template<
typename L,
typename Y>
528 const std::vector<Y>& ys) {
530 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
534 template<
typename L,
typename Y>
536 const std::string&
table) {
539 std::istringstream iss(
table);
540 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
544 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
548 template<
typename L,
typename Y>
555 template<
typename L,
typename Y>
558 const std::vector<DecisionTree> functions{
f0,
f1};
559 root_ =
compose(functions.begin(), functions.end(), label);
563 template <
typename L,
typename Y>
566 : root_(std::move(
other.root_)) {
573 if (
auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
575 leaf->constant_ = op(
leaf->constant_);
576 }
else if (
auto choice = std::dynamic_pointer_cast<Choice>(node)) {
578 for (
NodePtr& branch : choice->branches()) {
585 ApplyUnary applyUnary{op};
589 other.root_ =
nullptr;
593 template <
typename L,
typename Y>
594 template <
typename X,
typename Func>
597 root_ = convertFrom<X>(
other.root_, Y_of_X);
601 template <
typename L,
typename Y>
602 template <
typename M,
typename X,
typename Func>
604 const std::map<M, L>& map, Func Y_of_X) {
605 auto L_of_M = [&map](
const M& label) ->
L {
return map.at(label); };
606 root_ = convertFrom<M, X>(
other.root_, L_of_M, Y_of_X);
615 template <
typename L,
typename Y>
616 template <
typename Iterator>
620 std::optional<L> highestLabel;
621 size_t nrChoices = 0;
623 if (it->root_->isLeaf())
625 std::shared_ptr<const Choice>
c =
626 std::dynamic_pointer_cast<const Choice>(it->root_);
627 if (!highestLabel ||
c->label() > *highestLabel) {
628 highestLabel =
c->label();
629 nrChoices =
c->nrChoices();
634 if (!nrChoices || !highestLabel || label > *highestLabel) {
635 auto choiceOnLabel = std::make_shared<Choice>(label,
end - begin);
638 choiceOnLabel->push_back(std::move(root));
641 return choiceOnLabel;
644 auto choiceOnHighestLabel =
645 std::make_shared<Choice>(*highestLabel, nrChoices);
647 for (
size_t index = 0; index < nrChoices; index++) {
650 std::vector<DecisionTree> functions;
654 functions.push_back(chosen);
658 choiceOnHighestLabel->push_back(std::move(fi));
660 return choiceOnHighestLabel;
685 template<
typename L,
typename Y>
686 template<
typename It,
typename ValueIt>
688 It begin, It
end, ValueIt beginY, ValueIt endY) {
690 size_t nrChoices = begin->second;
691 size_t size = endY - beginY;
694 It labelC = begin + 1;
698 if (
size != nrChoices) {
699 std::cout <<
"Trying to create DD on " << begin->first << std::endl;
700 std::cout <<
"DecisionTree::create: expected " << nrChoices
701 <<
" values but got " <<
size <<
" instead" << std::endl;
702 throw std::invalid_argument(
"DecisionTree::create invalid argument");
704 auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
705 for (ValueIt
y = beginY;
y != endY;
y++) {
714 std::vector<DecisionTree> functions;
715 functions.reserve(nrChoices);
717 for (
size_t i = 0;
i < nrChoices;
i++, beginY +=
split) {
719 functions.emplace_back(
f);
721 return compose(functions.begin(), functions.end(), begin->first);
727 template<
typename L,
typename Y>
728 template<
typename It,
typename ValueIt>
730 It begin, It
end, ValueIt beginY, ValueIt endY) {
731 auto node =
build(begin,
end, beginY, endY);
732 if (
auto choice = std::dynamic_pointer_cast<Choice>(node)) {
733 return Choice::Unique(choice);
740 template <
typename L,
typename Y>
741 template <
typename X>
744 std::function<
Y(
const X&)> Y_of_X) {
748 if (
auto leaf = std::dynamic_pointer_cast<LXLeaf>(
f)) {
754 auto choice = std::dynamic_pointer_cast<const LXChoice>(
f);
755 if (!choice)
throw std::invalid_argument(
756 "DecisionTree::convertFrom: Invalid NodePtr");
759 auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
762 for (
auto&& branch : choice->branches()) {
763 newChoice->push_back(convertFrom<X>(branch, Y_of_X));
766 return Choice::Unique(newChoice);
770 template <
typename L,
typename Y>
771 template <
typename M,
typename X>
774 std::function<
L(
const M&)> L_of_M, std::function<
Y(
const X&)> Y_of_X) {
779 if (
auto leaf = std::dynamic_pointer_cast<const MXLeaf>(
f)) {
785 auto choice = std::dynamic_pointer_cast<const MXChoice>(
f);
787 throw std::invalid_argument(
"DecisionTree::convertFrom: Invalid NodePtr");
790 const M oldLabel = choice->label();
791 const L newLabel = L_of_M(oldLabel);
801 std::vector<LY> functions;
802 for (
auto&& branch : choice->branches()) {
803 functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
805 return Choice::Unique(
806 LY::compose(functions.begin(), functions.end(), newLabel));
820 template <
typename L,
typename Y>
822 using F = std::function<void(
const Y&)>;
829 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
830 return f(
leaf->constant());
833 auto choice = std::dynamic_pointer_cast<const Choice>(node);
835 throw std::invalid_argument(
"DecisionTree::Visit: Invalid NodePtr");
836 for (
auto&& branch : choice->branches()) (*this)(branch);
840 template <
typename L,
typename Y>
841 template <
typename Func>
857 template <
typename L,
typename Y>
866 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
870 auto choice = std::dynamic_pointer_cast<const Choice>(node);
872 throw std::invalid_argument(
"DecisionTree::VisitLeaf: Invalid NodePtr");
873 for (
auto&& branch : choice->branches()) (*this)(branch);
877 template <
typename L,
typename Y>
878 template <
typename Func>
891 template <
typename L,
typename Y>
901 if (
auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
902 return f(assignment,
leaf->constant());
905 auto choice = std::dynamic_pointer_cast<const Choice>(node);
907 throw std::invalid_argument(
"DecisionTree::VisitWith: Invalid NodePtr");
908 for (
size_t i = 0;
i < choice->nrChoices();
i++) {
909 assignment[choice->label()] =
i;
911 (*this)(choice->branches()[
i]);
914 auto choice_it = assignment.find(choice->label());
915 assignment.erase(choice_it);
920 template <
typename L,
typename Y>
921 template <
typename Func>
928 template <
typename L,
typename Y>
931 visit([&total](
const Y& node) { total += 1; });
937 template <
typename L,
typename Y>
938 template <
typename Func,
typename X>
940 visit([&](
const Y&
y) {
x0 =
f(
y,
x0); });
958 template <
typename L,
typename Y>
962 for (
auto&& kv : assignment) {
963 unique.insert(kv.first);
971 template <
typename L,
typename Y>
977 template <
typename L,
typename Y>
984 template<
typename L,
typename Y>
986 return root_->equals(*
other.root_);
990 template<
typename L,
typename Y>
992 if (root_ ==
nullptr)
993 throw std::invalid_argument(
994 "DecisionTree::operator() called on empty tree");
995 return root_->operator ()(
x);
999 template<
typename L,
typename Y>
1003 throw std::runtime_error(
1004 "DecisionTree::apply(unary op) undefined for empty tree.");
1011 template <
typename L,
typename Y>
1016 throw std::runtime_error(
1017 "DecisionTree::apply(unary op) undefined for empty tree.");
1024 template<
typename L,
typename Y>
1026 const Binary& op)
const {
1028 if (
empty() ||
g.empty()) {
1029 throw std::runtime_error(
1030 "DecisionTree::apply(binary op) undefined for empty trees.");
1033 NodePtr h = root_->apply_f_op_g(*
g.root_, op);
1048 template<
typename L,
typename Y>
1050 size_t cardinality,
const Binary& op)
const {
1052 for (
size_t index = 1; index < cardinality; index++) {
1060 template <
typename L,
typename Y>
1064 bool showZero)
const {
1065 os <<
"digraph G {\n";
1067 os <<
" [ordering=out]}" << std::endl;
1070 template <
typename L,
typename Y>
1074 bool showZero)
const {
1078 system((
"dot -Tpdf " +
name +
".dot -o " +
name +
".pdf >& /dev/null")
1081 throw std::runtime_error(
"DecisionTree::dot system call failed");
1084 template <
typename L,
typename Y>
1087 bool showZero)
const {
1088 std::stringstream
ss;
1094 template <
typename L,
typename Y>
1095 template <
typename A,
typename B>
1097 std::function<std::pair<A, B>(
const Y&)> AB_of_Y)
const {
1098 using AB = std::pair<A, B>;