testTableFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testTableFactor.cpp
14  *
15  * @date Feb 15, 2023
16  * @author Yoonwoo Kim
17  */
18 
20 #include <gtsam/base/Testable.h>
25 
26 #include <chrono>
27 #include <random>
28 
29 using namespace std;
30 using namespace gtsam;
31 
32 vector<double> genArr(double dropout, size_t size) {
33  random_device rd;
34  mt19937 g(rd());
35  vector<double> dropoutmask(size); // Chance of 0
36 
37  uniform_int_distribution<> dist(1, 9);
38  auto gen = [&dist, &g]() { return dist(g); };
39  generate(dropoutmask.begin(), dropoutmask.end(), gen);
40 
41  fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
42  shuffle(dropoutmask.begin(), dropoutmask.end(), g);
43 
44  return dropoutmask;
45 }
46 
47 map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
48  DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
49  vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
50  map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
51 
52  for (auto dropout : dropouts) {
53  vector<double> arr1 = genArr(dropout, size);
54  vector<double> arr2 = genArr(dropout, size);
55  TableFactor f1(keys1, arr1);
56  TableFactor f2(keys2, arr2);
57  DecisionTreeFactor f1_dt(keys1, arr1);
58  DecisionTreeFactor f2_dt(keys2, arr2);
59 
60  // measure time TableFactor
61  auto tb_start = chrono::high_resolution_clock::now();
62  TableFactor actual = f1 * f2;
63  auto tb_end = chrono::high_resolution_clock::now();
64  auto tb_time_diff =
65  chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
66 
67  // measure time DT
68  auto dt_start = chrono::high_resolution_clock::now();
69  DecisionTreeFactor actual_dt = f1_dt * f2_dt;
70  auto dt_end = chrono::high_resolution_clock::now();
71  auto dt_time_diff =
72  chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
73 
74  bool flag = true;
75  for (auto assignmentVal : actual_dt.enumerate()) {
76  flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
77  if (flag) {
78  std::cout << "something is wrong: " << std::endl;
79  assignmentVal.first.print();
80  std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl;
81  std::cout << "tb: " << actual(assignmentVal.first) << std::endl;
82  break;
83  }
84  }
85  if (flag) break;
86  measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
87  }
88  return measured_times;
89 }
90 
91 void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
92  measured_time) {
93  for (auto&& kv : measured_time) {
94  cout << "dropout: " << kv.first
95  << " | TableFactor time: " << kv.second.first.count()
96  << " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
97  }
98 }
99 
100 /* ************************************************************************* */
101 // Check constructors for TableFactor.
102 TEST(TableFactor, constructors) {
103  // Declare a bunch of keys
104  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
105 
106  // Create factors
107  TableFactor f_zeros(A, {0, 0, 0, 0, 1});
108  TableFactor f1(X, {2, 8});
109  TableFactor f2(X & Y, "2 5 3 6 4 7");
110  TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
111  EXPECT_LONGS_EQUAL(1, f1.size());
112  EXPECT_LONGS_EQUAL(2, f2.size());
113  EXPECT_LONGS_EQUAL(3, f3.size());
114 
116  values[0] = 1; // x
117  values[1] = 2; // y
118  values[2] = 1; // z
119  values[3] = 4; // a
120  EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
121  EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
122  EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
123  EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
124 
125  // Assert that error = -log(value)
126  EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
127 }
128 
129 /* ************************************************************************* */
130 // Check multiplication between two TableFactors.
131 TEST(TableFactor, multiplication) {
132  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
133 
134  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
135  DiscreteDistribution prior(v1 % "1/3");
136  TableFactor f1(v0 & v1, "1 2 3 4");
137  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
138  CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
139  f1.toDecisionTreeFactor()));
140  CHECK(assert_equal(expected, f1 * prior));
141 
142  // Multiply two factors
143  TableFactor f2(v1 & v2, "5 6 7 8");
144  TableFactor actual = f1 * f2;
145  TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
146  CHECK(assert_equal(expected2, actual));
147 
148  DiscreteKey A(0, 3), B(1, 2), C(2, 2);
149  TableFactor f_zeros1(A & C, "0 0 0 2 0 3");
150  TableFactor f_zeros2(B & C, "4 0 0 5");
151  TableFactor actual_zeros = f_zeros1 * f_zeros2;
152  TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
153  CHECK(assert_equal(expected3, actual_zeros));
154 }
155 
156 /* ************************************************************************* */
157 // Benchmark which compares runtime of multiplication of two TableFactors
158 // and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
160  DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
161  H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
162 
163  // 100
164  DiscreteKeys one_1 = {A, B, C, D};
165  DiscreteKeys one_2 = {C, D, E, F};
166  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
167  measureTime(one_1, one_2, 100);
168  printTime(time_map_1);
169  // 200
170  DiscreteKeys two_1 = {A, B, C, D, F};
171  DiscreteKeys two_2 = {B, C, D, E, F};
172  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
173  measureTime(two_1, two_2, 200);
174  printTime(time_map_2);
175  // 300
176  DiscreteKeys three_1 = {A, B, C, D, G};
177  DiscreteKeys three_2 = {C, D, E, F, G};
178  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
179  measureTime(three_1, three_2, 300);
180  printTime(time_map_3);
181  // 400
182  DiscreteKeys four_1 = {A, B, C, D, F, H};
183  DiscreteKeys four_2 = {B, C, D, E, F, H};
184  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
185  measureTime(four_1, four_2, 400);
186  printTime(time_map_4);
187  // 500
188  DiscreteKeys five_1 = {A, B, C, D, I};
189  DiscreteKeys five_2 = {C, D, E, F, I};
190  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
191  measureTime(five_1, five_2, 500);
192  printTime(time_map_5);
193  // 600
194  DiscreteKeys six_1 = {A, B, C, D, F, G};
195  DiscreteKeys six_2 = {B, C, D, E, F, G};
196  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
197  measureTime(six_1, six_2, 600);
198  printTime(time_map_6);
199  // 700
200  DiscreteKeys seven_1 = {A, B, C, D, J};
201  DiscreteKeys seven_2 = {C, D, E, F, J};
202  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
203  measureTime(seven_1, seven_2, 700);
204  printTime(time_map_7);
205  // 800
206  DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
207  DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
208  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
209  measureTime(eight_1, eight_2, 800);
210  printTime(time_map_8);
211  // 900
212  DiscreteKeys nine_1 = {A, B, C, D, G, L};
213  DiscreteKeys nine_2 = {C, D, E, F, G, L};
214  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
215  measureTime(nine_1, nine_2, 900);
216  printTime(time_map_9);
217 }
218 
219 /* ************************************************************************* */
220 // Check sum and max over frontals.
221 TEST(TableFactor, sum_max) {
222  DiscreteKey v0(0, 3), v1(1, 2);
223  TableFactor f1(v0 & v1, "1 2 3 4 5 6");
224 
225  TableFactor expected(v1, "9 12");
226  TableFactor::shared_ptr actual = f1.sum(1);
227  CHECK(assert_equal(expected, *actual, 1e-5));
228 
229  TableFactor expected2(v1, "5 6");
230  TableFactor::shared_ptr actual2 = f1.max(1);
231  CHECK(assert_equal(expected2, *actual2));
232 
233  TableFactor f2(v1 & v0, "1 2 3 4 5 6");
234  TableFactor::shared_ptr actual22 = f2.sum(1);
235 }
236 
237 /* ************************************************************************* */
238 // Check enumerate yields the correct list of assignment/value pairs.
239 TEST(TableFactor, enumerate) {
240  DiscreteKey A(12, 3), B(5, 2);
241  TableFactor f(A & B, "1 2 3 4 5 6");
242  auto actual = f.enumerate();
243  std::vector<std::pair<DiscreteValues, double>> expected;
245  for (size_t a : {0, 1, 2}) {
246  for (size_t b : {0, 1}) {
247  values[12] = a;
248  values[5] = b;
249  expected.emplace_back(values, f(values));
250  }
251  }
252  EXPECT(actual == expected);
253 }
254 
255 /* ************************************************************************* */
256 // Check pruning of the decision tree works as expected.
257 TEST(TableFactor, Prune) {
258  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
259  TableFactor f(A & B & C, "1 5 3 7 2 6 4 8");
260 
261  // Only keep the leaves with the top 5 values.
262  size_t maxNrAssignments = 5;
263  auto pruned5 = f.prune(maxNrAssignments);
264 
265  // Pruned leaves should be 0
266  TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
267  EXPECT(assert_equal(expected, pruned5));
268 
269  // Check for more extreme pruning where we only keep the top 2 leaves
270  maxNrAssignments = 2;
271  auto pruned2 = f.prune(maxNrAssignments);
272  TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
273  EXPECT(assert_equal(expected2, pruned2));
274 
275  DiscreteKey D(4, 2);
276  TableFactor factor(
277  D & C & B & A,
278  "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
279  "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
280 
281  TableFactor expected3(D & C & B & A,
282  "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
283  "0.999952870000 1.0 1.0 1.0 1.0");
284  maxNrAssignments = 5;
285  auto pruned3 = factor.prune(maxNrAssignments);
286  EXPECT(assert_equal(expected3, pruned3));
287 }
288 
289 /* ************************************************************************* */
290 // Check markdown representation looks as expected.
292  DiscreteKey A(12, 3), B(5, 2);
293  TableFactor f(A & B, "1 2 3 4 5 6");
294  string expected =
295  "|A|B|value|\n"
296  "|:-:|:-:|:-:|\n"
297  "|0|0|1|\n"
298  "|0|1|2|\n"
299  "|1|0|3|\n"
300  "|1|1|4|\n"
301  "|2|0|5|\n"
302  "|2|1|6|\n";
303  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
304  string actual = f.markdown(formatter);
305  EXPECT(actual == expected);
306 }
307 
308 /* ************************************************************************* */
309 // Check markdown representation with a value formatter.
310 TEST(TableFactor, markdownWithValueFormatter) {
311  DiscreteKey A(12, 3), B(5, 2);
312  TableFactor f(A & B, "1 2 3 4 5 6");
313  string expected =
314  "|A|B|value|\n"
315  "|:-:|:-:|:-:|\n"
316  "|Zero|-|1|\n"
317  "|Zero|+|2|\n"
318  "|One|-|3|\n"
319  "|One|+|4|\n"
320  "|Two|-|5|\n"
321  "|Two|+|6|\n";
322  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
323  TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
324  string actual = f.markdown(keyFormatter, names);
325  EXPECT(actual == expected);
326 }
327 
328 /* ************************************************************************* */
329 // Check html representation with a value formatter.
330 TEST(TableFactor, htmlWithValueFormatter) {
331  DiscreteKey A(12, 3), B(5, 2);
332  TableFactor f(A & B, "1 2 3 4 5 6");
333  string expected =
334  "<div>\n"
335  "<table class='TableFactor'>\n"
336  " <thead>\n"
337  " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
338  " </thead>\n"
339  " <tbody>\n"
340  " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
341  " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
342  " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
343  " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
344  " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
345  " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
346  " </tbody>\n"
347  "</table>\n"
348  "</div>";
349  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
350  TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
351  string actual = f.html(keyFormatter, names);
352  EXPECT(actual == expected);
353 }
354 
355 /* ************************************************************************* */
356 int main() {
357  TestResult tr;
358  return TestRegistry::runAllTests(tr);
359 }
360 /* ************************************************************************* */
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
const gtsam::Symbol key('X', 0)
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
#define CHECK(condition)
Definition: Test.h:108
const char Y
shared_ptr max(size_t nrFrontals) const
Create new factor by maximizing over all values with the same separator.
Definition: TableFactor.h:208
TEST(TableFactor, constructors)
#define I
Definition: main.h:112
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
Key F(std::uint64_t j)
Vector v2
Scalar * b
Definition: benchVecAdd.cpp:17
Concept check for values that can be used in unit tests.
static int runAllTests(TestResult &result)
Vector v1
DecisionTreeFactor toDecisionTreeFactor() const override
Convert into a decisiontree.
signatures for conditional densities
JacobiRotation< float > G
Point2 prior(const Point2 &x)
Prior on a single pose.
Definition: simulated2D.h:88
Matrix expected
Definition: testMatrix.cpp:971
size_t size() const
Definition: Factor.h:159
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
leaf::MyValues values
static Cal3_S2 K(500, 500, 0.1, 640/2, 480/2)
MatrixXd L
Definition: LLT_example.cpp:6
Definition: BFloat16.h:88
std::shared_ptr< TableFactor > shared_ptr
Definition: TableFactor.h:94
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy y set format x g set format y g set format x2 g set format y2 g set format z g set angles radians set nogrid set key title set key left top Right noreverse box linetype linewidth samplen spacing width set nolabel set noarrow set nologscale set logscale x set set pointsize set encoding default set nopolar set noparametric set set set set surface set nocontour set clabel set mapping cartesian set nohidden3d set cntrparam order set cntrparam linear set cntrparam levels auto set cntrparam points set size set set xzeroaxis lt lw set x2zeroaxis lt lw set yzeroaxis lt lw set y2zeroaxis lt lw set tics in set ticslevel set tics set mxtics default set mytics default set mx2tics default set my2tics default set xtics border mirror norotate autofreq set ytics border mirror norotate autofreq set ztics border nomirror norotate autofreq set nox2tics set noy2tics set timestamp bottom norotate set rrange [*:*] noreverse nowriteback set trange [*:*] noreverse nowriteback set urange [*:*] noreverse nowriteback set vrange [*:*] noreverse nowriteback set xlabel matrix size set x2label set timefmt d m y n H
EIGEN_DEVICE_FUNC const LogReturnType log() const
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
void g(const string &key, int i)
Definition: testBTree.cpp:41
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
const KeyFormatter & formatter
void printTime(map< double, pair< chrono::microseconds, chrono::microseconds >> measured_time)
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
#define Z
Definition: icosphere.cpp:21
#define EXPECT(condition)
Definition: Test.h:150
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
JacobiRotation< float > J
TableFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
traits
Definition: chartTesting.h:28
DiscreteKey E(5, 2)
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
static const double v0
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
double f3(double x1, double x2)
map< double, pair< chrono::microseconds, chrono::microseconds > > measureTime(DiscreteKeys keys1, DiscreteKeys keys2, size_t size)
std::vector< std::pair< DiscreteValues, double > > enumerate() const
Enumerate all values into a map from values to double.
int main()
#define X
Definition: icosphere.cpp:20
shared_ptr sum(size_t nrFrontals) const
Create new factor by summing all values with the same separator values.
Definition: TableFactor.h:198
vector< double > genArr(double dropout, size_t size)
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:39:43