31 using namespace gtsam;
36 vector<double> dropoutmask(
size);
38 uniform_int_distribution<> dist(1, 9);
39 auto gen = [&dist, &
g]() {
return dist(
g); };
40 generate(dropoutmask.begin(), dropoutmask.end(), gen);
42 fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
43 shuffle(dropoutmask.begin(), dropoutmask.end(),
g);
48 map<double, pair<chrono::microseconds, chrono::microseconds>>
measureTime(
50 vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
51 map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
53 for (
auto dropout : dropouts) {
62 auto tb_start = chrono::high_resolution_clock::now();
64 auto tb_end = chrono::high_resolution_clock::now();
66 chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
69 auto dt_start = chrono::high_resolution_clock::now();
71 auto dt_end = chrono::high_resolution_clock::now();
73 chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
76 for (
auto assignmentVal : actual_dt.
enumerate()) {
77 flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
79 std::cout <<
"something is wrong: " << std::endl;
80 assignmentVal.first.print();
81 std::cout <<
"dt: " << actual_dt(assignmentVal.first) << std::endl;
82 std::cout <<
"tb: " << actual(assignmentVal.first) << std::endl;
87 measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
89 return measured_times;
92 void printTime(map<
double, pair<chrono::microseconds, chrono::microseconds>>
94 for (
auto&& kv : measured_time) {
95 cout <<
"dropout: " << kv.first
96 <<
" | TableFactor time: " << kv.second.first.count()
97 <<
" | DecisionTreeFactor time: " << kv.second.second.count() << endl;
141 std::string expected_values =
142 "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
170 dkeys, std::vector<double>{0, 0, 0, 0.14649446, 0, 0.14648756, 0.14649446,
188 f1.toDecisionTreeFactor()));
210 DiscreteKey A(0, 5),
B(1, 2),
C(2, 5),
D(3, 2),
E(4, 5),
F(5, 2),
G(6, 3),
211 H(7, 2),
I(8, 5),
J(9, 7),
K(10, 2),
L(11, 3);
216 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
222 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
228 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
234 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
240 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
246 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
252 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
258 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
264 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
292 auto actual =
f.enumerate();
293 std::vector<std::pair<DiscreteValues, double>>
expected;
295 for (
size_t a : {0, 1, 2}) {
296 for (
size_t b : {0, 1}) {
312 size_t maxNrAssignments = 5;
313 auto pruned5 =
f.prune(maxNrAssignments);
320 maxNrAssignments = 2;
321 auto pruned2 =
f.prune(maxNrAssignments);
328 "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
329 "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
332 "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
333 "0.999952870000 1.0 1.0 1.0 1.0");
334 maxNrAssignments = 5;
372 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
373 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
374 string actual =
f.markdown(keyFormatter,
names);
385 "<table class='TableFactor'>\n"
387 " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
390 " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
391 " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
392 " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
393 " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
394 " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
395 " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
399 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
400 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
401 string actual =
f.html(keyFormatter,
names);
412 auto op = [](
const double x) {
return 2 *
x; };
413 auto g =
f.apply(op);
418 auto sq_op = [](
const double x) {
return x *
x; };
419 auto g_sq =
f.apply(sq_op);
432 auto g =
f.apply(op);