testDecisionTreeFactor.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testDecisionTreeFactor.cpp
14  *
15  * @date Feb 5, 2012
16  * @author Frank Dellaert
17  * @author Duy-Nguyen Ta
18  */
19 
21 #include <gtsam/base/Testable.h>
27 #include <gtsam/inference/Key.h>
29 
30 using namespace std;
31 using namespace gtsam;
32 
33 /* ************************************************************************* */
34 TEST(DecisionTreeFactor, ConstructorsMatch) {
35  // Declare two keys
36  DiscreteKey X(0, 2), Y(1, 3);
37 
38  // Create with vector and with string
39  const std::vector<double> table{2, 5, 3, 6, 4, 7};
41  DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
43 }
44 
45 /* ************************************************************************* */
46 TEST(DecisionTreeFactor, constructors) {
47  // Declare a bunch of keys
48  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
49 
50  // Create factors
51  DecisionTreeFactor f1(X, {2, 8});
52  DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
53  DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
54  EXPECT_LONGS_EQUAL(1, f1.size());
55  EXPECT_LONGS_EQUAL(2, f2.size());
56  EXPECT_LONGS_EQUAL(3, f3.size());
57 
58  DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
59  EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
60  EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
61  EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
62 
63  // Assert that error = -log(value)
64  EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
65 
66  // Construct from DiscreteConditional
67  DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
68  DecisionTreeFactor f4(conditional);
69  EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
70 }
71 
72 /* ************************************************************************* */
74  // Declare a bunch of keys
75  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
76 
77  // Create factors
78  DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
79 
80  auto errors = f.errorTree();
81  // regression
83  {X, Y, Z},
84  vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
85  -1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
86  -4.1743873, -3.8066625, -4.3174881});
87  EXPECT(assert_equal(expected, errors, 1e-6));
88 }
89 
90 /* ************************************************************************* */
91 TEST(DecisionTreeFactor, multiplication) {
92  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
93 
94  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
95  DiscreteDistribution prior(v1 % "1/3");
96  DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
97  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
100 
101  // Multiply two factors
102  DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
103  DecisionTreeFactor actual = f1 * f2;
104  DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
105  CHECK(assert_equal(expected2, actual));
106 }
107 
108 /* ************************************************************************* */
110  DiscreteKey v0(0, 3), v1(1, 2);
111  DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
112 
113  DecisionTreeFactor expected(v1, "9 12");
114  DecisionTreeFactor::shared_ptr actual = f1.sum(1);
115  CHECK(assert_equal(expected, *actual, 1e-5));
116 
117  DecisionTreeFactor expected2(v1, "5 6");
118  DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
119  CHECK(assert_equal(expected2, *actual2));
120 
121  DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
122  DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
123 }
124 
125 /* ************************************************************************* */
126 // Check enumerate yields the correct list of assignment/value pairs.
128  DiscreteKey A(12, 3), B(5, 2);
129  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
130  auto actual = f.enumerate();
131  std::vector<std::pair<DiscreteValues, double>> expected;
133  for (size_t a : {0, 1, 2}) {
134  for (size_t b : {0, 1}) {
135  values[12] = a;
136  values[5] = b;
137  expected.emplace_back(values, f(values));
138  }
139  }
140  EXPECT(actual == expected);
141 }
142 
143 /* ************************************************************************* */
144 // Check pruning of the decision tree works as expected.
146  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
147  DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
148 
149  // Only keep the leaves with the top 5 values.
150  size_t maxNrAssignments = 5;
151  auto pruned5 = f.prune(maxNrAssignments);
152 
153  // Pruned leaves should be 0
154  DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
155  EXPECT(assert_equal(expected, pruned5));
156 
157  // Check for more extreme pruning where we only keep the top 2 leaves
158  maxNrAssignments = 2;
159  auto pruned2 = f.prune(maxNrAssignments);
160  DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
161  EXPECT(assert_equal(expected2, pruned2));
162 
163  DiscreteKey D(4, 2);
164  DecisionTreeFactor factor(
165  D & C & B & A,
166  "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
167  "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
168 
169  DecisionTreeFactor expected3(D & C & B & A,
170  "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
171  "0.999952870000 1.0 1.0 1.0 1.0");
172  maxNrAssignments = 5;
173  auto pruned3 = factor.prune(maxNrAssignments);
174  EXPECT(assert_equal(expected3, pruned3));
175 }
176 
177 /* ************************************************************************** */
178 // Asia Bayes Network
179 /* ************************************************************************** */
180 
181 #define DISABLE_DOT
182 
183 void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
184 #ifndef DISABLE_DOT
185  std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
186  auto formatter = [&](Key key) { return names[key]; };
187  f.dot(filename, formatter, true);
188 #endif
189 }
190 
193  DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
194  return p;
195 }
196 
197 /* ************************************************************************* */
198 // test Asia Joint
200  DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
201  D(7, 2);
202 
203  gttic_(asiaCPTs);
204  DecisionTreeFactor pA = create(A % "99/1");
205  DecisionTreeFactor pS = create(S % "50/50");
206  DecisionTreeFactor pT = create(T | A = "99/1 95/5");
207  DecisionTreeFactor pL = create(L | S = "99/1 90/10");
208  DecisionTreeFactor pB = create(B | S = "70/30 40/60");
209  DecisionTreeFactor pE = create((E | T, L) = "F T T T");
210  DecisionTreeFactor pX = create(X | E = "95/5 2/98");
211  DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
212 
213  // Create joint
214  gttic_(asiaJoint);
215  DecisionTreeFactor joint = pA;
216  maybeSaveDotFile(joint, "Asia-A");
217  joint = joint * pS;
218  maybeSaveDotFile(joint, "Asia-AS");
219  joint = joint * pT;
220  maybeSaveDotFile(joint, "Asia-AST");
221  joint = joint * pL;
222  maybeSaveDotFile(joint, "Asia-ASTL");
223  joint = joint * pB;
224  maybeSaveDotFile(joint, "Asia-ASTLB");
225  joint = joint * pE;
226  maybeSaveDotFile(joint, "Asia-ASTLBE");
227  joint = joint * pX;
228  maybeSaveDotFile(joint, "Asia-ASTLBEX");
229  joint = joint * pD;
230  maybeSaveDotFile(joint, "Asia-ASTLBEXD");
231 
232  // Check that discrete keys are as expected
233  EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
234 
235  // Check that summing out variables maintains the keys even if merged, as is
236  // the case with S.
237  auto noAB = joint.sum(Ordering{A.first, B.first});
238  EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
239 }
240 
241 /* ************************************************************************* */
242 TEST(DecisionTreeFactor, DotWithNames) {
243  DiscreteKey A(12, 3), B(5, 2);
244  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
245  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
246 
247  for (bool showZero : {true, false}) {
248  string actual = f.dot(formatter, showZero);
249  // pretty weak test, as ids are pointers and not stable across platforms.
250  string expected = "digraph G {";
251  EXPECT(actual.substr(0, 11) == expected);
252  }
253 }
254 
255 /* ************************************************************************* */
256 // Check markdown representation looks as expected.
258  DiscreteKey A(12, 3), B(5, 2);
259  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
260  string expected =
261  "|A|B|value|\n"
262  "|:-:|:-:|:-:|\n"
263  "|0|0|1|\n"
264  "|0|1|2|\n"
265  "|1|0|3|\n"
266  "|1|1|4|\n"
267  "|2|0|5|\n"
268  "|2|1|6|\n";
269  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
270  string actual = f.markdown(formatter);
271  EXPECT(actual == expected);
272 }
273 
274 /* ************************************************************************* */
275 // Check markdown representation with a value formatter.
276 TEST(DecisionTreeFactor, markdownWithValueFormatter) {
277  DiscreteKey A(12, 3), B(5, 2);
278  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
279  string expected =
280  "|A|B|value|\n"
281  "|:-:|:-:|:-:|\n"
282  "|Zero|-|1|\n"
283  "|Zero|+|2|\n"
284  "|One|-|3|\n"
285  "|One|+|4|\n"
286  "|Two|-|5|\n"
287  "|Two|+|6|\n";
288  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
289  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
290  {5, {"-", "+"}}};
291  string actual = f.markdown(keyFormatter, names);
292  EXPECT(actual == expected);
293 }
294 
295 /* ************************************************************************* */
296 // Check html representation with a value formatter.
297 TEST(DecisionTreeFactor, htmlWithValueFormatter) {
298  DiscreteKey A(12, 3), B(5, 2);
299  DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
300  string expected =
301  "<div>\n"
302  "<table class='DecisionTreeFactor'>\n"
303  " <thead>\n"
304  " <tr><th>A</th><th>B</th><th>value</th></tr>\n"
305  " </thead>\n"
306  " <tbody>\n"
307  " <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
308  " <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
309  " <tr><th>One</th><th>-</th><td>3</td></tr>\n"
310  " <tr><th>One</th><th>+</th><td>4</td></tr>\n"
311  " <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
312  " <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
313  " </tbody>\n"
314  "</table>\n"
315  "</div>";
316  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
317  DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}},
318  {5, {"-", "+"}}};
319  string actual = f.html(keyFormatter, names);
320  EXPECT(actual == expected);
321 }
322 
323 /* ************************************************************************* */
324 int main() {
325  TestResult tr;
326  return TestRegistry::runAllTests(tr);
327 }
328 /* ************************************************************************* */
gtsam::Signature::cpt
std::vector< double > cpt() const
Definition: Signature.cpp:69
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
gtsam::markdown
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
Definition: DiscreteValues.cpp:130
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
B
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
D
MatrixXcd D
Definition: EigenSolver_EigenSolver_MatrixType.cpp:14
test_constructor::f1
auto f1
Definition: testHybridNonlinearFactor.cpp:56
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
v0
static const double v0
Definition: testCal3DFisheye.cpp:31
EXPECT_LONGS_EQUAL
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
Testable.h
Concept check for values that can be used in unit tests.
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
b
Scalar * b
Definition: benchVecAdd.cpp:17
asiaCPTs::pX
ADT pX
Definition: testAlgebraicDecisionTree.cpp:159
gtsam::DiscreteDistribution
Definition: DiscreteDistribution.h:33
gtsam::Y
GaussianFactorGraphValuePair Y
Definition: HybridGaussianProductFactor.cpp:29
asiaCPTs::pA
ADT pA
Definition: testAlgebraicDecisionTree.cpp:153
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:410
asiaCPTs
Definition: testAlgebraicDecisionTree.cpp:149
maybeSaveDotFile
void maybeSaveDotFile(const DecisionTreeFactor &f, const string &filename)
Definition: testDecisionTreeFactor.cpp:183
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
f2
double f2(const Vector2 &x)
Definition: testNumericalDerivative.cpp:56
T
Eigen::Triplet< double > T
Definition: Tutorial_sparse_example.cpp:6
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
log
const EIGEN_DEVICE_FUNC LogReturnType log() const
Definition: ArrayCwiseUnaryOps.h:128
Ordering.h
Variable ordering for the elimination algorithm.
X
#define X
Definition: icosphere.cpp:20
TEST
TEST(DecisionTreeFactor, ConstructorsMatch)
Definition: testDecisionTreeFactor.cpp:34
create
DecisionTreeFactor create(const Signature &signature)
Definition: testDecisionTreeFactor.cpp:192
gtsam::DecisionTreeFactor::shared_ptr
std::shared_ptr< DecisionTreeFactor > shared_ptr
Definition: DecisionTreeFactor.h:50
A
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
gtsam::AlgebraicDecisionTree< Key >
Key.h
asiaCPTs::pE
ADT pE
Definition: testAlgebraicDecisionTree.cpp:158
table
ArrayXXf table(10, 4)
main
int main()
Definition: testDecisionTreeFactor.cpp:324
gttic_
#define gttic_(label)
Definition: timing.h:245
relicense.filename
filename
Definition: relicense.py:57
DiscreteFactor.h
Signature.h
signatures for conditional densities
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
L
MatrixXd L
Definition: LLT_example.cpp:6
asiaCPTs::pT
ADT pT
Definition: testAlgebraicDecisionTree.cpp:155
asiaCPTs::pS
ADT pS
Definition: testAlgebraicDecisionTree.cpp:154
Eigen::Triplet< double >
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
asiaCPTs::pD
ADT pD
Definition: testAlgebraicDecisionTree.cpp:160
gtsam::DecisionTreeFactor::sum
shared_ptr sum(size_t nrFrontals) const
Create new factor by summing all values with the same separator values.
Definition: DecisionTreeFactor.h:162
serializationTestHelpers.h
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
E
DiscreteKey E(5, 2)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
C
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
gtsam
traits
Definition: SFMdata.h:40
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
f3
double f3(double x1, double x2)
Definition: testNumericalDerivative.cpp:76
DiscreteDistribution.h
CHECK
#define CHECK(condition)
Definition: Test.h:108
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
p
float * p
Definition: Tutorial_Map_using.cpp:9
v2
Vector v2
Definition: testSerializationBase.cpp:39
f4
double f4(double x, double y, double z)
Definition: testNumericalDerivative.cpp:105
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
gtsam::Signature::discreteKeys
DiscreteKeys discreteKeys() const
Definition: Signature.cpp:55
different_sigmas::prior
const auto prior
Definition: testHybridBayesNet.cpp:238
gtsam::DiscreteFactor::discreteKeys
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Definition: DiscreteFactor.cpp:32
asiaCPTs::pB
ADT pB
Definition: testAlgebraicDecisionTree.cpp:157
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
Z
#define Z
Definition: icosphere.cpp:21
gtsam::Ordering
Definition: inference/Ordering.h:33
asiaCPTs::pL
ADT pL
Definition: testAlgebraicDecisionTree.cpp:156
DecisionTreeFactor.h
gtsam::Signature
Definition: Signature.h:54
S
DiscreteKey S(1, 2)
v1
Vector v1
Definition: testSerializationBase.cpp:38


gtsam
Author(s):
autogenerated on Sat Nov 16 2024 04:07:21