testDiscreteFactorGraph.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteFactorGraph.cpp
14  * @date Feb 14, 2011
15  * @author Duy-Nguyen Ta
16  */
17 
25 #include <gtsam/inference/Symbol.h>
26 
27 using namespace std;
28 using namespace gtsam;
29 
31 
32 /* ************************************************************************* */
33 TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
34  DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
35 
37  graph.add(AI, "1 0 0 1");
38  graph.add(AI, "1 1 1 0");
39  graph.add(A & AI, "1 1 1 0 1 1 1 1 0 1 1 1");
40  graph.add(ME, "0 1 0 0");
41  graph.add(ME, "1 1 1 0");
42  graph.add(A & ME, "1 1 1 0 1 1 1 1 0 1 1 1");
43  graph.add(PC, "1 0 1 0");
44  graph.add(PC, "1 1 1 0");
45  graph.add(A & PC, "1 1 1 0 1 1 1 1 0 1 1 1");
46  graph.add(ME & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
47  graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
48  graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
49 
50  // Check MPE.
51  auto actualMPE = graph.optimize();
52  EXPECT(assert_equal({{0, 2}, {1, 1}, {2, 0}, {3, 0}}, actualMPE));
53 }
54 
55 /* ************************************************************************* */
57 TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
58 
59  // Three keys P1 and P2
60  DiscreteKey P1(0,2), P2(1,2), P3(2,3);
61 
62  // Create the DiscreteFactorGraph
64  graph.add(P1, "0.9 0.3");
65  graph.add(P2, "0.9 0.6");
66  graph.add(P1 & P2, "4 1 10 4");
67 
68  // Instantiate DiscreteValues
70  values[0] = 1;
71  values[1] = 1;
72 
73  // Check if graph evaluation works ( 0.3*0.6*4 )
74  EXPECT_DOUBLES_EQUAL( .72, graph(values), 1e-9);
75 
76  // Creating a new test with third node and adding unary and ternary factors on it
77  graph.add(P3, "0.9 0.2 0.5");
78  graph.add(P1 & P2 & P3, "1 2 3 4 5 6 7 8 9 10 11 12");
79 
80  // Below values lead to selecting the 8th index in the ternary factor table
81  values[0] = 1;
82  values[1] = 0;
83  values[2] = 1;
84 
85  // Check if graph evaluation works (0.3*0.9*1*0.2*8)
86  EXPECT_DOUBLES_EQUAL( 4.32, graph(values), 1e-9);
87 
88  // Below values lead to selecting the 3rd index in the ternary factor table
89  values[0] = 0;
90  values[1] = 1;
91  values[2] = 0;
92 
93  // Check if graph evaluation works (0.9*0.6*1*0.9*4)
94  EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
95 
96  // Check if graph product works
97  DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
98  EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
99 }
100 
101 /* ************************************************************************* */
103  // Declare keys and ordering
104  DiscreteKey C(0, 2), B(1, 2), A(2, 2);
105 
106  // A simple factor graph (A)-fAC-(C)-fBC-(B)
107  // with smoothness priors
109  graph.add(A & C, "3 1 1 3");
110  graph.add(C & B, "3 1 1 3");
111 
112  // Test EliminateDiscrete
113  const Ordering frontalKeys{0};
114  const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
115 
116  DecisionTreeFactor newFactor =
117  *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
118 
119  // Normalize newFactor by max for comparison with expected
120  auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
121 
122  newFactor = newFactor / denominator;
123 
124  // Check Conditional
125  CHECK(conditional);
126  Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
127  DiscreteConditional expectedConditional(signature);
128  EXPECT(assert_equal(expectedConditional, *conditional));
129 
130  // Check Factor
131  CHECK(&newFactor);
132  DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
133  // Normalize by max.
134  denominator =
135  expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
136  // Ensure denominator is correct.
137  expectedFactor = expectedFactor / denominator;
138  EXPECT(assert_equal(expectedFactor, newFactor));
139 
140  // Test using elimination tree
141  const Ordering ordering{0, 1, 2};
143  const auto [actual, remainingGraph] = etree.eliminate(&EliminateDiscrete);
144 
145  // Check Bayes net
146  DiscreteBayesNet expectedBayesNet;
147  expectedBayesNet.add(signature);
148  expectedBayesNet.add(B | A = "5/3 3/5");
149  expectedBayesNet.add(A % "1/1");
150  EXPECT(assert_equal(expectedBayesNet, *actual));
151 
152  // Test eliminateSequential
153  DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
154  EXPECT(assert_equal(expectedBayesNet, *actual2));
155 
156  // Test mpe
157  DiscreteValues mpe { {0, 0}, {1, 0}, {2, 0}};
158  auto actualMPE = graph.optimize();
159  EXPECT(assert_equal(mpe, actualMPE));
160  EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
161 
162  // Test sumProduct alias with all orderings:
163  auto mpeProbability = expectedBayesNet(mpe);
164  EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression
165 
166  // Using custom ordering
167  DiscreteBayesNet bayesNet = graph.sumProduct(ordering);
168  EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
169 
170  for (Ordering::OrderingType orderingType :
171  {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
172  Ordering::CUSTOM}) {
173  auto bayesNet = graph.sumProduct(orderingType);
174  EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5);
175  }
176 }
177 
178 /* ************************************************************************* */
180  // Declare a bunch of keys
181  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
182 
183  // Create Factor graph
185  graph.add(C & A, "0.2 0.8 0.3 0.7");
186  graph.add(C & B, "0.1 0.9 0.4 0.6");
187 
188  // Created expected MPE
189  DiscreteValues mpe{{0, 0}, {1, 1}, {2, 1}};
190 
191  // Do max-product with different orderings
192  for (Ordering::OrderingType orderingType :
193  {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
194  Ordering::CUSTOM}) {
195  DiscreteLookupDAG dag = graph.maxProduct(orderingType);
196  auto actualMPE = dag.argmax();
197  EXPECT(assert_equal(mpe, actualMPE));
198  auto actualMPE2 = graph.optimize(); // all in one
199  EXPECT(assert_equal(mpe, actualMPE2));
200  }
201 }
202 
203 /* ************************************************************************* */
204 TEST(DiscreteFactorGraph, marginalIsNotMPE) {
205  // Declare 2 keys
206  DiscreteKey A(0, 2), B(1, 2);
207 
208  // Create Bayes net such that marginal on A is bigger for 0 than 1, but the
209  // MPE does not have A=0.
211  bayesNet.add(B | A = "1/1 1/2");
212  bayesNet.add(A % "10/9");
213 
214  // The expected MPE is A=1, B=1
215  DiscreteValues mpe { {0, 1}, {1, 1} };
216 
217  // Which we verify using max-product:
219  auto actualMPE = graph.optimize();
220  EXPECT(assert_equal(mpe, actualMPE));
221  EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
222 }
223 
224 /* ************************************************************************* */
225 TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
226  // The factor graph in Darwiche09book, page 244
227  DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
228 
229  // Create Factor graph
231  graph.add(S, "0.55 0.45");
232  graph.add(S & C, "0.05 0.95 0.01 0.99");
233  graph.add(C & T1, "0.80 0.20 0.20 0.80");
234  graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
235  graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
236  graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
237 
238  DiscreteValues mpe { {0, 1}, {1, 1}, {2, 1}, {3, 1}, {4, 0}};
239  EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
240  // You can check visually by printing product:
241  // graph.product().print("Darwiche-product");
242 
243  // Check MPE.
244  auto actualMPE = graph.optimize();
245  EXPECT(assert_equal(mpe, actualMPE));
246 
247  // Check Bayes Net
248  const Ordering ordering{0, 1, 2, 3, 4};
249  auto chordal = graph.eliminateSequential(ordering);
250  EXPECT_LONGS_EQUAL(5, chordal->size());
251 
252  // Let us create the Bayes tree here, just for fun, because we don't use it
254  graph.eliminateMultifrontal(ordering);
255  // bayesTree->print("Bayes Tree");
256  EXPECT_LONGS_EQUAL(2, bayesTree->size());
257 }
258 
259 /* ************************************************************************* */
261  // Create Factor graph
263  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
264  graph.add(C & A, "0.2 0.8 0.3 0.7");
265  graph.add(C & B, "0.1 0.9 0.4 0.6");
266 
267  string actual = graph.dot();
268  string expected =
269  "graph {\n"
270  " size=\"5,5\";\n"
271  "\n"
272  " var0[label=\"0\"];\n"
273  " var1[label=\"1\"];\n"
274  " var2[label=\"2\"];\n"
275  "\n"
276  " factor0[label=\"\", shape=point];\n"
277  " var0--factor0;\n"
278  " var1--factor0;\n"
279  " factor1[label=\"\", shape=point];\n"
280  " var0--factor1;\n"
281  " var2--factor1;\n"
282  "}\n";
283  EXPECT(actual == expected);
284 }
285 
286 /* ************************************************************************* */
287 TEST(DiscreteFactorGraph, DotWithNames) {
288  // Create Factor graph
290  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
291  graph.add(C & A, "0.2 0.8 0.3 0.7");
292  graph.add(C & B, "0.1 0.9 0.4 0.6");
293 
294  vector<string> names{"C", "A", "B"};
295  auto formatter = [names](Key key) { return names[key]; };
296  string actual = graph.dot(formatter);
297  string expected =
298  "graph {\n"
299  " size=\"5,5\";\n"
300  "\n"
301  " var0[label=\"C\"];\n"
302  " var1[label=\"A\"];\n"
303  " var2[label=\"B\"];\n"
304  "\n"
305  " factor0[label=\"\", shape=point];\n"
306  " var0--factor0;\n"
307  " var1--factor0;\n"
308  " factor1[label=\"\", shape=point];\n"
309  " var0--factor1;\n"
310  " var2--factor1;\n"
311  "}\n";
312  EXPECT(actual == expected);
313 }
314 
315 /* ************************************************************************* */
316 // Check markdown representation looks as expected.
318  // Create Factor graph
320  DiscreteKey C(0, 2), A(1, 2), B(2, 2);
321  graph.add(C & A, "0.2 0.8 0.3 0.7");
322  graph.add(C & B, "0.1 0.9 0.4 0.6");
323 
324  string expected =
325  "`DiscreteFactorGraph` of size 2\n"
326  "\n"
327  "factor 0:\n"
328  "|C|A|value|\n"
329  "|:-:|:-:|:-:|\n"
330  "|0|0|0.2|\n"
331  "|0|1|0.8|\n"
332  "|1|0|0.3|\n"
333  "|1|1|0.7|\n"
334  "\n"
335  "factor 1:\n"
336  "|C|B|value|\n"
337  "|:-:|:-:|:-:|\n"
338  "|0|0|0.1|\n"
339  "|0|1|0.9|\n"
340  "|1|0|0.4|\n"
341  "|1|1|0.6|\n\n";
342  vector<string> names{"C", "A", "B"};
343  auto formatter = [names](Key key) { return names[key]; };
344  string actual = graph.markdown(formatter);
345  EXPECT(actual == expected);
346 
347  // Make sure values are correctly displayed.
349  values[0] = 1;
350  values[1] = 0;
351  EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
352 }
353 
354 /* ************************************************************************* */
355 int main() {
356 TestResult tr;
357 return TestRegistry::runAllTests(tr);
358 }
359 /* ************************************************************************* */
360 
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
gtsam::EliminateDiscrete
std::pair< DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
Definition: DiscreteFactorGraph.cpp:224
gtsam::markdown
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
Definition: DiscreteValues.cpp:130
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
B
Matrix< SCALARB, Dynamic, Dynamic, opt_B > B
Definition: bench_gemm.cpp:49
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
EXPECT_LONGS_EQUAL
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
gtsam::DiscreteBayesTree::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesTree.h:80
DiscreteFactorGraph.h
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::DiscreteLookupDAG
Definition: DiscreteLookupDAG.h:90
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
P3
static double P3[]
Definition: jv.c:563
gtsam::NonlinearFactorGraph::dot
void dot(std::ostream &os, const Values &values, const KeyFormatter &keyFormatter=DefaultKeyFormatter, const GraphvizFormatting &writer=GraphvizFormatting()) const
Output to graphviz format, stream version, with Values/extra options.
Definition: NonlinearFactorGraph.cpp:102
TestableAssertions.h
Provides additional testing facilities for common data structures.
A
Matrix< SCALARA, Dynamic, Dynamic, opt_A > A
Definition: bench_gemm.cpp:48
test
Definition: test.py:1
different_sigmas::bayesNet
const HybridBayesNet bayesNet
Definition: testHybridBayesNet.cpp:242
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
main
int main()
Definition: testDiscreteFactorGraph.cpp:355
gtsam::DiscreteEliminationTree
Elimination tree for discrete factors.
Definition: DiscreteEliminationTree.h:31
P1
static double P1[]
Definition: jv.c:552
DiscreteFactor.h
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
gtsam::EliminationTree::eliminate
std::pair< std::shared_ptr< BayesNetType >, std::shared_ptr< FactorGraphType > > eliminate(Eliminate function) const
Definition: EliminationTree-inst.h:225
Symbol.h
T2
static const Pose3 T2(Rot3::Rodrigues(0.3, 0.2, 0.1), P2)
DiscreteEliminationTree.h
BayesNet.h
Bayes network.
DiscreteBayesTree.h
Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree.
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
ordering
static enum @1096 ordering
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
gtsam::Ordering::OrderingType
OrderingType
Type of ordering to use.
Definition: inference/Ordering.h:40
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
TEST
TEST(DiscreteFactorGraph, test)
Definition: testDiscreteFactorGraph.cpp:102
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
TEST_UNSAFE
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler)
Definition: testDiscreteFactorGraph.cpp:33
gtsam::DiscreteBayesNet::add
void add(const DiscreteKey &key, const std::string &spec)
Definition: DiscreteBayesNet.h:85
C
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
gtsam
traits
Definition: SFMdata.h:40
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
CHECK
#define CHECK(condition)
Definition: Test.h:108
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
gtsam::DiscreteBayesNet::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesNet.h:43
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:41
product
void product(const MatrixType &m)
Definition: product.h:20
gtsam::DecisionTreeFactor::max
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override
Create new factor by maximizing over all values with the same separator.
Definition: DecisionTreeFactor.h:205
gtsam::DiscreteLookupDAG::argmax
DiscreteValues argmax(DiscreteValues given=DiscreteValues()) const
argmax by back-substitution, optionally given certain variables.
Definition: DiscreteLookupDAG.cpp:121
T1
static const Similarity3 T1(R, Point3(3.5, -8.2, 4.2), 1)
gtsam::FactorGraph::add
IsDerived< DERIVEDFACTOR > add(std::shared_ptr< DERIVEDFACTOR > factor)
add is a synonym for push_back.
Definition: FactorGraph.h:171
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2
gtsam::Factor::size
size_t size() const
Definition: Factor.h:160
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::Signature
Definition: Signature.h:54
S
DiscreteKey S(1, 2)
P2
static double P2[]
Definition: jv.c:557
M
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:07:10