testAlgebraicDecisionTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDecisionTree.cpp
14  * @brief Develop DecisionTree
15  * @author Frank Dellaert
16  * @date Mar 6, 2011
17  */
18 
19 #include <gtsam/base/Testable.h>
20 #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
22 // headers first to make sure no missing headers
23 //#define GTSAM_DT_NO_PRUNING
25 #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
26 #define DISABLE_TIMING
27 
29 #include <gtsam/base/timing.h>
31 
32 using namespace std;
33 using namespace gtsam;
34 
35 /* ************************************************************************** */
37 
38 // traits
39 namespace gtsam {
40 template <>
41 struct traits<ADT> : public Testable<ADT> {};
42 } // namespace gtsam
43 
44 #define DISABLE_DOT
45 
46 template <typename T>
47 void dot(const T& f, const string& filename) {
48 #ifndef DISABLE_DOT
49  f.dot(filename);
50 #endif
51 }
52 
71 /* ************************************************************************** */
72 // instrumented operators
73 /* ************************************************************************** */
74 size_t muls = 0, adds = 0;
75 double elapsed;
76 void resetCounts() {
77  muls = 0;
78  adds = 0;
79 }
80 void printCounts(const string& s) {
81 #ifndef DISABLE_TIMING
82 cout << s << ": " << std::setw(3) << muls << " muls, " <<
83  std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms."
84  << endl;
85 #endif
86  resetCounts();
87 }
88 double mul(const double& a, const double& b) {
89  muls++;
90  return a * b;
91 }
92 double add_(const double& a, const double& b) {
93  adds++;
94  return a + b;
95 }
96 
97 /* ************************************************************************** */
98 // test ADT
99 TEST(ADT, example3) {
100  // Create labels
101  DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
102 
103  // Literals
104  ADT a(A, 0.5, 0.5);
105  ADT notb(B, 1, 0);
106  ADT c(C, 0.1, 0.9);
107  ADT d(D, 0.1, 0.9);
108  ADT note(E, 0.9, 0.1);
109 
110  ADT cnotb = c * notb;
111  dot(cnotb, "ADT-cnotb");
112 
113  // a.print("a: ");
114  // cnotb.print("cnotb: ");
115  ADT acnotb = a * cnotb;
116  // acnotb.print("acnotb: ");
117  // acnotb.printCache("acnotb Cache:");
118 
119  dot(acnotb, "ADT-acnotb");
120 
121  ADT big = apply(apply(d, note, &mul), acnotb, &add_);
122  dot(big, "ADT-big");
123 }
124 
125 /* ************************************************************************** */
126 // Asia Bayes Network
127 /* ************************************************************************** */
128 
130 ADT create(const Signature& signature) {
131  ADT p(signature.discreteKeys(), signature.cpt());
132  static size_t count = 0;
133  const DiscreteKey& key = signature.key();
134  std::stringstream ss;
135  ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first;
136  string DOTfile = ss.str();
137  dot(p, DOTfile);
138  return p;
139 }
140 
141 /* ************************************************************************* */
142 // test Asia Joint
143 TEST(ADT, joint) {
144  DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
145  D(7, 2);
146 
147  resetCounts();
148  gttic_(asiaCPTs);
149  ADT pA = create(A % "99/1");
150  ADT pS = create(S % "50/50");
151  ADT pT = create(T | A = "99/1 95/5");
152  ADT pL = create(L | S = "99/1 90/10");
153  ADT pB = create(B | S = "70/30 40/60");
154  ADT pE = create((E | T, L) = "F T T T");
155  ADT pX = create(X | E = "95/5 2/98");
156  ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
157  gttoc_(asiaCPTs);
158  tictoc_getNode(asiaCPTsNode, asiaCPTs);
159  elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
160  tictoc_reset_();
161  printCounts("Asia CPTs");
162 
163  // Create joint
164  resetCounts();
165  gttic_(asiaJoint);
166  ADT joint = pA;
167  dot(joint, "Asia-A");
168  joint = apply(joint, pS, &mul);
169  dot(joint, "Asia-AS");
170  joint = apply(joint, pT, &mul);
171  dot(joint, "Asia-AST");
172  joint = apply(joint, pL, &mul);
173  dot(joint, "Asia-ASTL");
174  joint = apply(joint, pB, &mul);
175  dot(joint, "Asia-ASTLB");
176  joint = apply(joint, pE, &mul);
177  dot(joint, "Asia-ASTLBE");
178  joint = apply(joint, pX, &mul);
179  dot(joint, "Asia-ASTLBEX");
180  joint = apply(joint, pD, &mul);
181  dot(joint, "Asia-ASTLBEXD");
182  EXPECT_LONGS_EQUAL(346, muls);
183  gttoc_(asiaJoint);
184  tictoc_getNode(asiaJointNode, asiaJoint);
185  elapsed = asiaJointNode->secs() + asiaJointNode->wall();
186  tictoc_reset_();
187  printCounts("Asia joint");
188 
189  // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
190  ADT pASTL = pA;
191  pASTL = apply(pASTL, pS, &mul);
192  pASTL = apply(pASTL, pT, &mul);
193  pASTL = apply(pASTL, pL, &mul);
194 
195  // test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
196  ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
197  EXPECT(assert_equal(pA, fAa));
198  ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);
199  EXPECT(assert_equal(pA, fAb));
200 }
201 
202 /* ************************************************************************* */
203 // test Inference with joint
204 TEST(ADT, inference) {
205  DiscreteKey A(0, 2), D(1, 2), //
206  B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
207 
208  resetCounts();
209  gttic_(infCPTs);
210  ADT pA = create(A % "99/1");
211  ADT pS = create(S % "50/50");
212  ADT pT = create(T | A = "99/1 95/5");
213  ADT pL = create(L | S = "99/1 90/10");
214  ADT pB = create(B | S = "70/30 40/60");
215  ADT pE = create((E | T, L) = "F T T T");
216  ADT pX = create(X | E = "95/5 2/98");
217  ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
218  gttoc_(infCPTs);
219  tictoc_getNode(infCPTsNode, infCPTs);
220  elapsed = infCPTsNode->secs() + infCPTsNode->wall();
221  tictoc_reset_();
222  // printCounts("Inference CPTs");
223 
224  // Create joint
225  resetCounts();
226  gttic_(asiaProd);
227  ADT joint = pA;
228  dot(joint, "Joint-Product-A");
229  joint = apply(joint, pS, &mul);
230  dot(joint, "Joint-Product-AS");
231  joint = apply(joint, pT, &mul);
232  dot(joint, "Joint-Product-AST");
233  joint = apply(joint, pL, &mul);
234  dot(joint, "Joint-Product-ASTL");
235  joint = apply(joint, pB, &mul);
236  dot(joint, "Joint-Product-ASTLB");
237  joint = apply(joint, pE, &mul);
238  dot(joint, "Joint-Product-ASTLBE");
239  joint = apply(joint, pX, &mul);
240  dot(joint, "Joint-Product-ASTLBEX");
241  joint = apply(joint, pD, &mul);
242  dot(joint, "Joint-Product-ASTLBEXD");
243  EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
244  gttoc_(asiaProd);
245  tictoc_getNode(asiaProdNode, asiaProd);
246  elapsed = asiaProdNode->secs() + asiaProdNode->wall();
247  tictoc_reset_();
248  printCounts("Asia product");
249 
250  resetCounts();
251  gttic_(asiaSum);
252  ADT marginal = joint;
253  marginal = marginal.combine(X, &add_);
254  dot(marginal, "Joint-Sum-ADBLEST");
255  marginal = marginal.combine(T, &add_);
256  dot(marginal, "Joint-Sum-ADBLES");
257  marginal = marginal.combine(S, &add_);
258  dot(marginal, "Joint-Sum-ADBLE");
259  marginal = marginal.combine(E, &add_);
260  dot(marginal, "Joint-Sum-ADBL");
261  EXPECT_LONGS_EQUAL(161, (long)adds);
262  gttoc_(asiaSum);
263  tictoc_getNode(asiaSumNode, asiaSum);
264  elapsed = asiaSumNode->secs() + asiaSumNode->wall();
265  tictoc_reset_();
266  printCounts("Asia sum");
267 }
268 
269 /* ************************************************************************* */
270 TEST(ADT, factor_graph) {
271  DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
272 
273  resetCounts();
274  gttic_(createCPTs);
275  ADT pS = create(S % "50/50");
276  ADT pT = create(T % "95/5");
277  ADT pL = create(L | S = "99/1 90/10");
278  ADT pE = create((E | T, L) = "F T T T");
279  ADT pX = create(X | E = "95/5 2/98");
280  ADT pD = create(B | E = "1/8 7/9");
281  ADT pB = create(B | S = "70/30 40/60");
282  gttoc_(createCPTs);
283  tictoc_getNode(createCPTsNode, createCPTs);
284  elapsed = createCPTsNode->secs() + createCPTsNode->wall();
285  tictoc_reset_();
286  // printCounts("Create CPTs");
287 
288  // Create joint
289  resetCounts();
290  gttic_(asiaFG);
291  ADT fg = pS;
292  fg = apply(fg, pT, &mul);
293  fg = apply(fg, pL, &mul);
294  fg = apply(fg, pB, &mul);
295  fg = apply(fg, pE, &mul);
296  fg = apply(fg, pX, &mul);
297  fg = apply(fg, pD, &mul);
298  dot(fg, "FactorGraph");
299  EXPECT_LONGS_EQUAL(158, (long)muls);
300  gttoc_(asiaFG);
301  tictoc_getNode(asiaFGNode, asiaFG);
302  elapsed = asiaFGNode->secs() + asiaFGNode->wall();
303  tictoc_reset_();
304  printCounts("Asia FG");
305 
306  resetCounts();
307  gttic_(marg);
308  fg = fg.combine(X, &add_);
309  dot(fg, "Marginalized-6X");
310  fg = fg.combine(T, &add_);
311  dot(fg, "Marginalized-5T");
312  fg = fg.combine(S, &add_);
313  dot(fg, "Marginalized-4S");
314  fg = fg.combine(E, &add_);
315  dot(fg, "Marginalized-3E");
316  fg = fg.combine(L, &add_);
317  dot(fg, "Marginalized-2L");
318  LONGS_EQUAL(49, adds);
319  gttoc_(marg);
320  tictoc_getNode(margNode, marg);
321  elapsed = margNode->secs() + margNode->wall();
322  tictoc_reset_();
323  printCounts("marginalize");
324 
325  // BLESTX
326 
327  // Eliminate X
328  resetCounts();
329  gttic_(elimX);
330  ADT fE = pX;
331  dot(fE, "Eliminate-01-fEX");
332  fE = fE.combine(X, &add_);
333  dot(fE, "Eliminate-02-fE");
334  gttoc_(elimX);
335  tictoc_getNode(elimXNode, elimX);
336  elapsed = elimXNode->secs() + elimXNode->wall();
337  tictoc_reset_();
338  printCounts("Eliminate X");
339 
340  // Eliminate T
341  resetCounts();
342  gttic_(elimT);
343  ADT fLE = pT;
344  fLE = apply(fLE, pE, &mul);
345  dot(fLE, "Eliminate-03-fLET");
346  fLE = fLE.combine(T, &add_);
347  dot(fLE, "Eliminate-04-fLE");
348  gttoc_(elimT);
349  tictoc_getNode(elimTNode, elimT);
350  elapsed = elimTNode->secs() + elimTNode->wall();
351  tictoc_reset_();
352  printCounts("Eliminate T");
353 
354  // Eliminate S
355  resetCounts();
356  gttic_(elimS);
357  ADT fBL = pS;
358  fBL = apply(fBL, pL, &mul);
359  fBL = apply(fBL, pB, &mul);
360  dot(fBL, "Eliminate-05-fBLS");
361  fBL = fBL.combine(S, &add_);
362  dot(fBL, "Eliminate-06-fBL");
363  gttoc_(elimS);
364  tictoc_getNode(elimSNode, elimS);
365  elapsed = elimSNode->secs() + elimSNode->wall();
366  tictoc_reset_();
367  printCounts("Eliminate S");
368 
369  // Eliminate E
370  resetCounts();
371  gttic_(elimE);
372  ADT fBL2 = fE;
373  fBL2 = apply(fBL2, fLE, &mul);
374  fBL2 = apply(fBL2, pD, &mul);
375  dot(fBL2, "Eliminate-07-fBLE");
376  fBL2 = fBL2.combine(E, &add_);
377  dot(fBL2, "Eliminate-08-fBL2");
378  gttoc_(elimE);
379  tictoc_getNode(elimENode, elimE);
380  elapsed = elimENode->secs() + elimENode->wall();
381  tictoc_reset_();
382  printCounts("Eliminate E");
383 
384  // Eliminate L
385  resetCounts();
386  gttic_(elimL);
387  ADT fB = fBL;
388  fB = apply(fB, fBL2, &mul);
389  dot(fB, "Eliminate-09-fBL");
390  fB = fB.combine(L, &add_);
391  dot(fB, "Eliminate-10-fB");
392  gttoc_(elimL);
393  tictoc_getNode(elimLNode, elimL);
394  elapsed = elimLNode->secs() + elimLNode->wall();
395  tictoc_reset_();
396  printCounts("Eliminate L");
397 }
398 
399 /* ************************************************************************* */
400 // test equality
401 TEST(ADT, equality_noparser) {
402  const DiscreteKey A(0, 2), B(1, 2);
403  const Signature::Row rA{80, 20}, rB{60, 40};
404  const Signature::Table tableA{rA}, tableB{rB};
405 
406  // Check straight equality
407  ADT pA1 = create(A % tableA);
408  ADT pA2 = create(A % tableA);
409  EXPECT(pA1.equals(pA2)); // should be equal
410 
411  // Check equality after apply
412  ADT pB = create(B % tableB);
413  ADT pAB1 = apply(pA1, pB, &mul);
414  ADT pAB2 = apply(pB, pA1, &mul);
415  EXPECT(pAB2.equals(pAB1));
416 }
417 
418 /* ************************************************************************* */
419 // test equality
420 TEST(ADT, equality_parser) {
421  DiscreteKey A(0, 2), B(1, 2);
422  // Check straight equality
423  ADT pA1 = create(A % "80/20");
424  ADT pA2 = create(A % "80/20");
425  EXPECT(pA1.equals(pA2)); // should be equal
426 
427  // Check equality after apply
428  ADT pB = create(B % "60/40");
429  ADT pAB1 = apply(pA1, pB, &mul);
430  ADT pAB2 = apply(pB, pA1, &mul);
431  EXPECT(pAB2.equals(pAB1));
432 }
433 
434 /* ************************************************************************** */
435 // Factor graph construction
436 // test constructor from strings
438  DiscreteKey v0(0, 2), v1(1, 3);
439  DiscreteValues x00, x01, x02, x10, x11, x12;
440  x00[0] = 0, x00[1] = 0;
441  x01[0] = 0, x01[1] = 1;
442  x02[0] = 0, x02[1] = 2;
443  x10[0] = 1, x10[1] = 0;
444  x11[0] = 1, x11[1] = 1;
445  x12[0] = 1, x12[1] = 2;
446 
447  ADT f1(v0 & v1, "0 1 2 3 4 5");
448  EXPECT_DOUBLES_EQUAL(0, f1(x00), 1e-9);
449  EXPECT_DOUBLES_EQUAL(1, f1(x01), 1e-9);
450  EXPECT_DOUBLES_EQUAL(2, f1(x02), 1e-9);
451  EXPECT_DOUBLES_EQUAL(3, f1(x10), 1e-9);
452  EXPECT_DOUBLES_EQUAL(4, f1(x11), 1e-9);
453  EXPECT_DOUBLES_EQUAL(5, f1(x12), 1e-9);
454 
455  ADT f2(v1 & v0, "0 1 2 3 4 5");
456  EXPECT_DOUBLES_EQUAL(0, f2(x00), 1e-9);
457  EXPECT_DOUBLES_EQUAL(2, f2(x01), 1e-9);
458  EXPECT_DOUBLES_EQUAL(4, f2(x02), 1e-9);
459  EXPECT_DOUBLES_EQUAL(1, f2(x10), 1e-9);
460  EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
461  EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
462 
463  DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
464  vector<double> table(5 * 4 * 3 * 2);
465  double x = 0;
466  for (double& t : table) t = x++;
467  ADT f3(z0 & z1 & z2 & z3, table);
468  DiscreteValues assignment;
469  assignment[0] = 0;
470  assignment[1] = 0;
471  assignment[2] = 0;
472  assignment[3] = 1;
473  EXPECT_DOUBLES_EQUAL(1, f3(assignment), 1e-9);
474 }
475 
476 /* ************************************************************************* */
477 // test conversion to integer indices
478 // Only works if DiscreteKeys are binary, as size_t has binary cardinality!
479 TEST(ADT, conversion) {
480  DiscreteKey X(0, 2), Y(1, 2);
481  ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
482  dot(fDiscreteKey, "conversion-f1");
483 
484  std::map<Key, Key> keyMap;
485  keyMap[0] = 5;
486  keyMap[1] = 2;
487 
488  AlgebraicDecisionTree<Key> fIndexKey(fDiscreteKey, keyMap);
489  // f1.print("f1");
490  // f2.print("f2");
491  dot(fIndexKey, "conversion-f2");
492 
493  DiscreteValues x00, x01, x02, x10, x11, x12;
494  x00[5] = 0, x00[2] = 0;
495  x01[5] = 0, x01[2] = 1;
496  x10[5] = 1, x10[2] = 0;
497  x11[5] = 1, x11[2] = 1;
498  EXPECT_DOUBLES_EQUAL(0.2, fIndexKey(x00), 1e-9);
499  EXPECT_DOUBLES_EQUAL(0.5, fIndexKey(x01), 1e-9);
500  EXPECT_DOUBLES_EQUAL(0.3, fIndexKey(x10), 1e-9);
501  EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
502 }
503 
504 /* ************************************************************************** */
505 // test operations in elimination
506 TEST(ADT, elimination) {
507  DiscreteKey A(0, 2), B(1, 3), C(2, 2);
508  ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
509  dot(f1, "elimination-f1");
510 
511  {
512  // sum out lower key
513  ADT actualSum = f1.sum(C);
514  ADT expectedSum(A & B, "3 7 11 9 6 10");
515  CHECK(assert_equal(expectedSum, actualSum));
516 
517  // normalize
518  ADT actual = f1 / actualSum;
519  const vector<double> cpt{
520  1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
521  1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
522  ADT expected(A & B & C, cpt);
523  CHECK(assert_equal(expected, actual));
524  }
525 
526  {
527  // sum out lower 2 keys
528  ADT actualSum = f1.sum(C).sum(B);
529  ADT expectedSum(A, 21, 25);
530  CHECK(assert_equal(expectedSum, actualSum));
531 
532  // normalize
533  ADT actual = f1 / actualSum;
534  const vector<double> cpt{
535  1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
536  1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25};
537  ADT expected(A & B & C, cpt);
538  CHECK(assert_equal(expected, actual));
539  }
540 }
541 
542 /* ************************************************************************** */
543 // Test non-commutative op
544 TEST(ADT, div) {
545  DiscreteKey A(0, 2), B(1, 2);
546 
547  // Literals
548  ADT a(A, 8, 16);
549  ADT b(B, 2, 4);
550  ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
551  ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
552  EXPECT(assert_equal(expected_a_div_b, a / b));
553  EXPECT(assert_equal(expected_b_div_a, b / a));
554 }
555 
556 /* ************************************************************************** */
557 // test zero shortcut
559  DiscreteKey A(0, 2), B(1, 2);
560 
561  // Literals
562  ADT a(A, 0, 1);
563  ADT notb(B, 1, 0);
564  ADT anotb = a * notb;
565  // GTSAM_PRINT(anotb);
566  DiscreteValues x00, x01, x10, x11;
567  x00[0] = 0, x00[1] = 0;
568  x01[0] = 0, x01[1] = 1;
569  x10[0] = 1, x10[1] = 0;
570  x11[0] = 1, x11[1] = 1;
571  EXPECT_DOUBLES_EQUAL(0, anotb(x00), 1e-9);
572  EXPECT_DOUBLES_EQUAL(0, anotb(x01), 1e-9);
573  EXPECT_DOUBLES_EQUAL(1, anotb(x10), 1e-9);
574  EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9);
575 }
576 
577 /* ************************************************************************* */
578 int main() {
579  TestResult tr;
580  return TestRegistry::runAllTests(tr);
581 }
582 /* ************************************************************************* */
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
const gtsam::Symbol key('X', 0)
AlgebraicDecisionTree sum(const L &label, size_t cardinality) const
#define tictoc_getNode(variable, label)
Definition: timing.h:276
bool equals(const AlgebraicDecisionTree &other, double tol=1e-9) const
Equality method customized to value type double.
#define CHECK(condition)
Definition: Test.h:108
const char Y
void printCounts(const string &s)
double add_(const double &a, const double &b)
Scalar * b
Definition: benchVecAdd.cpp:17
#define gttic_(label)
Definition: timing.h:245
Concept check for values that can be used in unit tests.
static const Unit3 z2
static int runAllTests(TestResult &result)
void tictoc_reset_()
Definition: timing.h:282
Vector v1
signatures for conditional densities
Point2 pB(size_t i)
Matrix expected
Definition: testMatrix.cpp:971
double mul(const double &a, const double &b)
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
Apply unary operator op to DecisionTree f.
Definition: DecisionTree.h:398
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:296
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
AlgebraicDecisionTree< Key > ADT
MatrixXd L
Definition: LLT_example.cpp:6
Definition: BFloat16.h:88
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
const DiscreteKey & key() const
Definition: Signature.h:115
DiscreteKey S(1, 2)
Algebraic Decision Trees.
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
DiscreteKeys discreteKeys() const
Definition: Signature.cpp:55
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
std::vector< Row > Table
Definition: Signature.h:60
static const Unit3 z3
#define EXPECT(condition)
Definition: Test.h:150
Eigen::Triplet< double > T
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
RealScalar s
double elapsed
std::vector< double > Row
Definition: Signature.h:59
static std::stringstream ss
Definition: testBTree.cpp:31
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:134
traits
Definition: chartTesting.h:28
std::vector< double > cpt() const
Definition: Signature.cpp:69
Point2 pA(size_t i)
specialized key for discrete variables
DiscreteKey E(5, 2)
ADT create(const Signature &signature)
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
ArrayXXf table(10, 4)
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
static const double v0
float * p
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
double f3(double x1, double x2)
TEST(ADT, example3)
#define X
Definition: icosphere.cpp:20
#define gttoc_(label)
Definition: timing.h:250
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
void resetCounts()
Timing utilities.
Point2 t(10, 10)
void dot(const T &f, const string &filename)


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:37:47