31 using namespace gtsam;
36 vector<double> dropoutmask(
size);
38 uniform_int_distribution<> dist(1, 9);
39 auto gen = [&dist, &
g]() {
return dist(
g); };
40 generate(dropoutmask.begin(), dropoutmask.end(), gen);
42 fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
43 shuffle(dropoutmask.begin(), dropoutmask.end(),
g);
48 map<double, pair<chrono::microseconds, chrono::microseconds>>
measureTime(
50 vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
51 map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
53 for (
auto dropout : dropouts) {
62 auto tb_start = chrono::high_resolution_clock::now();
64 auto tb_end = chrono::high_resolution_clock::now();
66 chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
69 auto dt_start = chrono::high_resolution_clock::now();
71 auto dt_end = chrono::high_resolution_clock::now();
73 chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
76 for (
auto assignmentVal : actual_dt.
enumerate()) {
77 flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
79 std::cout <<
"something is wrong: " << std::endl;
80 assignmentVal.first.print();
81 std::cout <<
"dt: " << actual_dt(assignmentVal.first) << std::endl;
82 std::cout <<
"tb: " << actual(assignmentVal.first) << std::endl;
87 measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
89 return measured_times;
92 void printTime(map<
double, pair<chrono::microseconds, chrono::microseconds>>
94 for (
auto&& kv : measured_time) {
95 cout <<
"dropout: " << kv.first
96 <<
" | TableFactor time: " << kv.second.first.count()
97 <<
" | DecisionTreeFactor time: " << kv.second.second.count() << endl;
141 std::string expected_values =
142 "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
170 dkeys, std::vector<double>{0, 0, 0, 0.14649446, 0, 0.14648756, 0.14649446,
181 X, {
Y}, std::vector<double>{0.33333333, 0.6, 0.66666667, 0.4});
190 tf2.toDecisionTreeFactor()));
205 empty.toDecisionTreeFactor()));
218 f1.toDecisionTreeFactor()));
240 DiscreteKey A(0, 5),
B(1, 2),
C(2, 5),
D(3, 2),
E(4, 5),
F(5, 2),
G(6, 3),
241 H(7, 2),
I(8, 5),
J(9, 7),
K(10, 2),
L(11, 3);
246 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
252 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
258 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
264 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
270 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
276 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
282 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
288 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
294 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
322 auto actual =
f.enumerate();
323 std::vector<std::pair<DiscreteValues, double>>
expected;
325 for (
size_t a : {0, 1, 2}) {
326 for (
size_t b : {0, 1}) {
342 size_t maxNrAssignments = 5;
343 auto pruned5 =
f.prune(maxNrAssignments);
350 maxNrAssignments = 2;
351 auto pruned2 =
f.prune(maxNrAssignments);
358 "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
359 "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
362 "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
363 "0.999952870000 1.0 1.0 1.0 1.0");
364 maxNrAssignments = 5;
402 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
403 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
404 string actual =
f.markdown(keyFormatter,
names);
415 "<table class='TableFactor'>\n"
417 " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
420 " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
421 " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
422 " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
423 " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
424 " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
425 " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
429 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
430 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
431 string actual =
f.html(keyFormatter,
names);
442 auto op = [](
const double x) {
return 2 *
x; };
443 auto g =
f.apply(op);
448 auto sq_op = [](
const double x) {
return x *
x; };
449 auto g_sq =
f.apply(sq_op);
462 auto g =
f.apply(op);