testDiscreteMarginals.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * @file testDiscreteMarginals.cpp
14  * @date Jun 7, 2012
15  * @author Abhijit Kundu
16  * @author Richard Roberts
17  * @author Frank Dellaert
18  */
19 
21 
23 
24 using namespace std;
25 using namespace gtsam;
26 
27 /* ************************************************************************* */
29  size_t nrStates = 2;
30  DiscreteKey Cathy(1, nrStates), Heather(2, nrStates), Mark(3, nrStates),
31  Allison(4, nrStates);
33 
34  // add node potentials
35  graph.add(Cathy, "1 3");
36  graph.add(Heather, "9 1");
37  graph.add(Mark, "1 3");
38  graph.add(Allison, "9 1");
39 
40  // add edge potentials
41  graph.add(Cathy & Heather, "2 1 1 2");
42  graph.add(Heather & Mark, "2 1 1 2");
43  graph.add(Mark & Allison, "2 1 1 2");
44 
46  DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
48 
49  values[Cathy.first] = 0;
50  EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
51 
52  Vector actualCvector = marginals.marginalProbabilities(Cathy);
53  EXPECT(assert_equal(Vector2(0.359631, 0.640369), actualCvector, 1e-6));
54 
55  actualCvector = marginals.marginalProbabilities(Mark);
56  EXPECT(assert_equal(Vector2(0.48628, 0.51372), actualCvector, 1e-6));
57 }
58 
59 /* ************************************************************************* */
61 
62  const int nrNodes = 10;
63  const size_t nrStates = 7;
64 
65  // define variables
66  vector<DiscreteKey> key;
67  for (int i = 0; i < nrNodes; i++) {
68  DiscreteKey key_i(i, nrStates);
69  key.push_back(key_i);
70  }
71 
72  // create graph
74 
75  // add node potentials
76  graph.add(key[0], ".3 .6 .1 0 0 0 0");
77  for (int i = 1; i < nrNodes; i++)
78  graph.add(key[i], "1 1 1 1 1 1 1");
79 
80  const std::string edgePotential = ".08 .9 .01 0 0 0 .01 "
81  ".03 .95 .01 0 0 0 .01 "
82  ".06 .06 .75 .05 .05 .02 .01 "
83  "0 0 0 .3 .6 .09 .01 "
84  "0 0 0 .02 .95 .02 .01 "
85  "0 0 0 .01 .01 .97 .01 "
86  "0 0 0 0 0 0 1";
87 
88  // add edge potentials
89  for (int i = 0; i < nrNodes - 1; i++)
90  graph.add(key[i] & key[i + 1], edgePotential);
91 
93  DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
95 
96  values[key[2].first] = 0;
97  EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
98 }
99 
100 /* ************************************************************************* */
102 
103  const int nrNodes = 5;
104  const size_t nrStates = 2;
105 
106  // define variables
107  vector<DiscreteKey> key;
108  for (int i = 0; i < nrNodes; i++) {
109  DiscreteKey key_i(i, nrStates);
110  key.push_back(key_i);
111  }
112 
113  // create graph and add three truss potentials
115  graph.add(key[0] & key[2] & key[4],"2 2 2 2 1 1 1 1");
116  graph.add(key[1] & key[3] & key[4],"1 1 1 1 2 2 2 2");
117  graph.add(key[2] & key[3] & key[4],"1 1 1 1 1 1 1 1");
118  DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal();
119 // bayesTree->print("Bayes Tree");
120  typedef DiscreteBayesTreeClique Clique;
121 
122  Clique expected0(std::make_shared<DiscreteConditional>((key[0] | key[2], key[4]) = "2/1 2/1 2/1 2/1"));
123  Clique::shared_ptr actual0 = (*bayesTree)[0];
124  EXPECT(assert_equal(expected0, *actual0));
125 
126  Clique expected1(std::make_shared<DiscreteConditional>((key[1] | key[3], key[4]) = "1/2 1/2 1/2 1/2"));
127  Clique::shared_ptr actual1 = (*bayesTree)[1];
128 // EXPECT(assert_equal(expected1, *actual1)); // TODO, correct but fails
129 
130  // Create Marginals instance
132 
133  // test 0
134  DecisionTreeFactor expectedM0(key[0],"0.666667 0.333333");
136  EXPECT(assert_equal(expectedM0, *std::dynamic_pointer_cast<DecisionTreeFactor>(actualM0),1e-5));
137 
138  // test 1
139  DecisionTreeFactor expectedM1(key[1],"0.333333 0.666667");
141  EXPECT(assert_equal(expectedM1, *std::dynamic_pointer_cast<DecisionTreeFactor>(actualM1),1e-5));
142 }
143 
144 /* ************************************************************************* */
145 // Second truss example with non-trivial factors
147  const int nrNodes = 5;
148  const size_t nrStates = 2;
149 
150  // define variables
151  vector<DiscreteKey> key;
152  for (int i = 0; i < nrNodes; i++) {
153  DiscreteKey key_i(i, nrStates);
154  key.push_back(key_i);
155  }
156 
157  // create graph and add three truss potentials
159  graph.add(key[0] & key[2] & key[4], "1 2 3 4 5 6 7 8");
160  graph.add(key[1] & key[3] & key[4], "1 2 3 4 5 6 7 8");
161  graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
162 
163  // Calculate the marginals by brute force
164  auto allPosbValues = DiscreteValues::CartesianProduct(
165  key[0] & key[1] & key[2] & key[3] & key[4]);
166  Vector T = Z_5x1, F = Z_5x1;
167  for (size_t i = 0; i < allPosbValues.size(); ++i) {
168  DiscreteValues x = allPosbValues[i];
169  double px = graph(x);
170  for (size_t j = 0; j < 5; j++)
171  if (x[j])
172  T[j] += px;
173  else
174  F[j] += px;
175  }
176 
177  // Check all marginals given by a sequential solver and Marginals
178  // DiscreteSequentialSolver solver(graph);
180  for (size_t j = 0; j < 5; j++) {
181  double sum = T[j] + F[j];
182  T[j] /= sum;
183  F[j] /= sum;
184 
185  // Marginals
186  const vector<double> table{F[j], T[j]};
187  DecisionTreeFactor expectedM(key[j], table);
190  expectedM, *std::dynamic_pointer_cast<DecisionTreeFactor>(actualM)));
191  }
192 }
193 
194 /* ************************************************************************* */
195 int main() {
196  TestResult tr;
197  return TestRegistry::runAllTests(tr);
198 }
199 /* ************************************************************************* */
200 
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:98
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
gtsam::Vector2
Eigen::Vector2d Vector2
Definition: Vector.h:42
TestHarness.h
gtsam::DiscreteBayesTree::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesTree.h:80
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition: gnuplot_common_settings.hh:12
TEST_UNSAFE
TEST_UNSAFE(DiscreteMarginals, UGM_small)
Definition: testDiscreteMarginals.cpp:28
main
int main()
Definition: testDiscreteMarginals.cpp:195
gtsam::Vector
Eigen::VectorXd Vector
Definition: Vector.h:38
table
ArrayXXf table(10, 4)
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
gtsam::DiscreteMarginals
Definition: DiscreteMarginals.h:33
Eigen::Triplet< double >
gtsam::symbol_shorthand::F
Key F(std::uint64_t j)
Definition: inference/Symbol.h:153
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
DiscreteMarginals.h
A class for computing marginals in a DiscreteFactorGraph.
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
gtsam
traits
Definition: chartTesting.h:28
gtsam::DiscreteFactor::shared_ptr
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Definition: DiscreteFactor.h:44
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
leaf::values
leaf::MyValues values
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
gtsam::DiscreteBayesTreeClique
Definition: DiscreteBayesTree.h:39
gtsam::FactorGraph::add
IsDerived< DERIVEDFACTOR > add(std::shared_ptr< DERIVEDFACTOR > factor)
add is a synonym for push_back.
Definition: FactorGraph.h:171
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2
marginals
Marginals marginals(graph, result)
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
px
RealScalar RealScalar * px
Definition: level1_cplx_impl.h:28


gtsam
Author(s):
autogenerated on Tue Jun 25 2024 03:05:42