testDiscreteBayesTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3 * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8 * See LICENSE for the license information
9 
10 * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteBayesTree.cpp
14  * @date sept 15, 2012
15  * @author Frank Dellaert
16  */
17 
18 #include <gtsam/base/Vector.h>
23 
25 
26 #include <iostream>
27 #include <vector>
28 
29 using namespace std;
30 using namespace gtsam;
31 static constexpr bool debug = false;
32 
33 /* ************************************************************************* */
34 struct TestFixture {
35  vector<DiscreteKey> keys;
37  std::shared_ptr<DiscreteBayesTree> bayesTree;
38 
44  // Define variables.
45  for (int i = 0; i < 15; i++) {
46  DiscreteKey key_i(i, 2);
47  keys.push_back(key_i);
48  }
49 
50  // Create thin-tree Bayesnet.
51  bayesNet.add(keys[14] % "1/3");
52 
53  bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
54  bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
55 
56  bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
57  bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
58  bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
59  bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
60 
61  bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
62  bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
63  bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
64  bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
65 
66  bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
67  bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
68  bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
69  bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
70 
71  // Create a BayesTree out of the Bayes net.
72  bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
73  }
74 };
75 
76 /* ************************************************************************* */
77 TEST(DiscreteBayesTree, ThinTree) {
78  const TestFixture self;
79  const auto& keys = self.keys;
80 
81  if (debug) {
82  GTSAM_PRINT(self.bayesNet);
83  self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
84  }
85 
86  // create a BayesTree out of a Bayes net
87  if (debug) {
88  GTSAM_PRINT(*self.bayesTree);
89  self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
90  }
91 
92  // Check frontals and parents
93  for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
94  auto clique_i = (*self.bayesTree)[i];
95  EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
96  }
97 
98  auto R = self.bayesTree->roots().front();
99 
100  // Check whether BN and BT give the same answer on all configurations
101  auto allPosbValues = DiscreteValues::CartesianProduct(
102  keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
103  keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
104  keys[14]);
105  for (size_t i = 0; i < allPosbValues.size(); ++i) {
106  DiscreteValues x = allPosbValues[i];
107  double expected = self.bayesNet.evaluate(x);
108  double actual = self.bayesTree->evaluate(x);
109  DOUBLES_EQUAL(expected, actual, 1e-9);
110  }
111 
112  // Calculate all some marginals for DiscreteValues==all1
113  Vector marginals = Vector::Zero(15);
114  double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
115  joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
116  joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
117  joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
118  for (size_t i = 0; i < allPosbValues.size(); ++i) {
119  DiscreteValues x = allPosbValues[i];
120  double px = self.bayesTree->evaluate(x);
121  for (size_t i = 0; i < 15; i++)
122  if (x[i]) marginals[i] += px;
123  if (x[12] && x[14]) {
124  joint_12_14 += px;
125  if (x[9]) joint_9_12_14 += px;
126  if (x[8]) joint_8_12_14 += px;
127  }
128  if (x[8] && x[12]) joint_8_12 += px;
129  if (x[2]) {
130  if (x[8]) joint82 += px;
131  if (x[1]) joint12 += px;
132  }
133  if (x[4]) {
134  if (x[2]) joint24 += px;
135  if (x[5]) joint45 += px;
136  if (x[6]) joint46 += px;
137  if (x[11]) joint_4_11 += px;
138  }
139  if (x[11] && x[13]) {
140  joint_11_13 += px;
141  if (x[8] && x[12]) joint_8_11_12_13 += px;
142  if (x[9] && x[12]) joint_9_11_12_13 += px;
143  if (x[14]) {
144  joint_11_13_14 += px;
145  if (x[12]) {
146  joint_11_12_13_14 += px;
147  }
148  }
149  }
150  }
151  DiscreteValues all1 = allPosbValues.back();
152 
153  // check separator marginal P(S0)
154  auto clique = (*self.bayesTree)[0];
155  DiscreteFactorGraph separatorMarginal0 =
156  clique->separatorMarginal(EliminateDiscrete);
157  DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
158 
159  DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
160  DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
161  DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
162 
163  // check separator marginal P(S9), should be P(14)
164  clique = (*self.bayesTree)[9];
165  DiscreteFactorGraph separatorMarginal9 =
166  clique->separatorMarginal(EliminateDiscrete);
167  DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
168 
169  // check separator marginal of root, should be empty
170  clique = (*self.bayesTree)[11];
171  DiscreteFactorGraph separatorMarginal11 =
172  clique->separatorMarginal(EliminateDiscrete);
173  LONGS_EQUAL(0, separatorMarginal11.size());
174 
175  // check shortcut P(S9||R) to root
176  clique = (*self.bayesTree)[9];
177  DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
178  LONGS_EQUAL(1, shortcut.size());
179  DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
180 
181  // check shortcut P(S8||R) to root
182  clique = (*self.bayesTree)[8];
183  shortcut = clique->shortcut(R, EliminateDiscrete);
184  DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
185 
186  // check shortcut P(S2||R) to root
187  clique = (*self.bayesTree)[2];
188  shortcut = clique->shortcut(R, EliminateDiscrete);
189  DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
190 
191  // check shortcut P(S0||R) to root
192  clique = (*self.bayesTree)[0];
193  shortcut = clique->shortcut(R, EliminateDiscrete);
194  DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
195 
196  // calculate all shortcuts to root
197  DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
198  for (auto clique : cliques) {
199  DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
200  if (debug) {
201  clique.second->conditional_->printSignature();
202  shortcut.print("shortcut:");
203  }
204  }
205 
206  // Check all marginals
207  DiscreteFactor::shared_ptr marginalFactor;
208  for (size_t i = 0; i < 15; i++) {
209  marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
210  double actual = (*marginalFactor)(all1);
211  DOUBLES_EQUAL(marginals[i], actual, 1e-9);
212  }
213 
214  DiscreteBayesNet::shared_ptr actualJoint;
215 
216  // Check joint P(8, 2)
217  actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
218  DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
219 
220  // Check joint P(1, 2)
221  actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
222  DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
223 
224  // Check joint P(2, 4)
225  actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
226  DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
227 
228  // Check joint P(4, 5)
229  actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
230  DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
231 
232  // Check joint P(4, 6)
233  actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
234  DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
235 
236  // Check joint P(4, 11)
237  actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
238  DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
239 }
240 
241 /* ************************************************************************* */
243  const TestFixture self;
244  string actual = self.bayesTree->dot();
245  EXPECT(actual ==
246  "digraph G{\n"
247  "0[label=\"13, 11, 6, 7\"];\n"
248  "0->1\n"
249  "1[label=\"14 : 11, 13\"];\n"
250  "1->2\n"
251  "2[label=\"9, 12 : 14\"];\n"
252  "2->3\n"
253  "3[label=\"3 : 9, 12\"];\n"
254  "2->4\n"
255  "4[label=\"2 : 9, 12\"];\n"
256  "2->5\n"
257  "5[label=\"8 : 12, 14\"];\n"
258  "5->6\n"
259  "6[label=\"1 : 8, 12\"];\n"
260  "5->7\n"
261  "7[label=\"0 : 8, 12\"];\n"
262  "1->8\n"
263  "8[label=\"10 : 13, 14\"];\n"
264  "8->9\n"
265  "9[label=\"5 : 10, 13\"];\n"
266  "8->10\n"
267  "10[label=\"4 : 10, 13\"];\n"
268  "}");
269 }
270 
271 /* ************************************************************************* */
272 int main() {
273  TestResult tr;
274  return TestRegistry::runAllTests(tr);
275 }
276 /* ************************************************************************* */
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
static int runAllTests(TestResult &result)
double evaluate(const DiscreteValues &values) const
A Bayes tree representing a Discrete distribution.
Matrix expected
Definition: testMatrix.cpp:971
#define DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:141
Rot2 R(Rot2::fromAngle(0.1))
Definition: BFloat16.h:88
Bayes network.
static constexpr bool debug
size_t size() const
Definition: FactorGraph.h:334
DiscreteBayesNet bayesNet
Eigen::VectorXd Vector
Definition: Vector.h:38
#define EXPECT(condition)
Definition: Test.h:150
RealScalar RealScalar * px
Array< double, 1, 3 > e(1./3., 0.5, 2.)
Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree.
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:134
TEST(DiscreteBayesTree, ThinTree)
traits
Definition: chartTesting.h:28
typedef and functions to augment Eigen&#39;s VectorXd
int main()
#define GTSAM_PRINT(x)
Definition: Testable.h:43
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
vector< DiscreteKey > keys
std::shared_ptr< This > shared_ptr
void add(const DiscreteKey &key, const std::string &spec)
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std::shared_ptr< DiscreteBayesTree > bayesTree
std::shared_ptr< BayesTreeType > eliminateMultifrontal(OptionalOrderingType orderingType={}, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex={}) const
const KeyVector keys
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Marginals marginals(graph, result)
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:01