1 /* ----------------------------------------------------------------------------
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
8  * See LICENSE for the license information
10  * -------------------------------------------------------------------------- */
12 /*
13  * testTableFactor.cpp
14  *
15  * @date Feb 15, 2023
16  * @author Yoonwoo Kim
17  */
20 #include <gtsam/base/Testable.h>
26 #include <chrono>
27 #include <random>
29 using namespace std;
30 using namespace gtsam;
32 vector<double> genArr(double dropout, size_t size) {
33  random_device rd;
34  mt19937 g(rd());
35  vector<double> dropoutmask(size); // Chance of 0
37  uniform_int_distribution<> dist(1, 9);
38  auto gen = [&dist, &g]() { return dist(g); };
39  generate(dropoutmask.begin(), dropoutmask.end(), gen);
41  fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
42  shuffle(dropoutmask.begin(), dropoutmask.end(), g);
44  return dropoutmask;
45 }
47 map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
48  DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
49  vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
50  map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
52  for (auto dropout : dropouts) {
53  vector<double> arr1 = genArr(dropout, size);
54  vector<double> arr2 = genArr(dropout, size);
55  TableFactor f1(keys1, arr1);
56  TableFactor f2(keys2, arr2);
57  DecisionTreeFactor f1_dt(keys1, arr1);
58  DecisionTreeFactor f2_dt(keys2, arr2);
60  // measure time TableFactor
61  auto tb_start = chrono::high_resolution_clock::now();
62  TableFactor actual = f1 * f2;
63  auto tb_end = chrono::high_resolution_clock::now();
64  auto tb_time_diff =
65  chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
67  // measure time DT
68  auto dt_start = chrono::high_resolution_clock::now();
69  DecisionTreeFactor actual_dt = f1_dt * f2_dt;
70  auto dt_end = chrono::high_resolution_clock::now();
71  auto dt_time_diff =
72  chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
74  bool flag = true;
75  for (auto assignmentVal : actual_dt.enumerate()) {
76  flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
77  if (flag) {
78  std::cout << "something is wrong: " << std::endl;
79  assignmentVal.first.print();
80  std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl;
81  std::cout << "tb: " << actual(assignmentVal.first) << std::endl;
82  break;
83  }
84  }
85  if (flag) break;
86  measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
87  }
88  return measured_times;
89 }
91 void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
92  measured_time) {
93  for (auto&& kv : measured_time) {
94  cout << "dropout: " << kv.first
95  << " | TableFactor time: " << kv.second.first.count()
96  << " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
97  }
98 }
100 /* ************************************************************************* */
101 // Check constructors for TableFactor.
102 TEST(TableFactor, constructors) {
103  // Declare a bunch of keys
104  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
106  // Create factors
107  TableFactor f_zeros(A, {0, 0, 0, 0, 1});
108  TableFactor f1(X, {2, 8});
109  TableFactor f2(X & Y, "2 5 3 6 4 7");
110  TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
111  EXPECT_LONGS_EQUAL(1, f1.size());
112  EXPECT_LONGS_EQUAL(2, f2.size());
113  EXPECT_LONGS_EQUAL(3, f3.size());
116  values[0] = 1; // x
117  values[1] = 2; // y
118  values[2] = 1; // z
119  values[3] = 4; // a
120  EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
121  EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
122  EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
123  EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
125  // Assert that error = -log(value)
126  EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
127 }
129 /* ************************************************************************* */
130 // Check multiplication between two TableFactors.
131 TEST(TableFactor, multiplication) {
132  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
134  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
135  DiscreteDistribution prior(v1 % "1/3");
136  TableFactor f1(v0 & v1, "1 2 3 4");
137  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
138  CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
139  f1.toDecisionTreeFactor()));
140  CHECK(assert_equal(expected, f1 * prior));
142  // Multiply two factors
143  TableFactor f2(v1 & v2, "5 6 7 8");
144  TableFactor actual = f1 * f2;
145  TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
146  CHECK(assert_equal(expected2, actual));
148  DiscreteKey A(0, 3), B(1, 2), C(2, 2);
149  TableFactor f_zeros1(A & C, "0 0 0 2 0 3");
150  TableFactor f_zeros2(B & C, "4 0 0 5");
151  TableFactor actual_zeros = f_zeros1 * f_zeros2;
152  TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
153  CHECK(assert_equal(expected3, actual_zeros));
154 }
156 /* ************************************************************************* */
157 // Benchmark which compares runtime of multiplication of two TableFactors
158 // and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
160  DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
161  H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
163  // 100
164  DiscreteKeys one_1 = {A, B, C, D};
165  DiscreteKeys one_2 = {C, D, E, F};
166  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
167  measureTime(one_1, one_2, 100);
168  printTime(time_map_1);
169  // 200
170  DiscreteKeys two_1 = {A, B, C, D, F};
171  DiscreteKeys two_2 = {B, C, D, E, F};
172  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
173  measureTime(two_1, two_2, 200);
174  printTime(time_map_2);
175  // 300
176  DiscreteKeys three_1 = {A, B, C, D, G};
177  DiscreteKeys three_2 = {C, D, E, F, G};
178  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
179  measureTime(three_1, three_2, 300);
180  printTime(time_map_3);
181  // 400
182  DiscreteKeys four_1 = {A, B, C, D, F, H};
183  DiscreteKeys four_2 = {B, C, D, E, F, H};
184  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
185  measureTime(four_1, four_2, 400);
186  printTime(time_map_4);
187  // 500
188  DiscreteKeys five_1 = {A, B, C, D, I};
189  DiscreteKeys five_2 = {C, D, E, F, I};
190  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
191  measureTime(five_1, five_2, 500);
192  printTime(time_map_5);
193  // 600
194  DiscreteKeys six_1 = {A, B, C, D, F, G};
195  DiscreteKeys six_2 = {B, C, D, E, F, G};
196  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
197  measureTime(six_1, six_2, 600);
198  printTime(time_map_6);
199  // 700
200  DiscreteKeys seven_1 = {A, B, C, D, J};
201  DiscreteKeys seven_2 = {C, D, E, F, J};
202  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
203  measureTime(seven_1, seven_2, 700);
204  printTime(time_map_7);
205  // 800
206  DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
207  DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
208  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
209  measureTime(eight_1, eight_2, 800);
210  printTime(time_map_8);
211  // 900
212  DiscreteKeys nine_1 = {A, B, C, D, G, L};
213  DiscreteKeys nine_2 = {C, D, E, F, G, L};
214  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
215  measureTime(nine_1, nine_2, 900);
216  printTime(time_map_9);
217 }
219 /* ************************************************************************* */
220 // Check sum and max over frontals.
221 TEST(TableFactor, sum_max) {
222  DiscreteKey v0(0, 3), v1(1, 2);
223  TableFactor f1(v0 & v1, "1 2 3 4 5 6");
225  TableFactor expected(v1, "9 12");
226  TableFactor::shared_ptr actual = f1.sum(1);
227  CHECK(assert_equal(expected, *actual, 1e-5));
229  TableFactor expected2(v1, "5 6");
230  TableFactor::shared_ptr actual2 = f1.max(1);
231  CHECK(assert_equal(expected2, *actual2));
233  TableFactor f2(v1 & v0, "1 2 3 4 5 6");
234  TableFactor::shared_ptr actual22 = f2.sum(1);
235 }
237 /* ************************************************************************* */
238 // Check enumerate yields the correct list of assignment/value pairs.
239 TEST(TableFactor, enumerate) {
240  DiscreteKey A(12, 3), B(5, 2);
241  TableFactor f(A & B, "1 2 3 4 5 6");
242  auto actual = f.enumerate();
243  std::vector<std::pair<DiscreteValues, double>> expected;
245  for (size_t a : {0, 1, 2}) {
246  for (size_t b : {0, 1}) {
247  values[12] = a;
248  values[5] = b;
249  expected.emplace_back(values, f(values));
250  }
251  }
252  EXPECT(actual == expected);
253 }
255 /* ************************************************************************* */
256 // Check pruning of the decision tree works as expected.
257 TEST(TableFactor, Prune) {
258  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
259  TableFactor f(A & B & C, "1 5 3 7 2 6 4 8");
261  // Only keep the leaves with the top 5 values.
262  size_t maxNrAssignments = 5;
263  auto pruned5 = f.prune(maxNrAssignments);
265  // Pruned leaves should be 0
266  TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
267  EXPECT(assert_equal(expected, pruned5));
269  // Check for more extreme pruning where we only keep the top 2 leaves
270  maxNrAssignments = 2;
271  auto pruned2 = f.prune(maxNrAssignments);
272  TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
273  EXPECT(assert_equal(expected2, pruned2));
275  DiscreteKey D(4, 2);
276  TableFactor factor(
277  D & C & B & A,
278  "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
279  "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
281  TableFactor expected3(D & C & B & A,
282  "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
283  "0.999952870000 1.0 1.0 1.0 1.0");
284  maxNrAssignments = 5;
285  auto pruned3 = factor.prune(maxNrAssignments);
286  EXPECT(assert_equal(expected3, pruned3));
287 }
289 /* ************************************************************************* */
290 // Check markdown representation looks as expected.
292  DiscreteKey A(12, 3), B(5, 2);
293  TableFactor f(A & B, "1 2 3 4 5 6");
294  string expected =
295  "|A|B|value|\n"
296  "|:-:|:-:|:-:|\n"
297  "|0|0|1|\n"
298  "|0|1|2|\n"
299  "|1|0|3|\n"
300  "|1|1|4|\n"
301  "|2|0|5|\n"
302  "|2|1|6|\n";
303  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
304  string actual = f.markdown(formatter);
305  EXPECT(actual == expected);
306 }
308 /* ************************************************************************* */
309 // Check markdown representation with a value formatter.
310 TEST(TableFactor, markdownWithValueFormatter) {
311  DiscreteKey A(12, 3), B(5, 2);
312  TableFactor f(A & B, "1 2 3 4 5 6");
313  string expected =
314  "|A|B|value|\n"
315  "|:-:|:-:|:-:|\n"
316  "|Zero|-|1|\n"
317  "|Zero|+|2|\n"
318  "|One|-|3|\n"
319  "|One|+|4|\n"
320  "|Two|-|5|\n"
321  "|Two|+|6|\n";
322  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
323  TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
324  string actual = f.markdown(keyFormatter, names);
325  EXPECT(actual == expected);
326 }
328 /* ************************************************************************* */
329 // Check html representation with a value formatter.
330 TEST(TableFactor, htmlWithValueFormatter) {
331  DiscreteKey A(12, 3), B(5, 2);
332  TableFactor f(A & B, "1 2 3 4 5 6");
333  string expected =
334  "<div>\n"
335  "<table class='TableFactor'>\n"
336  " <thead>\n"
337  " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
338  " </thead>\n"
339  " <tbody>\n"
340  " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
341  " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
342  " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
343  " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
344  " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
345  " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
346  " </tbody>\n"
347  "</table>\n"
348  "</div>";
349  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
350  TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
351  string actual = f.html(keyFormatter, names);
352  EXPECT(actual == expected);
353 }
355 /* ************************************************************************* */
356 int main() {
357  TestResult tr;
358  return TestRegistry::runAllTests(tr);
359 }
360 /* ************************************************************************* */
