testDiscreteFactorGraph.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteFactorGraph.cpp
14  * @date Feb 14, 2011
15  * @author Duy-Nguyen Ta
16  */
17 
23 
25 
26 using namespace std;
27 using namespace gtsam;
28 
29 /* ************************************************************************* */
30 TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
31  DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
32 
34  graph.add(AI, "1 0 0 1");
35  graph.add(AI, "1 1 1 0");
36  graph.add(A & AI, "1 1 1 0 1 1 1 1 0 1 1 1");
37  graph.add(ME, "0 1 0 0");
38  graph.add(ME, "1 1 1 0");
39  graph.add(A & ME, "1 1 1 0 1 1 1 1 0 1 1 1");
40  graph.add(PC, "1 0 1 0");
41  graph.add(PC, "1 1 1 0");
42  graph.add(A & PC, "1 1 1 0 1 1 1 1 0 1 1 1");
43  graph.add(ME & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
44  graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
45  graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
46 
47  // Check MPE.
48  auto actualMPE = graph.optimize();
49  EXPECT(assert_equal({{0, 2}, {1, 1}, {2, 0}, {3, 0}}, actualMPE));
50 }
51 
52 /* ************************************************************************* */
54 TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
55 
56  // Three keys P1 and P2
57  DiscreteKey P1(0,2), P2(1,2), P3(2,3);
58 
59  // Create the DiscreteFactorGraph
61  graph.add(P1, "0.9 0.3");
62  graph.add(P2, "0.9 0.6");
63  graph.add(P1 & P2, "4 1 10 4");
64 
65  // Instantiate DiscreteValues
67  values[0] = 1;
68  values[1] = 1;
69 
70  // Check if graph evaluation works ( 0.3*0.6*4 )
71  EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9);
72 
73  // Creating a new test with third node and adding unary and ternary factors on it
74  graph.add(P3, "0.9 0.2 0.5");
75  graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12");
76 
77  // Below values lead to selecting the 8th index in the ternary factor table
78  values[0] = 1;
79  values[1] = 0;
80  values[2] = 1;
81 
82  // Check if graph evaluation works (0.3*0.9*1*0.2*8)
83  EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9);
84 
85  // Below values lead to selecting the 3rd index in the ternary factor table
86  values[0] = 0;
87  values[1] = 1;
88  values[2] = 0;
89 
90  // Check if graph evaluation works (0.9*0.6*1*0.9*4)
91  EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
92 
93  // Check if graph product works
95  EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
96 }
97 
98 /* ************************************************************************* */
100  // Declare keys and ordering
101  DiscreteKey C(0, 2), B(1, 2), A(2, 2);
102 
103  // A simple factor graph (A)-fAC-(C)-fBC-(B)
104  // with smoothness priors
106  graph.add(A & C, "3 1 1 3");
107  graph.add(C & B, "3 1 1 3");
108 
109  // Test EliminateDiscrete
110  const Ordering frontalKeys{0};
111  const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
112 
113  DecisionTreeFactor newFactor = *newFactorPtr;
114 
115  // Normalize newFactor by max for comparison with expected
116  auto normalization = newFactor.max(newFactor.size());
117 
118  newFactor = newFactor / *normalization;
119 
120  // Check Conditional
121  CHECK(conditional);
122  Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
123  DiscreteConditional expectedConditional(signature);
124  EXPECT(assert_equal(expectedConditional, *conditional));
125 
126  // Check Factor
127  CHECK(&newFactor);
128  DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
129  // Normalize by max.
130  normalization = expectedFactor.max(expectedFactor.size());
131  // Ensure normalization is correct.
132  expectedFactor = expectedFactor / *normalization;
133  EXPECT(assert_equal(expectedFactor, newFactor));
134 
135  // Test using elimination tree
136  const Ordering ordering{0, 1, 2};
137  DiscreteEliminationTree etree(graph, ordering);
138  const auto [actual, remainingGraph] = etree.eliminate(&EliminateDiscrete);
139 
140  // Check Bayes net
141  DiscreteBayesNet expectedBayesNet;
142  expectedBayesNet.add(signature);
143  expectedBayesNet.add(B | A = "5/3 3/5");
144  expectedBayesNet.add(A % "1/1");
145  EXPECT(assert_equal(expectedBayesNet, *actual));
146 
147  // Test eliminateSequential
149  EXPECT(assert_equal(expectedBayesNet, *actual2));
150 
151  // Test mpe
152  DiscreteValues mpe { {0, 0}, {1, 0}, {2, 0}};
153  auto actualMPE = graph.optimize();
154  EXPECT(assert_equal(mpe, actualMPE));
155  EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
156 
157  // Test sumProduct alias with all orderings:
158  auto mpeProbability = expectedBayesNet(mpe);
159  EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
160 
161  // Using custom ordering
162  DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
163  EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
164 
165  for (Ordering::OrderingType orderingType :
166  {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
167  Ordering::CUSTOM}) {
168  auto bayesNet = graph.sumProduct(orderingType);
169  EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
170  }
171 }
172 
173 /* ************************************************************************* */
175  // Declare a bunch of keys
176  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
177 
178  // Create Factor graph
180  graph.add(C & A, "0.2 0.8 0.3 0.7");
181  graph.add(C & B, "0.1 0.9 0.4 0.6");
182 
183  // Created expected MPE
184  DiscreteValues mpe{{0, 0}, {1, 1}, {2, 1}};
185 
186  // Do max-product with different orderings
187  for (Ordering::OrderingType orderingType :
188  {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
189  Ordering::CUSTOM}) {
190  DiscreteLookupDAG dag = graph.maxProduct(orderingType);
191  auto actualMPE = dag.argmax();
192  EXPECT(assert_equal(mpe, actualMPE));
193  auto actualMPE2 = graph.optimize(); // all in one
194  EXPECT(assert_equal(mpe, actualMPE2));
195  }
196 }
197 
198 /* ************************************************************************* */
199 TEST(DiscreteFactorGraph, marginalIsNotMPE) {
200  // Declare 2 keys
201  DiscreteKey A(0, 2), B(1, 2);
202 
203  // Create Bayes net such that marginal on A is bigger for 0 than 1, but the
204  // MPE does not have A=0.
205  DiscreteBayesNet bayesNet;
206  bayesNet.add(B | A = "1/1 1/2");
207  bayesNet.add(A % "10/9");
208 
209  // The expected MPE is A=1, B=1
210  DiscreteValues mpe { {0, 1}, {1, 1} };
211 
212  // Which we verify using max-product:
213  DiscreteFactorGraph graph(bayesNet);
214  auto actualMPE = graph.optimize();
215  EXPECT(assert_equal(mpe, actualMPE));
216  EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
217 }
218 
219 /* ************************************************************************* */
220 TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
221  // The factor graph in Darwiche09book, page 244
222  DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
223 
224  // Create Factor graph
226  graph.add(S, "0.55 0.45");
227  graph.add(S & C, "0.05 0.95 0.01 0.99");
228  graph.add(C & T1, "0.80 0.20 0.20 0.80");
229  graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
230  graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
231  graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
232 
233  DiscreteValues mpe { {0, 1}, {1, 1}, {2, 1}, {3, 1}, {4, 0}};
234  EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
235  // You can check visually by printing product:
236  // graph.product().print("Darwiche-product");
237 
238  // Check MPE.
239  auto actualMPE = graph.optimize();
240  EXPECT(assert_equal(mpe, actualMPE));
241 
242  // Check Bayes Net
243  const Ordering ordering{0, 1, 2, 3, 4};
244  auto chordal = graph.eliminateSequential(ordering);
245  EXPECT_LONGS_EQUAL(5, chordal->size());
246 
247  // Let us create the Bayes tree here, just for fun, because we don't use it
250  // bayesTree->print("Bayes Tree");
251  EXPECT_LONGS_EQUAL(2, bayesTree->size());
252 }
253 
254 /* ************************************************************************* */
256  // Create Factor graph
258  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
259  graph.add(C & A, "0.2 0.8 0.3 0.7");
260  graph.add(C & B, "0.1 0.9 0.4 0.6");
261 
262  string actual = graph.dot();
263  string expected =
264  "graph {\n"
265  " size=\"5,5\";\n"
266  "\n"
267  " var0[label=\"0\"];\n"
268  " var1[label=\"1\"];\n"
269  " var2[label=\"2\"];\n"
270  "\n"
271  " factor0[label=\"\", shape=point];\n"
272  " var0--factor0;\n"
273  " var1--factor0;\n"
274  " factor1[label=\"\", shape=point];\n"
275  " var0--factor1;\n"
276  " var2--factor1;\n"
277  "}\n";
278  EXPECT(actual == expected);
279 }
280 
281 /* ************************************************************************* */
282 TEST(DiscreteFactorGraph, DotWithNames) {
283  // Create Factor graph
285  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
286  graph.add(C & A, "0.2 0.8 0.3 0.7");
287  graph.add(C & B, "0.1 0.9 0.4 0.6");
288 
289  vector<string> names{"C", "A", "B"};
290  auto formatter = [names](Key key) { return names[key]; };
291  string actual = graph.dot(formatter);
292  string expected =
293  "graph {\n"
294  " size=\"5,5\";\n"
295  "\n"
296  " var0[label=\"C\"];\n"
297  " var1[label=\"A\"];\n"
298  " var2[label=\"B\"];\n"
299  "\n"
300  " factor0[label=\"\", shape=point];\n"
301  " var0--factor0;\n"
302  " var1--factor0;\n"
303  " factor1[label=\"\", shape=point];\n"
304  " var0--factor1;\n"
305  " var2--factor1;\n"
306  "}\n";
307  EXPECT(actual == expected);
308 }
309 
310 /* ************************************************************************* */
311 // Check markdown representation looks as expected.
313  // Create Factor graph
315  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
316  graph.add(C & A, "0.2 0.8 0.3 0.7");
317  graph.add(C & B, "0.1 0.9 0.4 0.6");
318 
319  string expected =
320  "`DiscreteFactorGraph` of size 2\n"
321  "\n"
322  "factor 0:\n"
323  "|C|A|value|\n"
324  "|:-:|:-:|:-:|\n"
325  "|0|0|0.2|\n"
326  "|0|1|0.8|\n"
327  "|1|0|0.3|\n"
328  "|1|1|0.7|\n"
329  "\n"
330  "factor 1:\n"
331  "|C|B|value|\n"
332  "|:-:|:-:|:-:|\n"
333  "|0|0|0.1|\n"
334  "|0|1|0.9|\n"
335  "|1|0|0.4|\n"
336  "|1|1|0.6|\n\n";
337  vector<string> names{"C", "A", "B"};
338  auto formatter = [names](Key key) { return names[key]; };
339  string actual = graph.markdown(formatter);
340  EXPECT(actual == expected);
341 
342  // Make sure values are correctly displayed.
344  values[0] = 1;
345  values[1] = 0;
346  EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
347 }
348 /* ************************************************************************* */
349 int main() {
350 TestResult tr;
351 return TestRegistry::runAllTests(tr);
352 }
353 /* ************************************************************************* */
354 
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
const gtsam::Symbol key('X', 0)
#define CHECK(condition)
Definition: Test.h:108
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const DiscreteFactor::Names &names={}) const
Render as markdown tables.
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
static int runAllTests(TestResult &result)
Matrix expected
Definition: testMatrix.cpp:971
size_t size() const
Definition: Factor.h:159
DecisionTreeFactor product() const
Definition: test.py:1
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
shared_ptr max(size_t nrFrontals) const
Create new factor by maximizing over all values with the same separator.
leaf::MyValues values
Definition: BFloat16.h:88
static const Pose3 T2(Rot3::Rodrigues(0.3, 0.2, 0.1), P2)
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
DiscreteKey S(1, 2)
NonlinearFactorGraph graph
Bayes network.
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
static enum @1107 ordering
static const Matrix93 P3
Definition: SO3.cpp:352
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
const KeyFormatter & formatter
#define EXPECT(condition)
Definition: Test.h:150
TEST(DiscreteFactorGraph, test)
std::shared_ptr< BayesNetType > eliminateSequential(OptionalOrderingType orderingType={}, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex={}) const
Array< double, 1, 3 > e(1./3., 0.5, 2.)
Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree.
OrderingType
Type of ordering to use.
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, const DotWriter &writer=DotWriter()) const
Output to graphviz format, stream version.
DiscreteBayesNet sumProduct(OptionalOrderingType orderingType={}) const
Implement the sum-product algorithm.
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
static const Point3 P2(3.5,-8.2, 4.2)
traits
Definition: chartTesting.h:28
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler)
static const Similarity3 T1(R, Point3(3.5, -8.2, 4.2), 1)
std::shared_ptr< This > shared_ptr
DiscreteLookupDAG maxProduct(OptionalOrderingType orderingType={}) const
Implement the max-product algorithm.
std::pair< std::shared_ptr< BayesNetType >, std::shared_ptr< FactorGraphType > > eliminate(Eliminate function) const
void add(const DiscreteKey &key, const std::string &spec)
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std::shared_ptr< BayesTreeType > eliminateMultifrontal(OptionalOrderingType orderingType={}, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex={}) const
Elimination tree for discrete factors.
DiscreteValues argmax(DiscreteValues given=DiscreteValues()) const
argmax by back-substitution, optionally given certain variables.
std::shared_ptr< This > shared_ptr
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
void product(const MatrixType &m)
Definition: product.h:20


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:01