testDiscreteBayesTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3 * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8 * See LICENSE for the license information
9 
10 * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteBayesTree.cpp
14  * @date sept 15, 2012
15  * @author Frank Dellaert
16  */
17 
18 #include <gtsam/base/Vector.h>
19 #include <gtsam/inference/Symbol.h>
24 
26 
27 #include <iostream>
28 #include <vector>
29 
30 using namespace gtsam;
31 static constexpr bool debug = false;
32 
33 /* ************************************************************************* */
34 struct TestFixture {
36  std::vector<DiscreteValues> assignments;
38  std::shared_ptr<DiscreteBayesTree> bayesTree;
39 
45  // Define variables.
46  for (int i = 0; i < 15; i++) {
47  DiscreteKey key_i(i, 2);
48  keys.push_back(key_i);
49  }
50 
51  // Enumerate all assignments.
53 
54  // Create thin-tree Bayesnet.
55  bayesNet.add(keys[14] % "1/3");
56 
57  bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
58  bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
59 
60  bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
61  bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
62  bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
63  bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
64 
65  bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
66  bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
67  bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
68  bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
69 
70  bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
71  bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
72  bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
73  bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
74 
75  // Create a BayesTree out of the Bayes net.
76  bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
77  }
78 };
79 
80 /* ************************************************************************* */
81 // Check that BN and BT give the same answer on all configurations
82 TEST(DiscreteBayesTree, ThinTree) {
83  TestFixture self;
84 
85  if (debug) {
86  GTSAM_PRINT(self.bayesNet);
87  self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
88  }
89 
90  // create a BayesTree out of a Bayes net
91  if (debug) {
92  GTSAM_PRINT(*self.bayesTree);
93  self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
94  }
95 
96  // Check frontals and parents
97  for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
98  auto clique_i = (*self.bayesTree)[i];
99  EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
100  }
101 
102  for (const auto& x : self.assignments) {
103  double expected = self.bayesNet.evaluate(x);
104  double actual = self.bayesTree->evaluate(x);
105  DOUBLES_EQUAL(expected, actual, 1e-9);
106  }
107 }
108 
109 /* ************************************************************************* */
110 // Check calculation of separator marginals
111 TEST(DiscreteBayesTree, SeparatorMarginals) {
112  TestFixture self;
113 
114  // Calculate some marginals for DiscreteValues==all1
115  double marginal_14 = 0, joint_8_12 = 0;
116  for (auto& x : self.assignments) {
117  double px = self.bayesTree->evaluate(x);
118  if (x[8] && x[12]) joint_8_12 += px;
119  if (x[14]) marginal_14 += px;
120  }
121  DiscreteValues all1 = self.assignments.back();
122 
123  // check separator marginal P(S0)
124  auto clique = (*self.bayesTree)[0];
125  DiscreteFactorGraph separatorMarginal0 =
126  clique->separatorMarginal(EliminateDiscrete);
127  DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
128 
129  // check separator marginal P(S9), should be P(14)
130  clique = (*self.bayesTree)[9];
131  DiscreteFactorGraph separatorMarginal9 =
132  clique->separatorMarginal(EliminateDiscrete);
133  DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
134 
135  // check separator marginal of root, should be empty
136  clique = (*self.bayesTree)[11];
137  DiscreteFactorGraph separatorMarginal11 =
138  clique->separatorMarginal(EliminateDiscrete);
139  LONGS_EQUAL(0, separatorMarginal11.size());
140 }
141 
142 /* ************************************************************************* */
143 // Check shortcuts in the tree
144 TEST(DiscreteBayesTree, Shortcuts) {
145  TestFixture self;
146 
147  // Calculate some marginals for DiscreteValues==all1
148  double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
149  joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
150  for (auto& x : self.assignments) {
151  double px = self.bayesTree->evaluate(x);
152  if (x[11] && x[13]) {
153  joint_11_13 += px;
154  if (x[8] && x[12]) joint_8_11_12_13 += px;
155  if (x[9] && x[12]) joint_9_11_12_13 += px;
156  if (x[14]) {
157  joint_11_13_14 += px;
158  if (x[12]) {
159  joint_11_12_13_14 += px;
160  }
161  }
162  }
163  }
164  DiscreteValues all1 = self.assignments.back();
165 
166  auto R = self.bayesTree->roots().front();
167 
168  // check shortcut P(S9||R) to root
169  auto clique = (*self.bayesTree)[9];
170  DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
171  LONGS_EQUAL(1, shortcut.size());
172  DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
173 
174  // check shortcut P(S8||R) to root
175  clique = (*self.bayesTree)[8];
176  shortcut = clique->shortcut(R, EliminateDiscrete);
177  DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
178 
179  // check shortcut P(S2||R) to root
180  clique = (*self.bayesTree)[2];
181  shortcut = clique->shortcut(R, EliminateDiscrete);
182  DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
183 
184  // check shortcut P(S0||R) to root
185  clique = (*self.bayesTree)[0];
186  shortcut = clique->shortcut(R, EliminateDiscrete);
187  DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
188 
189  // calculate all shortcuts to root
190  DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
191  for (auto clique : cliques) {
192  DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
193  if (debug) {
194  clique.second->conditional_->printSignature();
195  shortcut.print("shortcut:");
196  }
197  }
198 }
199 
200 /* ************************************************************************* */
201 // Check all marginals
202 TEST(DiscreteBayesTree, MarginalFactors) {
203  TestFixture self;
204 
205  Vector marginals = Vector::Zero(15);
206  for (size_t i = 0; i < self.assignments.size(); ++i) {
207  DiscreteValues& x = self.assignments[i];
208  double px = self.bayesTree->evaluate(x);
209  for (size_t i = 0; i < 15; i++)
210  if (x[i]) marginals[i] += px;
211  }
212 
213  // Check all marginals
214  DiscreteValues all1 = self.assignments.back();
215  for (size_t i = 0; i < 15; i++) {
216  auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
217  double actual = (*marginalFactor)(all1);
218  DOUBLES_EQUAL(marginals[i], actual, 1e-9);
219  }
220 }
221 
222 /* ************************************************************************* */
223 // Check a number of joint marginals.
225  TestFixture self;
226 
227  // Calculate some marginals for DiscreteValues==all1
228  Vector marginals = Vector::Zero(15);
229  double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
230  joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
231  for (size_t i = 0; i < self.assignments.size(); ++i) {
232  DiscreteValues& x = self.assignments[i];
233  double px = self.bayesTree->evaluate(x);
234  for (size_t i = 0; i < 15; i++)
235  if (x[i]) marginals[i] += px;
236  if (x[12] && x[14]) {
237  joint_12_14 += px;
238  if (x[9]) joint_9_12_14 += px;
239  if (x[8]) joint_8_12_14 += px;
240  }
241  if (x[2]) {
242  if (x[8]) joint82 += px;
243  if (x[1]) joint12 += px;
244  }
245  if (x[4]) {
246  if (x[2]) joint24 += px;
247  if (x[5]) joint45 += px;
248  if (x[6]) joint46 += px;
249  if (x[11]) joint_4_11 += px;
250  }
251  }
252 
253  // regression tests:
254  DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
255  DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
256  DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
257 
258  DiscreteValues all1 = self.assignments.back();
259  DiscreteBayesNet::shared_ptr actualJoint;
260 
261  // Check joint P(8, 2)
262  actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
263  DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
264 
265  // Check joint P(1, 2)
266  actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
267  DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
268 
269  // Check joint P(2, 4)
270  actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
271  DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
272 
273  // Check joint P(4, 5)
274  actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
275  DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
276 
277  // Check joint P(4, 6)
278  actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
279  DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
280 
281  // Check joint P(4, 11)
282  actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
283  DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
284 }
285 
286 /* ************************************************************************* */
288  TestFixture self;
289  std::string actual = self.bayesTree->dot();
290  EXPECT(actual ==
291  "digraph G{\n"
292  "0[label=\"13, 11, 6, 7\"];\n"
293  "0->1\n"
294  "1[label=\"14 : 11, 13\"];\n"
295  "1->2\n"
296  "2[label=\"9, 12 : 14\"];\n"
297  "2->3\n"
298  "3[label=\"3 : 9, 12\"];\n"
299  "2->4\n"
300  "4[label=\"2 : 9, 12\"];\n"
301  "2->5\n"
302  "5[label=\"8 : 12, 14\"];\n"
303  "5->6\n"
304  "6[label=\"1 : 8, 12\"];\n"
305  "5->7\n"
306  "7[label=\"0 : 8, 12\"];\n"
307  "1->8\n"
308  "8[label=\"10 : 13, 14\"];\n"
309  "8->9\n"
310  "9[label=\"5 : 10, 13\"];\n"
311  "8->10\n"
312  "10[label=\"4 : 10, 13\"];\n"
313  "}");
314 }
315 
316 /* ************************************************************************* */
317 // Check that we can have a multi-frontal lookup table
321 
322  // Make a small planning-like graph: 3 states, 2 actions
324  const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
325  const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
326 
327  // Constraint on start and goal
328  graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
329  graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
330 
331  // Should I stay or should I go?
332  // "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
333  const double r = 10;
334  std::vector<double> table{
335  r, 0, 0, 0, r, 0, // x1 = 0
336  0, r, 0, 0, 0, r, // x1 = 1
337  0, 0, r, 0, 0, r // x1 = 2
338  };
341 
342  // eliminate for MPE (maximum probable explanation).
343  Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
344  auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
345 
346  // Check that the lookup table is correct
347  EXPECT_LONGS_EQUAL(2, lookup->size());
348  auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
349  EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
350  // check that sum is 1.0 (not 100, as we now normalize)
352  EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
353  // And that only non-zero reward is for x1 a1 x2 == 0 1 1
354  EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
355 
356  auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
357  // check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
358  auto sum_x2 = lookup_a2_x3->sum(2);
359  EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
360  EXPECT_DOUBLES_EQUAL(1.0, (*sum_x2)({{X(2),1}}), 1e-9);
361  EXPECT_DOUBLES_EQUAL(2.0, (*sum_x2)({{X(2),2}}), 1e-9);
362  EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
363  // And that the non-zero rewards are for
364  // x2 a2 x3 == 1 1 2
365  EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
366  // x2 a2 x3 == 2 0 2
367  EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
368  // x2 a2 x3 == 2 1 2
369  EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
370 }
371 
372 /* ************************************************************************* */
373 int main() {
374  TestResult tr;
375  return TestRegistry::runAllTests(tr);
376 }
377 /* ************************************************************************* */
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
DiscreteBayesNet.h
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:98
main
int main()
Definition: testDiscreteBayesTree.cpp:373
Vector.h
typedef and functions to augment Eigen's VectorXd
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
gtsam::BayesNet::print
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31
EXPECT_LONGS_EQUAL
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
DiscreteFactorGraph.h
TestFixture::TestFixture
TestFixture()
Definition: testDiscreteBayesTree.cpp:44
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition: gnuplot_common_settings.hh:12
gtsam::BayesTree< DiscreteBayesTreeClique >::Nodes
ConcurrentMap< Key, sharedClique > Nodes
Definition: BayesTree.h:92
gtsam::symbol_shorthand::A
Key A(std::uint64_t j)
Definition: inference/Symbol.h:148
X
#define X
Definition: icosphere.cpp:20
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::Vector
Eigen::VectorXd Vector
Definition: Vector.h:38
x3
Pose3 x3(Rot3::Ypr(M_PI/4.0, 0.0, 0.0), l2)
gtsam::EliminateableFactorGraph::eliminateMultifrontal
std::shared_ptr< BayesTreeType > eliminateMultifrontal(OptionalOrderingType orderingType={}, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex={}) const
Definition: EliminateableFactorGraph-inst.h:89
A
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
TestFixture::bayesTree
std::shared_ptr< DiscreteBayesTree > bayesTree
Definition: testDiscreteBayesTree.cpp:38
align_3::a1
Point2 a1
Definition: testPose2.cpp:769
gtsam::symbol_shorthand::X
Key X(std::uint64_t j)
Definition: inference/Symbol.h:171
table
ArrayXXf table(10, 4)
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::DiscreteValues::CartesianProduct
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
Definition: DiscreteValues.h:85
x1
Pose3 x1
Definition: testPose3.cpp:663
gtsam::EliminateDiscrete
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
Definition: DiscreteFactorGraph.cpp:205
DOUBLES_EQUAL
#define DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:141
debug
static constexpr bool debug
Definition: testDiscreteBayesTree.cpp:31
gtsam::DiscreteBayesNet::evaluate
double evaluate(const DiscreteValues &values) const
Definition: DiscreteBayesNet.cpp:43
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
TestFixture::bayesNet
DiscreteBayesNet bayesNet
Definition: testDiscreteBayesTree.cpp:37
Symbol.h
GTSAM_PRINT
#define GTSAM_PRINT(x)
Definition: Testable.h:43
BayesNet.h
Bayes network.
gtsam::EliminateForMPE
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateForMPE(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Alternate elimination function for that creates non-normalized lookup tables.
Definition: DiscreteFactorGraph.cpp:116
DiscreteBayesTree.h
Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree.
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
ordering
static enum @1096 ordering
TestResult
Definition: TestResult.h:26
gtsam::FactorGraph::size
size_t size() const
Definition: FactorGraph.h:297
gtsam::DiscreteBayesNet::add
void add(const DiscreteKey &key, const std::string &spec)
Definition: DiscreteBayesNet.h:85
empty
Definition: test_copy_move.cpp:19
gtsam
traits
Definition: chartTesting.h:28
gtsam::TEST
TEST(SmartFactorBase, Pinhole)
Definition: testSmartFactorBase.cpp:38
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
align_3::a2
Point2 a2
Definition: testPose2.cpp:770
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
gtsam::DiscreteBayesNet::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesNet.h:43
TestFixture::assignments
std::vector< DiscreteValues > assignments
Definition: testDiscreteBayesTree.cpp:36
gtsam::DiscreteBayesTree
A Bayes tree representing a Discrete distribution.
Definition: DiscreteBayesTree.h:73
TestFixture
Definition: testDiscreteBayesTree.cpp:34
gtsam::FactorGraph::add
IsDerived< DERIVEDFACTOR > add(std::shared_ptr< DERIVEDFACTOR > factor)
add is a synonym for push_back.
Definition: FactorGraph.h:171
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2
marginals
Marginals marginals(graph, result)
gtsam::Ordering
Definition: inference/Ordering.h:33
x2
Pose3 x2(Rot3::Ypr(0.0, 0.0, 0.0), l2)
LONGS_EQUAL
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:134
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
TestFixture::keys
DiscreteKeys keys
Definition: testDiscreteBayesTree.cpp:35
R
Rot2 R(Rot2::fromAngle(0.1))
px
RealScalar RealScalar * px
Definition: level1_cplx_impl.h:28


gtsam
Author(s):
autogenerated on Tue Jun 25 2024 03:05:42