testDiscreteBayesTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3 * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8 * See LICENSE for the license information
9 
10 * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteBayesTree.cpp
14  * @date sept 15, 2012
15  * @author Frank Dellaert
16  */
17 
18 #include <gtsam/base/Vector.h>
23 
24 #include <boost/assign/std/vector.hpp>
25 using namespace boost::assign;
26 
28 
29 #include <vector>
30 
31 using namespace std;
32 using namespace gtsam;
33 
34 static bool debug = false;
35 
36 /* ************************************************************************* */
37 
39  const int nrNodes = 15;
40  const size_t nrStates = 2;
41 
42  // define variables
43  vector<DiscreteKey> key;
44  for (int i = 0; i < nrNodes; i++) {
45  DiscreteKey key_i(i, nrStates);
46  key.push_back(key_i);
47  }
48 
49  // create a thin-tree Bayesnet, a la Jean-Guillaume
50  DiscreteBayesNet bayesNet;
51  bayesNet.add(key[14] % "1/3");
52 
53  bayesNet.add(key[13] | key[14] = "1/3 3/1");
54  bayesNet.add(key[12] | key[14] = "3/1 3/1");
55 
56  bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
57  bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
58  bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
59  bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
60 
61  bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
62  bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
63  bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
64  bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
65 
66  bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
67  bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
68  bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
69  bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
70 
71  if (debug) {
72  GTSAM_PRINT(bayesNet);
73  bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
74  }
75 
76  // create a BayesTree out of a Bayes net
77  auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
78  if (debug) {
79  GTSAM_PRINT(*bayesTree);
80  bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
81  }
82 
83  // Check frontals and parents
84  for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
85  auto clique_i = (*bayesTree)[i];
86  EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
87  }
88 
89  auto R = bayesTree->roots().front();
90 
91  // Check whether BN and BT give the same answer on all configurations
92  vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
93  key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
94  key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
95  for (size_t i = 0; i < allPosbValues.size(); ++i) {
96  DiscreteFactor::Values x = allPosbValues[i];
97  double expected = bayesNet.evaluate(x);
98  double actual = bayesTree->evaluate(x);
99  DOUBLES_EQUAL(expected, actual, 1e-9);
100  }
101 
102  // Calculate all some marginals for Values==all1
103  Vector marginals = Vector::Zero(15);
104  double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
105  joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
106  joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
107  joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
108  for (size_t i = 0; i < allPosbValues.size(); ++i) {
109  DiscreteFactor::Values x = allPosbValues[i];
110  double px = bayesTree->evaluate(x);
111  for (size_t i = 0; i < 15; i++)
112  if (x[i]) marginals[i] += px;
113  if (x[12] && x[14]) {
114  joint_12_14 += px;
115  if (x[9]) joint_9_12_14 += px;
116  if (x[8]) joint_8_12_14 += px;
117  }
118  if (x[8] && x[12]) joint_8_12 += px;
119  if (x[2]) {
120  if (x[8]) joint82 += px;
121  if (x[1]) joint12 += px;
122  }
123  if (x[4]) {
124  if (x[2]) joint24 += px;
125  if (x[5]) joint45 += px;
126  if (x[6]) joint46 += px;
127  if (x[11]) joint_4_11 += px;
128  }
129  if (x[11] && x[13]) {
130  joint_11_13 += px;
131  if (x[8] && x[12]) joint_8_11_12_13 += px;
132  if (x[9] && x[12]) joint_9_11_12_13 += px;
133  if (x[14]) {
134  joint_11_13_14 += px;
135  if (x[12]) {
136  joint_11_12_13_14 += px;
137  }
138  }
139  }
140  }
141  DiscreteFactor::Values all1 = allPosbValues.back();
142 
143  // check separator marginal P(S0)
144  auto clique = (*bayesTree)[0];
145  DiscreteFactorGraph separatorMarginal0 =
146  clique->separatorMarginal(EliminateDiscrete);
147  DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
148 
149  // check separator marginal P(S9), should be P(14)
150  clique = (*bayesTree)[9];
151  DiscreteFactorGraph separatorMarginal9 =
152  clique->separatorMarginal(EliminateDiscrete);
153  DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
154 
155  // check separator marginal of root, should be empty
156  clique = (*bayesTree)[11];
157  DiscreteFactorGraph separatorMarginal11 =
158  clique->separatorMarginal(EliminateDiscrete);
159  LONGS_EQUAL(0, separatorMarginal11.size());
160 
161  // check shortcut P(S9||R) to root
162  clique = (*bayesTree)[9];
163  DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
164  LONGS_EQUAL(1, shortcut.size());
165  DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
166 
167  // check shortcut P(S8||R) to root
168  clique = (*bayesTree)[8];
169  shortcut = clique->shortcut(R, EliminateDiscrete);
170  DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
171 
172  // check shortcut P(S2||R) to root
173  clique = (*bayesTree)[2];
174  shortcut = clique->shortcut(R, EliminateDiscrete);
175  DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
176 
177  // check shortcut P(S0||R) to root
178  clique = (*bayesTree)[0];
179  shortcut = clique->shortcut(R, EliminateDiscrete);
180  DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
181 
182  // calculate all shortcuts to root
183  DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
184  for (auto clique : cliques) {
185  DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
186  if (debug) {
187  clique.second->conditional_->printSignature();
188  shortcut.print("shortcut:");
189  }
190  }
191 
192  // Check all marginals
193  DiscreteFactor::shared_ptr marginalFactor;
194  for (size_t i = 0; i < 15; i++) {
195  marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
196  double actual = (*marginalFactor)(all1);
197  DOUBLES_EQUAL(marginals[i], actual, 1e-9);
198  }
199 
200  DiscreteBayesNet::shared_ptr actualJoint;
201 
202  // Check joint P(8, 2)
203  actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
204  DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
205 
206  // Check joint P(1, 2)
207  actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
208  DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
209 
210  // Check joint P(2, 4)
211  actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
212  DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
213 
214  // Check joint P(4, 5)
215  actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
216  DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
217 
218  // Check joint P(4, 6)
219  actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
220  DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
221 
222  // Check joint P(4, 11)
223  actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
224  DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
225 }
226 
227 /* ************************************************************************* */
228 int main() {
229  TestResult tr;
230  return TestRegistry::runAllTests(tr);
231 }
232 /* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesTree, ThinTree)
size_t size() const
Definition: FactorGraph.h:306
std::vector< Assignment< L > > cartesianProduct(const std::vector< std::pair< L, size_t > > &keys)
Get Cartesian product consisting all possible configurations.
Definition: Assignment.h:62
static int runAllTests(TestResult &result)
Matrix expected
Definition: testMatrix.cpp:974
boost::shared_ptr< BayesTreeType > eliminateMultifrontal(OptionalOrderingType orderingType=boost::none, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex=boost::none) const
#define DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:142
Rot2 R(Rot2::fromAngle(0.1))
void add(const Signature &s)
Definition: Half.h:150
Bayes network.
Eigen::VectorXd Vector
Definition: Vector.h:38
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:34
RealScalar RealScalar * px
static bool debug
Array< double, 1, 3 > e(1./3., 0.5, 2.)
Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree.
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
boost::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:135
traits
Definition: chartTesting.h:28
typedef and functions to augment Eigen&#39;s VectorXd
void saveGraph(const std::string &s, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Definition: BayesNet-inst.h:38
int main()
#define GTSAM_PRINT(x)
Definition: Testable.h:41
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:155
boost::shared_ptr< This > shared_ptr
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
double evaluate(const DiscreteConditional::Values &values) const
Marginals marginals(graph, result)
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:46:25