34 using namespace gtsam;
43 #define DOT(x) (dot(x, #x)) 52 void print(
const std::string&
s =
"")
const {
53 auto keyFormatter = [](
const std::string&
s) {
return s; };
56 ss <<
"{" <<
v.a <<
"," << std::setw(4) << std::setprecision(2) <<
v.b <<
"}";
90 void print(
const std::string&
s =
"")
const {
91 auto keyFormatter = [](
const std::string&
s) {
return s; };
93 return std::to_string(
v);
100 auto compare = [](
const int&
v,
const int&
w) {
return v ==
w; };
101 return Base::equals(other,
compare);
114 static inline int zero() {
return 0; }
115 static inline int one() {
return 1; }
116 static inline int id(
const int&
a) {
return a; }
117 static inline int add(
const int&
a,
const int&
b) {
return a +
b; }
118 static inline int mul(
const int&
a,
const int&
b) {
return a *
b; }
125 string A(
"A"),
B(
"B"),
C(
"C");
129 x00[
A] = 0, x00[
B] = 0;
130 x01[
A] = 0, x01[
B] = 1;
131 x10[
A] = 1, x10[
B] = 0;
132 x11[
A] = 1, x11[
B] = 1;
215 x101[
A] = 1, x101[
B] = 0, x101[
C] = 1;
235 string A(
"A"),
B(
"B");
245 x00[
"A"] = 0, x00[
"B"] = 0;
256 string A(
"A"),
B(
"B");
269 x00[
X] = 0, x00[
Y] = 0;
270 x01[
X] = 0, x01[
Y] = 1;
271 x10[
X] = 1, x10[
Y] = 0;
272 x11[
X] = 1, x11[
Y] = 1;
283 string A(
"A"),
B(
"B"),
C(
"C");
300 DT f5(
keys,
"0 4 2 6 1 5 3 7");
308 using Container = std::vector<double>;
312 StringContainerTree
tree;
315 string A(
"A"),
B(
"B");
316 DT stringIntTree(B,
DT(
A, 0, 1),
DT(
A, 2, 3));
319 auto container_of_int = [](
const int&
i) {
324 StringContainerTree converted(stringIntTree, container_of_int);
330 const std::pair<string, size_t>
A(
"A", 2),
B(
"B", 2),
C(
"C", 2);
331 DT tree({
A,
B, C},
"1 1 1 1 1 1 1 1");
336 DT tree2({
C,
B,
A},
"1 1 1 2 3 4 5 5");
351 auto root = std::dynamic_pointer_cast<
const DT::Choice>(tree2.root_);
353 auto choice0 = std::dynamic_pointer_cast<
const DT::Choice>(root->branches()[0]);
355 EXPECT(choice0->branches()[0]->isLeaf());
356 auto choice00 = std::dynamic_pointer_cast<
const DT::Leaf>(choice0->branches()[0]);
360 auto choice1 = std::dynamic_pointer_cast<
const DT::Choice>(root->branches()[1]);
362 auto choice10 = std::dynamic_pointer_cast<
const DT::Choice>(choice1->branches()[0]);
364 auto choice11 = std::dynamic_pointer_cast<
const DT::Leaf>(choice1->branches()[1]);
366 EXPECT(choice11->isLeaf());
374 string A(
"A"),
B(
"B");
377 auto visitor = [&](
int y) { sum +=
y; };
386 string A(
"A"),
B(
"B");
398 std::pair<string, size_t>
A(
"A", 2),
B(
"B", 2),
C(
"C", 2);
399 std::vector<std::pair<string, size_t>>
labels = {
C,
B,
A};
400 std::vector<int>
nodes = {0, 0, 2, 3, 4, 4, 6, 7};
403 std::vector<Assignment<string>> choices;
405 choices.push_back(choice);
413 expectedAssignment = {{
"B", 0}, {
"C", 0}};
414 EXPECT(expectedAssignment == choices.at(0));
416 expectedAssignment = {{
"A", 0}, {
"B", 1}, {
"C", 0}};
417 EXPECT(expectedAssignment == choices.at(1));
419 expectedAssignment = {{
"A", 1}, {
"B", 1}, {
"C", 0}};
420 EXPECT(expectedAssignment == choices.at(2));
422 expectedAssignment = {{
"B", 0}, {
"C", 1}};
423 EXPECT(expectedAssignment == choices.at(3));
425 expectedAssignment = {{
"A", 0}, {
"B", 1}, {
"C", 1}};
426 EXPECT(expectedAssignment == choices.at(4));
428 expectedAssignment = {{
"A", 1}, {
"B", 1}, {
"C", 1}};
429 EXPECT(expectedAssignment == choices.at(5));
436 string A(
"A"),
B(
"B");
438 auto add = [](
const int&
y,
double x) {
return y +
x; };
439 double sum = tree.
fold(
add, 0.0);
447 string A(
"A"),
B(
"B");
461 string A(
"A"),
B(
"B"),
C(
"C");
462 DTP
tree(
B, DTP(
A, {0,
"zero"}, {1,
"one"}),
463 DTP(
A, {2,
"two"}, {1337,
"l33t"}));
465 const auto [dt1, dt2] =
unzip(tree);
467 DT1 tree1(
B, DT1(
A, 0, 1), DT1(
A, 2, 1337));
468 DT2 tree2(
B, DT2(
A,
"zero",
"one"), DT2(
A,
"two",
"l33t"));
470 EXPECT(tree1.equals(dt1));
471 EXPECT(tree2.equals(dt2));
483 auto count = [](
const int&
value,
int count) {
484 return value == 0 ? count + 1 : count;
489 auto threshold = [](
int value) {
return value < 5 ? 0 :
value; };
490 DT thresholded(tree, threshold);
506 keys,
"0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
507 double threshold = 0.045;
512 if (probTree(choices) < threshold) {
518 DT prunedTree = tree.
apply(pruner);
520 DT expectedTree(
keys,
"0 0 0 0 5 6 7 8");
528 DT prunedTree2 = prunedTree.
apply(counter);
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
static int id(const int &a)
static int mul(const int &a, const int &b)
Concept check for values that can be used in unit tests.
void visit(Func f) const
Visit all leaves in depth-first fashion.
static int runAllTests(TestResult &result)
double dot(const V1 &a, const V2 &b)
signatures for conditional densities
std::vector< std::string > labels
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
static std::string valueFormatter(const double &v)
#define CHECK_EXCEPTION(condition, exception_name)
std::set< L > labels() const
static enum @1107 ordering
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
bool bool_of_int(const int &y)
#define EXPECT(condition)
Array< int, Dynamic, 1 > v
bool equals(const Base &other, double tol=1e-9) const
Equality method customized to int node type.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
void print(const std::string &s="") const
print to stdout
DecisionTree apply(const Unary &op) const
static sharedNode Leaf(Key key, const SymbolicFactorGraph &factors)
static std::stringstream ss
Matrix< Scalar, Dynamic, Dynamic > C
#define LONGS_EQUAL(expected, actual)
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
static int add(const int &a, const int &b)
#define EXPECT_LONGS_EQUAL(expected, actual)
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
double f4(double x, double y, double z)
double f3(double x1, double x2)
void print(const std::string &s="") const
print to stdout
graph add(PriorFactor< Pose2 >(1, priorMean, priorNoise))
DecisionTree< Label, bool > LabelBoolTree
std::pair< string, size_t > LabelC
TEST(SmartFactorBase, Pinhole)
DecisionTree< string, bool > StringBoolTree
bool equals(const CrazyDecisionTree &other, double tol=1e-9) const
Equality method customized to Crazy node type.
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
DecisionTree choose(const L &label, size_t index) const
#define GTSAM_CONCEPT_TESTABLE_INST(T)