31 using namespace gtsam;
36 vector<double> dropoutmask(
size);
38 uniform_int_distribution<> dist(1, 9);
39 auto gen = [&dist, &
g]() {
return dist(
g); };
40 generate(dropoutmask.begin(), dropoutmask.end(), gen);
42 fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
43 shuffle(dropoutmask.begin(), dropoutmask.end(),
g);
48 map<double, pair<chrono::microseconds, chrono::microseconds>>
measureTime(
50 vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
51 map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
53 for (
auto dropout : dropouts) {
62 auto tb_start = chrono::high_resolution_clock::now();
64 auto tb_end = chrono::high_resolution_clock::now();
66 chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
69 auto dt_start = chrono::high_resolution_clock::now();
71 auto dt_end = chrono::high_resolution_clock::now();
73 chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
76 for (
auto assignmentVal : actual_dt.
enumerate()) {
77 flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
79 std::cout <<
"something is wrong: " << std::endl;
80 assignmentVal.first.print();
81 std::cout <<
"dt: " << actual_dt(assignmentVal.first) << std::endl;
82 std::cout <<
"tb: " << actual(assignmentVal.first) << std::endl;
87 measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
89 return measured_times;
92 void printTime(map<
double, pair<chrono::microseconds, chrono::microseconds>>
94 for (
auto&& kv : measured_time) {
95 cout <<
"dropout: " << kv.first
96 <<
" | TableFactor time: " << kv.second.first.count()
97 <<
" | DecisionTreeFactor time: " << kv.second.second.count() << endl;
143 "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
157 f1.toDecisionTreeFactor()));
179 DiscreteKey A(0, 5),
B(1, 2),
C(2, 5),
D(3, 2),
E(4, 5),
F(5, 2),
G(6, 3),
180 H(7, 2),
I(8, 5),
J(9, 7),
K(10, 2),
L(11, 3);
185 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
191 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
197 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
203 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
209 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
215 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
221 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
227 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
233 map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
261 auto actual =
f.enumerate();
262 std::vector<std::pair<DiscreteValues, double>>
expected;
264 for (
size_t a : {0, 1, 2}) {
265 for (
size_t b : {0, 1}) {
281 size_t maxNrAssignments = 5;
282 auto pruned5 =
f.prune(maxNrAssignments);
289 maxNrAssignments = 2;
290 auto pruned2 =
f.prune(maxNrAssignments);
297 "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
298 "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
301 "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
302 "0.999952870000 1.0 1.0 1.0 1.0");
303 maxNrAssignments = 5;
304 auto pruned3 = factor.
prune(maxNrAssignments);
341 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
342 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
343 string actual =
f.markdown(keyFormatter,
names);
354 "<table class='TableFactor'>\n"
356 " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
359 " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
360 " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
361 " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
362 " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
363 " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
364 " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
368 auto keyFormatter = [](
Key key) {
return key == 12 ?
"A" :
"B"; };
369 TableFactor::Names
names{{12, {
"Zero",
"One",
"Two"}}, {5, {
"-",
"+"}}};
370 string actual =
f.html(keyFormatter,
names);
381 auto op = [](
const double x) {
return 2 *
x; };
382 auto g =
f.apply(op);
387 auto sq_op = [](
const double x) {
return x *
x; };
388 auto g_sq =
f.apply(sq_op);
401 auto g =
f.apply(op);