31 using namespace gtsam;
36 vector<double> dropoutmask(
size);
38 uniform_int_distribution<> dist(1, 9);
39 auto gen = [&dist, &
g]() {
return dist(
g); };
40 generate(dropoutmask.begin(), dropoutmask.end(), gen);
42 fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
43 shuffle(dropoutmask.begin(), dropoutmask.end(),
g);
48 map<double, pair<chrono::microseconds, chrono::microseconds>>
measureTime(
50 vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
51 map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
53 for (
auto dropout : dropouts) {
62 auto tb_start = chrono::high_resolution_clock::now();
64 auto tb_end = chrono::high_resolution_clock::now();
66 chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
69 auto dt_start = chrono::high_resolution_clock::now();
71 auto dt_end = chrono::high_resolution_clock::now();
73 chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
76 for (
auto assignmentVal : actual_dt.
enumerate()) {
77 flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
79 std::cout <<
"something is wrong: " << std::endl;
80 assignmentVal.first.print();
81 std::cout <<
"dt: " << actual_dt(assignmentVal.first) << std::endl;
82 std::cout <<
"tb: " << actual(assignmentVal.first) << std::endl;
87 measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
89 return measured_times;
92 void printTime(map<
double, pair<chrono::microseconds, chrono::microseconds>>
94 for (
auto&& kv : measured_time) {
95 cout <<
"dropout: " << kv.first
96 <<
" | TableFactor time: " << kv.second.first.count()
97 <<
" | DecisionTreeFactor time: " << kv.second.second.count() << endl;
141 std::string expected_values =
142 "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
170 dkeys, std::vector<double>{0, 0, 0, 0.14649446, 0, 0.14648756, 0.14649446,
181 X, {
Y}, std::vector<double>{0.33333333, 0.6, 0.66666667, 0.4});
190 tf2.toDecisionTreeFactor()));
201 single->toDecisionTreeFactor()));
207 empty->toDecisionTreeFactor()));
220 f1.toDecisionTreeFactor()));
242 DiscreteKey A(0, 5),
B(1, 2),
C(2, 5),
D(3, 2),
E(4, 5),
F(5, 2),
G(6, 3),
243 H(7, 2),
I(8, 5),
J(9, 7),
K(10, 2),
L(11, 3);
248 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
254 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
260 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
266 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
272 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
278 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
284 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
290 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
296 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
308 auto actual = std::dynamic_pointer_cast<TableFactor>(
f1.sum(1));
313 auto actual2 = std::dynamic_pointer_cast<TableFactor>(
f1.max(1));
318 auto actual22 = std::dynamic_pointer_cast<TableFactor>(
f2.sum(1));
327 auto actual =
f.enumerate();
328 std::vector<std::pair<DiscreteValues, double>>
expected;
330 for (
size_t a : {0, 1, 2}) {
331 for (
size_t b : {0, 1}) {
347 size_t maxNrAssignments = 5;
348 auto pruned5 =
f.prune(maxNrAssignments);
355 maxNrAssignments = 2;
356 auto pruned2 =
f.prune(maxNrAssignments);
363 "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
364 "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
367 "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
368 "0.999952870000 1.0 1.0 1.0 1.0");
369 maxNrAssignments = 5;
407 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
408 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
409 string actual =
f.markdown(keyFormatter,
names);
420 "<table class='TableFactor'>\n"
422 " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
425 " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
426 " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
427 " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
428 " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
429 " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
430 " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
434 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
435 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
436 string actual =
f.html(keyFormatter,
names);
447 auto op = [](
const double x) {
return 2 *
x; };
448 auto g =
f.apply(op);
453 auto sq_op = [](
const double x) {
return x *
x; };
454 auto g_sq =
f.apply(sq_op);
467 auto g =
f.apply(op);