25 #include <boost/format.hpp> 26 #include <boost/optional.hpp> 27 #include <boost/tuple/tuple.hpp> 28 #include <boost/assign/std/vector.hpp> 29 using boost::assign::operator+=;
30 #include <boost/unordered_set.hpp> 31 #include <boost/noncopyable.hpp> 43 #ifdef DT_DEBUG_MEMORY 44 template<
typename L,
typename Y>
45 int DecisionTree<L, Y>::Node::nrNodes = 0;
51 template<
typename L,
typename Y>
61 constant_(constant) {}
80 const Leaf* other =
dynamic_cast<const Leaf*
> (&
q);
81 if (!other)
return false;
86 void print(
const std::string&
s)
const override {
88 if (showZero || constant_) std::cout << s <<
" Leaf " << constant_ << std::endl;
92 void dot(std::ostream&
os,
bool showZero)
const override {
93 if (showZero || constant_) os <<
"\"" << this->
id() <<
"\" [label=\"" 94 << boost::format(
"%4.2g") % constant_
95 <<
"\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
134 bool isLeaf()
const override {
return true; }
141 template<
typename L,
typename Y>
159 #ifdef DT_DEBUG_MEMORY 160 std::std::cout << Node::nrNodes <<
" destructing (Choice) " << this->
id() << std::std::endl;
166 #ifndef DT_NO_PRUNING 168 assert(f->branches().size() > 0);
170 assert(f0->isLeaf());
171 NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
178 bool isLeaf()
const override {
return false; }
182 label_(label), allSame_(true) {
183 branches_.reserve(count);
197 branches_.reserve(count);
198 for (
size_t i = 0;
i < count;
i++)
199 push_back(f.
branches_[
i]->apply_f_op_g(g, op));
204 branches_.reserve(count);
205 for (
size_t i = 0;
i < count;
i++)
206 push_back(g.
branches_[
i]->apply_g_op_fC(f, op));
211 branches_.reserve(count);
212 for (
size_t i = 0;
i < count;
i++)
222 return branches_.size();
232 if (allSame_ && !branches_.empty()) {
233 allSame_ = node->sameLeaf(*branches_.back());
235 branches_.push_back(node);
239 void print(
const std::string&
s)
const override {
240 std::cout << s <<
" Choice(";
242 std::cout << label_ <<
") " << std::endl;
243 for (
size_t i = 0;
i < branches_.size();
i++)
244 branches_[
i]->
print((boost::format(
"%s %d") % s %
i).str());
248 void dot(std::ostream&
os,
bool showZero)
const override {
249 os <<
"\"" << this->
id() <<
"\" [shape=circle, label=\"" << label_
251 for (
size_t i = 0;
i < branches_.size();
i++) {
256 const Leaf*
leaf =
dynamic_cast<const Leaf*
> (branch.get());
257 if (leaf && !leaf->
constant())
continue;
260 os <<
"\"" << this->
id() <<
"\" -> \"" << branch->id() <<
"\"";
261 if (
i == 0) os <<
" [style=dashed]";
262 if (
i > 1) os <<
" [style=bold]";
264 branch->dot(os, showZero);
281 if (!other)
return false;
282 if (this->label_ != other->
label_)
return false;
283 if (branches_.size() != other->
branches_.size())
return false;
285 for (
size_t i = 0;
i < branches_.size();
i++)
295 std::cout <<
"Trying to find value for " << label_ << std::endl;
296 throw std::invalid_argument(
297 "DecisionTree::operator(): value undefined for a label");
300 size_t index = x.at(label_);
301 NodePtr child = branches_[index];
309 label_(label), allSame_(true) {
313 push_back(branch->apply(op));
318 boost::shared_ptr<Choice> r(
new Choice(label_, *
this, op));
333 boost::shared_ptr<Choice>
h(
new Choice(label(), nrChoices()));
341 boost::shared_ptr<Choice>
h(
new Choice(fC, *
this, op));
346 template<
typename OP>
348 boost::shared_ptr<Choice>
h(
new Choice(label(), nrChoices()));
349 for(
const NodePtr& branch: branches_)
350 h->push_back(branch->apply_f_op_g(gL, op));
357 return branches_[index];
360 boost::shared_ptr<Choice> r(
new Choice(label_, branches_.size()));
361 for(
const NodePtr& branch: branches_)
362 r->push_back(branch->choose(label, index));
371 template<
typename L,
typename Y>
375 template<
typename L,
typename Y>
381 template<
typename L,
typename Y>
387 template<
typename L,
typename Y>
389 const L& label,
const Y& y1,
const Y& y2) {
390 boost::shared_ptr<Choice>
a(
new Choice(label, 2));
398 template<
typename L,
typename Y>
400 const LabelC& labelC,
const Y& y1,
const Y& y2) {
401 if (labelC.second != 2)
throw std::invalid_argument(
402 "DecisionTree: binary constructor called with non-binary label");
403 boost::shared_ptr<Choice>
a(
new Choice(labelC.first, 2));
411 template<
typename L,
typename Y>
413 const std::vector<Y>& ys) {
415 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
419 template<
typename L,
typename Y>
421 const std::string&
table) {
425 std::istringstream iss(table);
426 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
430 root_ =
create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
434 template<
typename L,
typename Y>
436 Iterator begin, Iterator
end,
const L& label) {
441 template<
typename L,
typename Y>
444 std::vector<DecisionTree> functions;
446 root_ =
compose(functions.begin(), functions.end(), label);
450 template<
typename L,
typename Y>
451 template<
typename M,
typename X>
453 const std::map<M, L>& map, boost::function<
Y(
const X&)> op) {
463 template<
typename L,
typename Y>
template<
typename Iterator>
465 Iterator
end,
const L& label)
const {
468 boost::optional<L> highestLabel;
469 size_t nrChoices = 0;
470 for (Iterator it = begin; it !=
end; it++) {
471 if (it->root_->isLeaf())
473 boost::shared_ptr<const Choice>
c =
474 boost::dynamic_pointer_cast<
const Choice>(it->root_);
475 if (!highestLabel || c->label() > *highestLabel) {
476 highestLabel.reset(c->label());
482 if (!nrChoices || !highestLabel || label > *highestLabel) {
483 boost::shared_ptr<Choice> choiceOnLabel(
new Choice(label, end - begin));
484 for (Iterator it = begin; it !=
end; it++)
485 choiceOnLabel->push_back(it->root_);
489 boost::shared_ptr<Choice> choiceOnHighestLabel(
new Choice(*highestLabel, nrChoices));
491 for (
size_t index = 0; index < nrChoices; index++) {
494 std::vector<DecisionTree> functions;
495 for (Iterator it = begin; it !=
end; it++) {
498 functions.push_back(chosen);
502 choiceOnHighestLabel->push_back(fi);
529 template<
typename L,
typename Y>
530 template<
typename It,
typename ValueIt>
532 It begin,
It end, ValueIt beginY, ValueIt endY)
const {
535 size_t nrChoices = begin->second;
536 size_t size = endY - beginY;
539 It labelC = begin + 1;
543 if (size != nrChoices) {
544 std::cout <<
"Trying to create DD on " << begin->first << std::endl;
545 std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
546 throw std::invalid_argument(
"DecisionTree::create invalid argument");
548 boost::shared_ptr<Choice> choice(
new Choice(begin->first, endY - beginY));
549 for (ValueIt
y = beginY;
y != endY;
y++)
557 std::vector<DecisionTree> functions;
558 size_t split = size / nrChoices;
559 for (
size_t i = 0;
i < nrChoices;
i++, beginY +=
split) {
563 return compose(functions.begin(), functions.end(), begin->first);
567 template<
typename L,
typename Y>
568 template<
typename M,
typename X>
571 boost::function<
Y(
const X&)> op) {
574 typedef typename MX::Leaf MXLeaf;
575 typedef typename MX::Choice MXChoice;
576 typedef typename MX::NodePtr MXNodePtr;
581 const MXLeaf*
leaf =
dynamic_cast<const MXLeaf*
> (f.get());
582 if (leaf)
return NodePtr(
new Leaf(op(leaf->constant())));
585 boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<
const MXChoice> (
f);
586 if (!choice)
throw std::invalid_argument(
587 "DecisionTree::Convert: Invalid NodePtr");
590 M oldLabel = choice->label();
591 L newLabel = map.at(oldLabel);
594 std::vector<LY> functions;
595 for(
const MXNodePtr& branch: choice->branches()) {
596 LY converted(convert<M, X>(branch, map, op));
597 functions += converted;
599 return LY::compose(functions.begin(), functions.end(), newLabel);
603 template<
typename L,
typename Y>
608 template<
typename L,
typename Y>
613 template<
typename L,
typename Y>
618 template<
typename L,
typename Y>
620 return root_->operator ()(
x);
623 template<
typename L,
typename Y>
629 template<
typename L,
typename Y>
648 template<
typename L,
typename Y>
650 size_t cardinality,
const Binary& op)
const {
652 for (
size_t index = 1; index < cardinality; index++) {
654 result = result.
apply(chosen, op);
660 template<
typename L,
typename Y>
662 os <<
"digraph G {\n";
663 root_->dot(os, showZero);
664 os <<
" [ordering=out]}" << std::endl;
667 template<
typename L,
typename Y>
669 std::ofstream
os((name +
".dot").
c_str());
672 (
"dot -Tpdf " + name +
".dot -o " + name +
".pdf >& /dev/null").
c_str());
673 if (result==-1)
throw std::runtime_error(
"DecisionTree::dot system call failed");
NodePtr apply(const Unary &op) const override
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
Decision Tree for use in DiscreteFactors.
Matrix< RealScalar, Dynamic, Dynamic > M
const std::vector< NodePtr > & branches() const
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
bool sameLeaf(const Node &q) const override
polymorphic equality: is q is a leaf, could be
Concept check for values that can be used in unit tests.
bool isLeaf() const override
virtual bool isLeaf() const =0
Q id(Eigen::AngleAxisd(0, Q_z_axis))
NodePtr apply_fC_op_gL(const Leaf &gL, OP op) const
bool equals(const DecisionTree &other, double tol=1e-9) const
DecisionTree apply(const Unary &op) const
void push_back(const NodePtr &node)
NodePtr apply(const Unary &op) const override
bool isLeaf() const override
boost::function< Y(const Y &, const Y &)> Binary
static NodePtr Unique(const ChoicePtr &f)
virtual Ptr apply_g_op_fL(const Leaf &, const Binary &) const =0
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
DecisionTree choose(const L &label, size_t index) const
boost::shared_ptr< const Choice > ChoicePtr
void print(const std::string &s) const override
const mpreal root(const mpreal &x, unsigned long int k, mp_rnd_t r=mpreal::get_default_rnd())
void g(const string &key, int i)
T compose(const T &t1, const T &t2)
const char * c_str(Args &&...args)
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
std::vector< NodePtr > branches_
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
boost::function< Y(const Y &)> Unary
NodePtr compose(Iterator begin, Iterator end, const L &label) const
void print(const std::string &s="DecisionTree") const
Choice(const L &label, const Choice &f, const Unary &op)
void dot(std::ostream &os, bool showZero=true) const
const Y & operator()(const Assignment< L > &x) const override
bool equals(const Node &q, double tol) const override
NodePtr apply_f_op_g(const Node &g, const Binary &op) const override
Choice(const L &label, size_t count)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
NodePtr choose(const L &label, size_t index) const override
EIGEN_DEVICE_FUNC const Scalar & q
NodePtr choose(const L &label, size_t index) const override
Choice(const Choice &f, const Choice &g, const Binary &op)
string::const_iterator It
void dot(std::ostream &os, bool showZero) const override
const Y & operator()(const Assignment< L > &x) const override
bool equals(const Node &q, double tol) const override
const Y & operator()(const Assignment< L > &x) const
NodePtr apply_g_op_fC(const Choice &fC, const Binary &op) const override
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
ofstream os("timeSchurFactors.csv")
void print(const std::string &s) const override
bool operator==(const DecisionTree &q) const
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
void dot(std::ostream &os, bool showZero) const override
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
NodePtr convert(const typename DecisionTree< M, X >::NodePtr &f, const std::map< M, L > &map, boost::function< Y(const X &)> op)
virtual bool sameLeaf(const Leaf &q) const =0
Annotation for function names.
int EIGEN_BLAS_FUNC() copy(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy)
NodePtr apply_g_op_fL(const Leaf &fL, const Binary &op) const override
virtual Ptr apply_g_op_fC(const Choice &, const Binary &) const =0
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
const Y & constant() const
std::pair< L, size_t > LabelC