testDiscreteBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testDiscreteBayesNet.cpp
14  *
15  * @date Feb 27, 2011
16  * @author Frank Dellaert
17  */
18 
22 #include <gtsam/base/debug.h>
23 #include <gtsam/base/Testable.h>
24 #include <gtsam/base/Vector.h>
25 
27 
28 
29 #include <boost/assign/list_inserter.hpp>
30 #include <boost/assign/std/map.hpp>
31 
32 using namespace boost::assign;
33 
34 #include <iostream>
35 #include <string>
36 #include <vector>
37 
38 using namespace std;
39 using namespace gtsam;
40 
41 /* ************************************************************************* */
42 TEST(DiscreteBayesNet, bayesNet) {
43  DiscreteBayesNet bayesNet;
44  DiscreteKey Parent(0, 2), Child(1, 2);
45 
46  auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
47  CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
49  bayesNet.push_back(prior);
50 
51  auto conditional =
52  boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
53  EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
54  Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
55  CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
56  bayesNet.push_back(conditional);
57 
58  DiscreteFactorGraph fg(bayesNet);
59  LONGS_EQUAL(2, fg.back()->size());
60 
61  // Check the marginals
62  const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2};
64  for (size_t j = 0; j < 2; j++) {
65  Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2));
66  EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3);
67  EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9);
68  }
69 }
70 
71 /* ************************************************************************* */
73  DiscreteBayesNet asia;
74  DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
75  Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
76 
77  asia.add(Asia % "99/1");
78  asia.add(Smoking % "50/50");
79 
80  asia.add(Tuberculosis | Asia = "99/1 95/5");
81  asia.add(LungCancer | Smoking = "99/1 90/10");
82  asia.add(Bronchitis | Smoking = "70/30 40/60");
83 
84  asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
85 
86  asia.add(XRay | Either = "95/5 2/98");
87  asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
88 
89  // Convert to factor graph
90  DiscreteFactorGraph fg(asia);
91  LONGS_EQUAL(3, fg.back()->size());
92 
93  // Check the marginals we know (of the parent-less nodes)
95  Vector2 va(0.99, 0.01), vs(0.5, 0.5);
96  EXPECT(assert_equal(va, marginals.marginalProbabilities(Asia)));
97  EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
98 
99  // Create solver and eliminate
101  ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
102  DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
103  DiscreteConditional expected2(Bronchitis % "11/9");
104  EXPECT(assert_equal(expected2, *chordal->back()));
105 
106  // solve
107  DiscreteFactor::sharedValues actualMPE = chordal->optimize();
108  DiscreteFactor::Values expectedMPE;
109  insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
110  Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
111  LungCancer.first, 0)(Bronchitis.first, 0);
112  EXPECT(assert_equal(expectedMPE, *actualMPE));
113 
114  // add evidence, we were in Asia and we have dyspnea
115  fg.add(Asia, "0 1");
116  fg.add(Dyspnea, "0 1");
117 
118  // solve again, now with evidence
119  DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
120  DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
121  DiscreteFactor::Values expectedMPE2;
122  insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
123  Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
124  LungCancer.first, 0)(Bronchitis.first, 1);
125  EXPECT(assert_equal(expectedMPE2, *actualMPE2));
126 
127  // now sample from it
128  DiscreteFactor::Values expectedSample;
129  SETDEBUG("DiscreteConditional::sample", false);
130  insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
131  Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
132  LungCancer.first, 1)(Bronchitis.first, 0);
133  DiscreteFactor::sharedValues actualSample = chordal2->sample();
134  EXPECT(assert_equal(expectedSample, *actualSample));
135 }
136 
137 /* ************************************************************************* */
139  DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
140 
141  DiscreteBayesNet bn;
142 
143  // try logic
144  bn.add((E | T, L) = "OR");
145  bn.add((E | T, L) = "AND");
146 
147  // try multivalued
148  bn.add(C % "1/1/2");
149  bn.add(C | S = "1/1/2 5/2/3");
150 }
151 
152 /* ************************************************************************* */
153 int main() {
154  TestResult tr;
155  return TestRegistry::runAllTests(tr);
156 }
157 /* ************************************************************************* */
#define CHECK(condition)
Definition: Test.h:109
A insert(1, 2)=0
#define SETDEBUG(S, V)
Definition: debug.h:61
void add(const DiscreteKey &j, SOURCE table)
Key E(std::uint64_t j)
Concept check for values that can be used in unit tests.
static int runAllTests(TestResult &result)
Global debugging flags.
Vector marginalProbabilities(const DiscreteKey &key) const
Point2 prior(const Point2 &x)
Prior on a single pose.
Definition: simulated2D.h:87
static enum @843 ordering
Matrix expected
Definition: testMatrix.cpp:974
TEST_UNSAFE(DiscreteBayesNet, Sugar)
TEST(DiscreteBayesNet, bayesNet)
void add(const Signature &s)
sharedFactor back() const
Definition: FactorGraph.h:342
MatrixXd L
Definition: LLT_example.cpp:6
Definition: Half.h:150
IsDerived< DERIVEDFACTOR > push_back(boost::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:166
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:162
int main()
Eigen::VectorXd Vector
Definition: Vector.h:38
Key back() const
Last key.
Definition: Factor.h:118
Key S(std::uint64_t j)
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:34
#define EXPECT(condition)
Definition: Test.h:151
Eigen::Triplet< double > T
Array< double, 1, 3 > e(1./3., 0.5, 2.)
A class for computing marginals in a DiscreteFactorGraph.
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:37
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:135
traits
Definition: chartTesting.h:28
typedef and functions to augment Eigen&#39;s VectorXd
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:42
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:155
boost::shared_ptr< Values > sharedValues
boost::shared_ptr< BayesNetType > eliminateSequential(OptionalOrderingType orderingType=boost::none, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex=boost::none) const
boost::shared_ptr< This > shared_ptr
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:61
Marginals marginals(graph, result)
std::ptrdiff_t j


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:46:25