testDiscreteFactorGraph.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteFactorGraph.cpp
14  * @date Feb 14, 2011
15  * @author Duy-Nguyen Ta
16  */
17 
25 #include <gtsam/inference/Symbol.h>
26 
27 using namespace std;
28 using namespace gtsam;
29 
31 
32 /* ************************************************************************* */
33 TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
34  DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
35 
37  graph.add(AI, "1 0 0 1");
38  graph.add(AI, "1 1 1 0");
39  graph.add(A & AI, "1 1 1 0 1 1 1 1 0 1 1 1");
40  graph.add(ME, "0 1 0 0");
41  graph.add(ME, "1 1 1 0");
42  graph.add(A & ME, "1 1 1 0 1 1 1 1 0 1 1 1");
43  graph.add(PC, "1 0 1 0");
44  graph.add(PC, "1 1 1 0");
45  graph.add(A & PC, "1 1 1 0 1 1 1 1 0 1 1 1");
46  graph.add(ME & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
47  graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
48  graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
49 
50  // Check MPE.
51  auto actualMPE = graph.optimize();
52  EXPECT(assert_equal({{0, 2}, {1, 1}, {2, 0}, {3, 0}}, actualMPE));
53 }
54 
55 /* ************************************************************************* */
57 TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
58 
59  // Three keys P1 and P2
60  DiscreteKey P1(0,2), P2(1,2), P3(2,3);
61 
62  // Create the DiscreteFactorGraph
64  graph.add(P1, "0.9 0.3");
65  graph.add(P2, "0.9 0.6");
66  graph.add(P1 & P2, "4 1 10 4");
67 
68  // Instantiate DiscreteValues
70  values[0] = 1;
71  values[1] = 1;
72 
73  // Check if graph evaluation works ( 0.3*0.6*4 )
74  EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9);
75 
76  // Creating a new test with third node and adding unary and ternary factors on it
77  graph.add(P3, "0.9 0.2 0.5");
78  graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12");
79 
80  // Below values lead to selecting the 8th index in the ternary factor table
81  values[0] = 1;
82  values[1] = 0;
83  values[2] = 1;
84 
85  // Check if graph evaluation works (0.3*0.9*1*0.2*8)
86  EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9);
87 
88  // Below values lead to selecting the 3rd index in the ternary factor table
89  values[0] = 0;
90  values[1] = 1;
91  values[2] = 0;
92 
93  // Check if graph evaluation works (0.9*0.6*1*0.9*4)
94  EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
95 
96  // Check if graph product works
97  DecisionTreeFactor product = graph.product();
98  EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
99 }
100 
101 /* ************************************************************************* */
103  // Declare keys and ordering
104  DiscreteKey C(0, 2), B(1, 2), A(2, 2);
105 
106  // A simple factor graph (A)-fAC-(C)-fBC-(B)
107  // with smoothness priors
109  graph.add(A & C, "3 1 1 3");
110  graph.add(C & B, "3 1 1 3");
111 
112  // Test EliminateDiscrete
113  const Ordering frontalKeys{0};
114  const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
115 
116  DecisionTreeFactor newFactor = *newFactorPtr;
117 
118  // Normalize newFactor by max for comparison with expected
119  auto normalization = newFactor.max(newFactor.size());
120 
121  newFactor = newFactor / *normalization;
122 
123  // Check Conditional
124  CHECK(conditional);
125  Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
126  DiscreteConditional expectedConditional(signature);
127  EXPECT(assert_equal(expectedConditional, *conditional));
128 
129  // Check Factor
130  CHECK(&newFactor);
131  DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
132  // Normalize by max.
133  normalization = expectedFactor.max(expectedFactor.size());
134  // Ensure normalization is correct.
135  expectedFactor = expectedFactor / *normalization;
136  EXPECT(assert_equal(expectedFactor, newFactor));
137 
138  // Test using elimination tree
139  const Ordering ordering{0, 1, 2};
141  const auto [actual, remainingGraph] = etree.eliminate(&EliminateDiscrete);
142 
143  // Check Bayes net
144  DiscreteBayesNet expectedBayesNet;
145  expectedBayesNet.add(signature);
146  expectedBayesNet.add(B | A = "5/3 3/5");
147  expectedBayesNet.add(A % "1/1");
148  EXPECT(assert_equal(expectedBayesNet, *actual));
149 
150  // Test eliminateSequential
151  DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
152  EXPECT(assert_equal(expectedBayesNet, *actual2));
153 
154  // Test mpe
155  DiscreteValues mpe { {0, 0}, {1, 0}, {2, 0}};
156  auto actualMPE = graph.optimize();
157  EXPECT(assert_equal(mpe, actualMPE));
158  EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
159 
160  // Test sumProduct alias with all orderings:
161  auto mpeProbability = expectedBayesNet(mpe);
162  EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
163 
164  // Using custom ordering
165  DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
166  EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
167 
168  for (Ordering::OrderingType orderingType :
169  {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
170  Ordering::CUSTOM}) {
171  auto bayesNet = graph.sumProduct(orderingType);
172  EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
173  }
174 }
175 
176 /* ************************************************************************* */
178  // Declare a bunch of keys
179  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
180 
181  // Create Factor graph
183  graph.add(C & A, "0.2 0.8 0.3 0.7");
184  graph.add(C & B, "0.1 0.9 0.4 0.6");
185 
186  // Created expected MPE
187  DiscreteValues mpe{{0, 0}, {1, 1}, {2, 1}};
188 
189  // Do max-product with different orderings
190  for (Ordering::OrderingType orderingType :
191  {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
192  Ordering::CUSTOM}) {
193  DiscreteLookupDAG dag = graph.maxProduct(orderingType);
194  auto actualMPE = dag.argmax();
195  EXPECT(assert_equal(mpe, actualMPE));
196  auto actualMPE2 = graph.optimize(); // all in one
197  EXPECT(assert_equal(mpe, actualMPE2));
198  }
199 }
200 
201 /* ************************************************************************* */
202 TEST(DiscreteFactorGraph, marginalIsNotMPE) {
203  // Declare 2 keys
204  DiscreteKey A(0, 2), B(1, 2);
205 
206  // Create Bayes net such that marginal on A is bigger for 0 than 1, but the
207  // MPE does not have A=0.
208  DiscreteBayesNet bayesNet;
209  bayesNet.add(B | A = "1/1 1/2");
210  bayesNet.add(A % "10/9");
211 
212  // The expected MPE is A=1, B=1
213  DiscreteValues mpe { {0, 1}, {1, 1} };
214 
215  // Which we verify using max-product:
216  DiscreteFactorGraph graph(bayesNet);
217  auto actualMPE = graph.optimize();
218  EXPECT(assert_equal(mpe, actualMPE));
219  EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
220 }
221 
222 /* ************************************************************************* */
223 TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
224  // The factor graph in Darwiche09book, page 244
225  DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
226 
227  // Create Factor graph
229  graph.add(S, "0.55 0.45");
230  graph.add(S & C, "0.05 0.95 0.01 0.99");
231  graph.add(C & T1, "0.80 0.20 0.20 0.80");
232  graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
233  graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
234  graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
235 
236  DiscreteValues mpe { {0, 1}, {1, 1}, {2, 1}, {3, 1}, {4, 0}};
237  EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
238  // You can check visually by printing product:
239  // graph.product().print("Darwiche-product");
240 
241  // Check MPE.
242  auto actualMPE = graph.optimize();
243  EXPECT(assert_equal(mpe, actualMPE));
244 
245  // Check Bayes Net
246  const Ordering ordering{0, 1, 2, 3, 4};
247  auto chordal = graph.eliminateSequential(ordering);
248  EXPECT_LONGS_EQUAL(5, chordal->size());
249 
250  // Let us create the Bayes tree here, just for fun, because we don't use it
252  graph.eliminateMultifrontal(ordering);
253  // bayesTree->print("Bayes Tree");
254  EXPECT_LONGS_EQUAL(2, bayesTree->size());
255 }
256 
257 /* ************************************************************************* */
259  // Create Factor graph
261  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
262  graph.add(C & A, "0.2 0.8 0.3 0.7");
263  graph.add(C & B, "0.1 0.9 0.4 0.6");
264 
265  string actual = graph.dot();
266  string expected =
267  "graph {\n"
268  " size=\"5,5\";\n"
269  "\n"
270  " var0[label=\"0\"];\n"
271  " var1[label=\"1\"];\n"
272  " var2[label=\"2\"];\n"
273  "\n"
274  " factor0[label=\"\", shape=point];\n"
275  " var0--factor0;\n"
276  " var1--factor0;\n"
277  " factor1[label=\"\", shape=point];\n"
278  " var0--factor1;\n"
279  " var2--factor1;\n"
280  "}\n";
281  EXPECT(actual == expected);
282 }
283 
284 /* ************************************************************************* */
285 TEST(DiscreteFactorGraph, DotWithNames) {
286  // Create Factor graph
288  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
289  graph.add(C & A, "0.2 0.8 0.3 0.7");
290  graph.add(C & B, "0.1 0.9 0.4 0.6");
291 
292  vector<string> names{"C", "A", "B"};
293  auto formatter = [names](Key key) { return names[key]; };
294  string actual = graph.dot(formatter);
295  string expected =
296  "graph {\n"
297  " size=\"5,5\";\n"
298  "\n"
299  " var0[label=\"C\"];\n"
300  " var1[label=\"A\"];\n"
301  " var2[label=\"B\"];\n"
302  "\n"
303  " factor0[label=\"\", shape=point];\n"
304  " var0--factor0;\n"
305  " var1--factor0;\n"
306  " factor1[label=\"\", shape=point];\n"
307  " var0--factor1;\n"
308  " var2--factor1;\n"
309  "}\n";
310  EXPECT(actual == expected);
311 }
312 
313 /* ************************************************************************* */
314 // Check markdown representation looks as expected.
316  // Create Factor graph
318  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
319  graph.add(C & A, "0.2 0.8 0.3 0.7");
320  graph.add(C & B, "0.1 0.9 0.4 0.6");
321 
322  string expected =
323  "`DiscreteFactorGraph` of size 2\n"
324  "\n"
325  "factor 0:\n"
326  "|C|A|value|\n"
327  "|:-:|:-:|:-:|\n"
328  "|0|0|0.2|\n"
329  "|0|1|0.8|\n"
330  "|1|0|0.3|\n"
331  "|1|1|0.7|\n"
332  "\n"
333  "factor 1:\n"
334  "|C|B|value|\n"
335  "|:-:|:-:|:-:|\n"
336  "|0|0|0.1|\n"
337  "|0|1|0.9|\n"
338  "|1|0|0.4|\n"
339  "|1|1|0.6|\n\n";
340  vector<string> names{"C", "A", "B"};
341  auto formatter = [names](Key key) { return names[key]; };
342  string actual = graph.markdown(formatter);
343  EXPECT(actual == expected);
344 
345  // Make sure values are correctly displayed.
347  values[0] = 1;
348  values[1] = 0;
349  EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
350 }
351 
352 /* ************************************************************************* */
353 int main() {
354 TestResult tr;
355 return TestRegistry::runAllTests(tr);
356 }
357 /* ************************************************************************* */
358 
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
gtsam::markdown
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
Definition: DiscreteValues.cpp:130
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
B
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:98
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
EXPECT_LONGS_EQUAL
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
gtsam::DiscreteBayesTree::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesTree.h:80
DiscreteFactorGraph.h
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::DiscreteLookupDAG
Definition: DiscreteLookupDAG.h:77
P3
static double P3[]
Definition: jv.c:563
gtsam::NonlinearFactorGraph::dot
void dot(std::ostream &os, const Values &values, const KeyFormatter &keyFormatter=DefaultKeyFormatter, const GraphvizFormatting &writer=GraphvizFormatting()) const
Output to graphviz format, stream version, with Values/extra options.
Definition: NonlinearFactorGraph.cpp:102
TestableAssertions.h
Provides additional testing facilities for common data structures.
A
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
test
Definition: test.py:1
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
main
int main()
Definition: testDiscreteFactorGraph.cpp:353
gtsam::DiscreteEliminationTree
Elimination tree for discrete factors.
Definition: DiscreteEliminationTree.h:31
gtsam::EliminateDiscrete
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
Definition: DiscreteFactorGraph.cpp:205
P1
static double P1[]
Definition: jv.c:552
DiscreteFactor.h
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
gtsam::EliminationTree::eliminate
std::pair< std::shared_ptr< BayesNetType >, std::shared_ptr< FactorGraphType > > eliminate(Eliminate function) const
Definition: EliminationTree-inst.h:224
Symbol.h
T2
static const Pose3 T2(Rot3::Rodrigues(0.3, 0.2, 0.1), P2)
DiscreteEliminationTree.h
BayesNet.h
Bayes network.
DiscreteBayesTree.h
Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree.
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
ordering
static enum @1096 ordering
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
gtsam::Ordering::OrderingType
OrderingType
Type of ordering to use.
Definition: inference/Ordering.h:40
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
TEST
TEST(DiscreteFactorGraph, test)
Definition: testDiscreteFactorGraph.cpp:102
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
TEST_UNSAFE
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler)
Definition: testDiscreteFactorGraph.cpp:33
gtsam::DiscreteBayesNet::add
void add(const DiscreteKey &key, const std::string &spec)
Definition: DiscreteBayesNet.h:85
C
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
gtsam
traits
Definition: chartTesting.h:28
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
leaf::values
leaf::MyValues values
CHECK
#define CHECK(condition)
Definition: Test.h:108
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
gtsam::DecisionTreeFactor::max
shared_ptr max(size_t nrFrontals) const
Create new factor by maximizing over all values with the same separator.
Definition: DecisionTreeFactor.h:172
gtsam::DiscreteBayesNet::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesNet.h:43
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
product
void product(const MatrixType &m)
Definition: product.h:20
gtsam::DiscreteLookupDAG::argmax
DiscreteValues argmax(DiscreteValues given=DiscreteValues()) const
argmax by back-substitution, optionally given certain variables.
Definition: DiscreteLookupDAG.cpp:120
T1
static const Similarity3 T1(R, Point3(3.5, -8.2, 4.2), 1)
gtsam::FactorGraph::add
IsDerived< DERIVEDFACTOR > add(std::shared_ptr< DERIVEDFACTOR > factor)
add is a synonym for push_back.
Definition: FactorGraph.h:171
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2
gtsam::Factor::size
size_t size() const
Definition: Factor.h:159
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::Signature
Definition: Signature.h:54
S
DiscreteKey S(1, 2)
P2
static double P2[]
Definition: jv.c:557
M
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51


gtsam
Author(s):
autogenerated on Tue Jun 25 2024 03:05:42