32 using namespace gtsam;
37 vector<double> dropoutmask(
size);
39 uniform_int_distribution<> dist(1, 9);
40 auto gen = [&dist, &
g]() {
return dist(
g); };
41 generate(dropoutmask.begin(), dropoutmask.end(), gen);
43 fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
44 shuffle(dropoutmask.begin(), dropoutmask.end(),
g);
49 map<double, pair<chrono::microseconds, chrono::microseconds>>
measureTime(
51 vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
52 map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
54 for (
auto dropout : dropouts) {
63 auto tb_start = chrono::high_resolution_clock::now();
65 auto tb_end = chrono::high_resolution_clock::now();
67 chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
70 auto dt_start = chrono::high_resolution_clock::now();
72 auto dt_end = chrono::high_resolution_clock::now();
74 chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
77 for (
auto assignmentVal : actual_dt.
enumerate()) {
78 flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
80 std::cout <<
"something is wrong: " << std::endl;
81 assignmentVal.first.print();
82 std::cout <<
"dt: " << actual_dt(assignmentVal.first) << std::endl;
83 std::cout <<
"tb: " << actual(assignmentVal.first) << std::endl;
88 measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
90 return measured_times;
93 void printTime(map<
double, pair<chrono::microseconds, chrono::microseconds>>
95 for (
auto&& kv : measured_time) {
96 cout <<
"dropout: " << kv.first
97 <<
" | TableFactor time: " << kv.second.first.count()
98 <<
" | DecisionTreeFactor time: " << kv.second.second.count() << endl;
142 std::string expected_values =
143 "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
171 dkeys, std::vector<double>{0, 0, 0, 0.14649446, 0, 0.14648756, 0.14649446,
182 X, {
Y}, std::vector<double>{0.33333333, 0.6, 0.66666667, 0.4});
191 tf2.toDecisionTreeFactor()));
202 single->toDecisionTreeFactor()));
208 empty->toDecisionTreeFactor()));
221 f1.toDecisionTreeFactor()));
243 DiscreteKey A(0, 5),
B(1, 2),
C(2, 5),
D(3, 2),
E(4, 5),
F(5, 2),
G(6, 3),
244 H(7, 2),
I(8, 5),
J(9, 7),
K(10, 2),
L(11, 3);
249 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
255 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
261 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
267 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
273 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
279 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
285 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
291 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
297 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
309 auto actual = std::dynamic_pointer_cast<TableFactor>(
f1.sum(1));
314 auto actual2 = std::dynamic_pointer_cast<TableFactor>(
f1.max(1));
319 auto actual22 = std::dynamic_pointer_cast<TableFactor>(
f2.sum(1));
328 auto actual =
f.enumerate();
329 std::vector<std::pair<DiscreteValues, double>>
expected;
331 for (
size_t a : {0, 1, 2}) {
332 for (
size_t b : {0, 1}) {
348 size_t maxNrAssignments = 5;
349 auto pruned5 =
f.prune(maxNrAssignments);
356 maxNrAssignments = 2;
357 auto pruned2 =
f.prune(maxNrAssignments);
364 "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
365 "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
368 "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
369 "0.999952870000 1.0 1.0 1.0 1.0");
370 maxNrAssignments = 5;
408 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
409 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
410 string actual =
f.markdown(keyFormatter,
names);
421 "<table class='TableFactor'>\n"
423 " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
426 " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
427 " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
428 " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
429 " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
430 " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
431 " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
435 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
436 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
437 string actual =
f.html(keyFormatter,
names);
448 auto op = [](
const double x) {
return 2 *
x; };
449 auto g =
f.apply(op);
454 auto sq_op = [](
const double x) {
return x *
x; };
455 auto g_sq =
f.apply(sq_op);
468 auto g =
f.apply(op);