testDiscreteConditional.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteConditional.cpp
14  * @brief unit tests for DiscreteConditional
15  * @author Duy-Nguyen Ta
16  * @author Frank dellaert
17  * @date Feb 14, 2011
18  */
19 
24 #include <gtsam/inference/Symbol.h>
25 
26 
27 using namespace std;
28 using namespace gtsam;
29 
30 /* ************************************************************************* */
31 TEST(DiscreteConditional, constructors) {
32  DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
33 
34  DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
35  EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals()));
36  EXPECT_LONGS_EQUAL(2, *(actual.beginParents()));
37  EXPECT(actual.endParents() == actual.end());
38  EXPECT(actual.endFrontals() == actual.beginParents());
39 
40  DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
41  DiscreteConditional expected1(1, f1);
42  EXPECT(assert_equal(expected1, actual, 1e-9));
43 
45  X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
46  DiscreteConditional actual2(1, f2);
47  DecisionTreeFactor expected2 = f2 / *f2.sum(1);
48  EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
49 }
50 
51 /* ************************************************************************* */
52 TEST(DiscreteConditional, constructors_alt_interface) {
53  DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
54 
55  const Signature::Row r1{1, 1}, r2{2, 3}, r3{1, 4};
56  const Signature::Table table{r1, r2, r3};
57  DiscreteConditional actual1(X, {Y}, table);
58 
59  DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
60  DiscreteConditional expected1(1, f1);
61  EXPECT(assert_equal(expected1, actual1, 1e-9));
62 
64  X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
65  DiscreteConditional actual2(1, f2);
66  DecisionTreeFactor expected2 = f2 / *f2.sum(1);
67  EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
68 }
69 
70 /* ************************************************************************* */
71 TEST(DiscreteConditional, constructors2) {
72  DiscreteKey C(0, 2), B(1, 2);
73  Signature signature((C | B) = "4/1 3/1");
74  DiscreteConditional actual(signature);
75 
76  DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
77  EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
78 }
79 
80 /* ************************************************************************* */
81 TEST(DiscreteConditional, constructors3) {
82  DiscreteKey C(0, 2), B(1, 2), A(2, 2);
83  Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
84  DiscreteConditional actual(signature);
85 
86  DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
87  EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
88 }
89 
90 /* ****************************************************************************/
91 // Test evaluate for a discrete Prior P(Asia).
92 TEST(DiscreteConditional, PriorProbability) {
93  constexpr Key asiaKey = 0;
94  const DiscreteKey Asia(asiaKey, 2);
95  DiscreteConditional dc(Asia, "4/6");
97  EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
98  EXPECT(DiscreteConditional::CheckInvariants(dc, values));
99 }
100 
101 /* ************************************************************************* */
102 // Check that error, logProbability, evaluate all work as expected.
103 TEST(DiscreteConditional, probability) {
104  DiscreteKey C(2, 2), D(4, 2), E(3, 2);
105  DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
106 
107  DiscreteValues given {{C.first, 1}, {D.first, 0}, {E.first, 0}};
108  EXPECT_DOUBLES_EQUAL(0.2, C_given_DE.evaluate(given), 1e-9);
109  EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
110  EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
111  EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
112  EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given));
113 }
114 
115 /* ************************************************************************* */
116 // Check calculation of joint P(A,B)
118  DiscreteKey A(1, 2), B(0, 2);
119  DiscreteConditional conditional(A | B = "1/2 2/1");
120  DiscreteConditional prior(B % "1/2");
121 
122  // The expected factor
123  DecisionTreeFactor f(A & B, "1 4 2 2");
125 
126  // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
127  for (auto&& actual : {prior * conditional, conditional * prior}) {
128  EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
129  KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
130  EXPECT((frontals == KeyVector{0, 1}));
131  for (auto&& it : actual.enumerate()) {
132  const DiscreteValues& v = it.first;
133  EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
134  }
135  // And for good measure:
136  EXPECT(assert_equal(expected, actual));
137  }
138 }
139 
140 /* ************************************************************************* */
141 // Check calculation of conditional joint P(A,B|C)
143  DiscreteKey A(0, 2), B(1, 2), C(2, 2);
144  DiscreteConditional A_given_B(A | B = "1/3 3/1");
145  DiscreteConditional B_given_C(B | C = "1/3 3/1");
146 
147  // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
148  for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
149  EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
150  EXPECT_LONGS_EQUAL(1, actual.nrParents());
151  KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
152  EXPECT((frontals == KeyVector{0, 1}));
153  for (auto&& it : actual.enumerate()) {
154  const DiscreteValues& v = it.first;
155  EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
156  }
157  }
158 }
159 
160 /* ************************************************************************* */
161 // Check calculation of conditional joint P(A,B|C), double check keys
163  DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
164  DiscreteConditional A_given_B(A | B = "1/3 3/1");
165  DiscreteConditional B_given_C(B | C = "1/3 3/1");
166 
167  // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
168  for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
169  EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
170  EXPECT_LONGS_EQUAL(1, actual.nrParents());
171  KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
172  EXPECT((frontals == KeyVector{1, 2}));
173  for (auto&& it : actual.enumerate()) {
174  const DiscreteValues& v = it.first;
175  EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
176  }
177  }
178 }
179 
180 /* ************************************************************************* */
181 // Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
183  DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
184  DiscreteConditional A_given_B(A | B = "1/3 3/1");
185  DiscreteConditional B_given_D(B | D = "1/3 3/1");
186  DiscreteConditional AB_given_D = A_given_B * B_given_D;
187  DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
188 
189  // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
190  for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) {
191  EXPECT_LONGS_EQUAL(3, actual.nrFrontals());
192  EXPECT_LONGS_EQUAL(2, actual.nrParents());
193  KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
194  EXPECT((frontals == KeyVector{0, 1, 2}));
195  KeyVector parents(actual.beginParents(), actual.endParents());
196  EXPECT((parents == KeyVector{3, 4}));
197  for (auto&& it : actual.enumerate()) {
198  const DiscreteValues& v = it.first;
199  EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9);
200  }
201  }
202 }
203 
204 /* ************************************************************************* */
205 // Check calculation of marginals for joint P(A,B)
207  DiscreteKey A(1, 2), B(0, 2);
208  DiscreteConditional conditional(A | B = "1/2 2/1");
209  DiscreteConditional prior(B % "1/2");
210  DiscreteConditional pAB = prior * conditional;
211 
212  // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
213  // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
214  DiscreteConditional actualA = pAB.marginal(A.first);
215  DiscreteConditional pA(A % "5/4");
216  EXPECT(assert_equal(pA, actualA));
217  EXPECT(actualA.frontals() == KeyVector{1});
218  EXPECT_LONGS_EQUAL(0, actualA.nrParents());
219 
220  DiscreteConditional actualB = pAB.marginal(B.first);
221  EXPECT(assert_equal(prior, actualB));
222  EXPECT(actualB.frontals() == KeyVector{0});
223  EXPECT_LONGS_EQUAL(0, actualB.nrParents());
224 }
225 
226 /* ************************************************************************* */
227 // Check calculation of marginals in case branches are pruned
228 TEST(DiscreteConditional, marginals2) {
229  DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
230  DiscreteConditional conditional(A | B = "2/2 3/1");
231  DiscreteConditional prior(B % "1/2");
232  DiscreteConditional pAB = prior * conditional;
233  // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
234  // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
235  DiscreteConditional actualA = pAB.marginal(A.first);
236  DiscreteConditional pA(A % "8/4");
237  EXPECT(assert_equal(pA, actualA));
238 
239  DiscreteConditional actualB = pAB.marginal(B.first);
240  EXPECT(assert_equal(prior, actualB));
241 }
242 
243 /* ************************************************************************* */
244 TEST(DiscreteConditional, likelihood) {
245  DiscreteKey X(0, 2), Y(1, 3);
246  DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
247 
248  auto actual0 = conditional.likelihood(0);
249  DecisionTreeFactor expected0(Y, "0.2 0.4 0.5");
250  EXPECT(assert_equal(expected0, *actual0, 1e-9));
251 
252  auto actual1 = conditional.likelihood(1);
253  DecisionTreeFactor expected1(Y, "0.8 0.6 0.5");
254  EXPECT(assert_equal(expected1, *actual1, 1e-9));
255 }
256 
257 /* ************************************************************************* */
258 // Check choose on P(C|D,E)
260  DiscreteKey C(2, 2), D(4, 2), E(3, 2);
261  DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
262 
263  // Case 1: no given values: no-op
264  DiscreteValues given;
265  auto actual1 = C_given_DE.choose(given);
266  EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
267 
268  // Case 2: 1 given value
269  given[D.first] = 1;
270  auto actual2 = C_given_DE.choose(given);
271  EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
272  EXPECT_LONGS_EQUAL(1, actual2->nrParents());
273  DiscreteConditional expected2(C | E = "1/1 1/4");
274  EXPECT(assert_equal(expected2, *actual2, 1e-9));
275 
276  // Case 2: 2 given values
277  given[E.first] = 0;
278  auto actual3 = C_given_DE.choose(given);
279  EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
280  EXPECT_LONGS_EQUAL(0, actual3->nrParents());
281  DiscreteConditional expected3(C % "1/1");
282  EXPECT(assert_equal(expected3, *actual3, 1e-9));
283 }
284 
285 /* ************************************************************************* */
286 // Check markdown representation looks as expected, no parents.
287 TEST(DiscreteConditional, markdown_prior) {
288  DiscreteKey A(Symbol('x', 1), 3);
289  DiscreteConditional conditional(A % "1/2/2");
290  string expected =
291  " *P(x1):*\n\n"
292  "|x1|value|\n"
293  "|:-:|:-:|\n"
294  "|0|0.2|\n"
295  "|1|0.4|\n"
296  "|2|0.4|\n";
297  string actual = conditional.markdown();
298  EXPECT(actual == expected);
299 }
300 
301 /* ************************************************************************* */
302 // Check markdown representation looks as expected, no parents + names.
303 TEST(DiscreteConditional, markdown_prior_names) {
304  Symbol x1('x', 1);
305  DiscreteKey A(x1, 3);
306  DiscreteConditional conditional(A % "1/2/2");
307  string expected =
308  " *P(x1):*\n\n"
309  "|x1|value|\n"
310  "|:-:|:-:|\n"
311  "|A0|0.2|\n"
312  "|A1|0.4|\n"
313  "|A2|0.4|\n";
314  DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}};
315  string actual = conditional.markdown(DefaultKeyFormatter, names);
316  EXPECT(actual == expected);
317 }
318 
319 /* ************************************************************************* */
320 // Check markdown representation looks as expected, multivalued.
321 TEST(DiscreteConditional, markdown_multivalued) {
322  DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
323  DiscreteConditional conditional(
324  A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
325  string expected =
326  " *P(a1|b1):*\n\n"
327  "|*b1*|0|1|2|\n"
328  "|:-:|:-:|:-:|:-:|\n"
329  "|0|0.02|0.88|0.1|\n"
330  "|1|0.02|0.2|0.78|\n"
331  "|2|0.33|0.33|0.34|\n"
332  "|3|0.33|0.33|0.34|\n"
333  "|4|0.95|0.02|0.03|\n";
334  string actual = conditional.markdown();
335  EXPECT(actual == expected);
336 }
337 
338 /* ************************************************************************* */
339 // Check markdown representation looks as expected, two parents + names.
341  DiscreteKey A(2, 2), B(1, 2), C(0, 3);
342  DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
343  string expected =
344  " *P(A|B,C):*\n\n"
345  "|*B*|*C*|T|F|\n"
346  "|:-:|:-:|:-:|:-:|\n"
347  "|-|Zero|0|1|\n"
348  "|-|One|0.25|0.75|\n"
349  "|-|Two|0.5|0.5|\n"
350  "|+|Zero|0.75|0.25|\n"
351  "|+|One|0|1|\n"
352  "|+|Two|1|0|\n";
353  vector<string> keyNames{"C", "B", "A"};
354  auto formatter = [keyNames](Key key) { return keyNames[key]; };
355  DecisionTreeFactor::Names names{
356  {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
357  string actual = conditional.markdown(formatter, names);
358  EXPECT(actual == expected);
359 }
360 
361 /* ************************************************************************* */
362 // Check html representation looks as expected, two parents + names.
364  DiscreteKey A(2, 2), B(1, 2), C(0, 3);
365  DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
366  string expected =
367  "<div>\n"
368  "<p> <i>P(A|B,C):</i></p>\n"
369  "<table class='DiscreteConditional'>\n"
370  " <thead>\n"
371  " <tr><th><i>B</i></th><th><i>C</i></th><th>T</th><th>F</th></tr>\n"
372  " </thead>\n"
373  " <tbody>\n"
374  " <tr><th>-</th><th>Zero</th><td>0</td><td>1</td></tr>\n"
375  " <tr><th>-</th><th>One</th><td>0.25</td><td>0.75</td></tr>\n"
376  " <tr><th>-</th><th>Two</th><td>0.5</td><td>0.5</td></tr>\n"
377  " <tr><th>+</th><th>Zero</th><td>0.75</td><td>0.25</td></tr>\n"
378  " <tr><th>+</th><th>One</th><td>0</td><td>1</td></tr>\n"
379  " <tr><th>+</th><th>Two</th><td>1</td><td>0</td></tr>\n"
380  " </tbody>\n"
381  "</table>\n"
382  "</div>";
383  vector<string> keyNames{"C", "B", "A"};
384  auto formatter = [keyNames](Key key) { return keyNames[key]; };
385  DecisionTreeFactor::Names names{
386  {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
387  string actual = conditional.html(formatter, names);
388  EXPECT(actual == expected);
389 }
390 
391 /* ************************************************************************* */
392 int main() {
393  TestResult tr;
394  return TestRegistry::runAllTests(tr);
395 }
396 /* ************************************************************************* */
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
const gtsam::Symbol key('X', 0)
const char Y
static int runAllTests(TestResult &result)
double evaluate(const DiscreteValues &values) const
Evaluate, just look up in AlgebraicDecisonTree.
Point2 prior(const Point2 &x)
Prior on a single pose.
Definition: simulated2D.h:88
Matrix expected
Definition: testMatrix.cpp:971
shared_ptr choose(const DiscreteValues &given) const
< DiscreteValues version
FACTOR::const_iterator endFrontals() const
Definition: Conditional.h:182
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
static const T & choose(int layout, const T &col, const T &row)
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
leaf::MyValues values
FACTOR::const_iterator beginParents() const
Definition: Conditional.h:185
static const Key asiaKey
Definition: BFloat16.h:88
Frontals frontals() const
Definition: Conditional.h:143
double f2(const Vector2 &x)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
DecisionTreeFactor::shared_ptr likelihood(const DiscreteValues &frontalValues) const
EIGEN_DEVICE_FUNC const LogReturnType log() const
static const KeyFormatter DefaultKeyFormatter
Definition: Key.h:43
double error(const DiscreteValues &values) const
Calculate error for DiscreteValues x, is -log(probability).
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
const KeyFormatter & formatter
FACTOR::const_iterator beginFrontals() const
Definition: Conditional.h:179
const_iterator end() const
Definition: Factor.h:148
#define Z
Definition: icosphere.cpp:21
std::vector< Row > Table
Definition: Signature.h:60
string html(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of html.
double logProbability(const DiscreteValues &x) const
Log-probability is just -error(x).
FACTOR::const_iterator endParents() const
Definition: Conditional.h:188
#define EXPECT(condition)
Definition: Test.h:150
Array< int, Dynamic, 1 > v
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Array< double, 1, 3 > e(1./3., 0.5, 2.)
static const double r2
static const double r3
std::vector< double > Row
Definition: Signature.h:59
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
traits
Definition: chartTesting.h:28
Point2 pA(size_t i)
DiscreteKey E(5, 2)
TEST(DiscreteConditional, constructors)
static const double r1
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
ArrayXXf table(10, 4)
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Pose3 x1
Definition: testPose3.cpp:663
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
shared_ptr sum(size_t nrFrontals) const
Create new factor by summing all values with the same separator values.
DiscreteConditional marginal(Key key) const
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
#define X
Definition: icosphere.cpp:20
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
Marginals marginals(graph, result)
size_t nrParents() const
Definition: Conditional.h:132


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:01