testDiscreteBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
12 /*
13  * testDiscreteBayesNet.cpp
14  *
15  * @date Feb 27, 2011
16  * @author Frank Dellaert
17  */
18 
22 #include <gtsam/base/debug.h>
23 #include <gtsam/base/Testable.h>
24 #include <gtsam/base/Vector.h>
25 
27 
28 #include <iostream>
29 #include <string>
30 #include <vector>
31 
32 using namespace std;
33 using namespace gtsam;
34 
35 static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
36  LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
37 
39 
40 /* ************************************************************************* */
41 TEST(DiscreteBayesNet, bayesNet) {
42  DiscreteBayesNet bayesNet;
43  DiscreteKey Parent(0, 2), Child(1, 2);
44 
45  auto prior = std::make_shared<DiscreteConditional>(Parent % "6/4");
46  CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
47  (ADT)*prior));
48  bayesNet.push_back(prior);
49 
50  auto conditional =
51  std::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
52  EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
53  ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
54  CHECK(assert_equal(expected, (ADT)*conditional));
55  bayesNet.push_back(conditional);
56 
57  DiscreteFactorGraph fg(bayesNet);
58  LONGS_EQUAL(2, fg.back()->size());
59 
60  // Check the marginals
61  const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2};
63  for (size_t j = 0; j < 2; j++) {
64  Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2));
65  EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3);
66  EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9);
67  }
68 }
69 
70 /* ************************************************************************* */
72  DiscreteBayesNet asia;
73 
74  asia.add(Asia, "99/1");
75  asia.add(Smoking % "50/50"); // Signature version
76 
77  asia.add(Tuberculosis | Asia = "99/1 95/5");
78  asia.add(LungCancer | Smoking = "99/1 90/10");
79  asia.add(Bronchitis | Smoking = "70/30 40/60");
80 
81  asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
82 
83  asia.add(XRay | Either = "95/5 2/98");
84  asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
85 
86  // Convert to factor graph
87  DiscreteFactorGraph fg(asia);
88  LONGS_EQUAL(3, fg.back()->size());
89 
90  // Check the marginals we know (of the parent-less nodes)
92  Vector2 va(0.99, 0.01), vs(0.5, 0.5);
93  EXPECT(assert_equal(va, marginals.marginalProbabilities(Asia)));
94  EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
95 
96  // Create solver and eliminate
97  const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7};
99  DiscreteConditional expected2(Bronchitis % "11/9");
100  EXPECT(assert_equal(expected2, *chordal->back()));
101 
102  // Check evaluate and logProbability
103  auto result = fg.optimize();
105  std::log(asia.evaluate(result)), 1e-9);
106 
107  // add evidence, we were in Asia and we have dyspnea
108  fg.add(Asia, "0 1");
109  fg.add(Dyspnea, "0 1");
110 
111  // solve again, now with evidence
113  EXPECT(assert_equal(expected2, *chordal->back()));
114 
115  // now sample from it
116  DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
117  {XRay.first, 1}, {Tuberculosis.first, 0},
118  {Smoking.first, 1}, {Either.first, 1},
119  {LungCancer.first, 1}, {Bronchitis.first, 0}};
120  SETDEBUG("DiscreteConditional::sample", false);
121  auto actualSample = chordal2->sample();
122  EXPECT(assert_equal(expectedSample, actualSample));
123 }
124 
125 /* ************************************************************************* */
127  DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
128 
129  DiscreteBayesNet bn;
130 
131  // try logic
132  bn.add((E | T, L) = "OR");
133  bn.add((E | T, L) = "AND");
134 
135  // try multivalued
136  bn.add(C % "1/1/2");
137  bn.add(C | S = "1/1/2 5/2/3");
138 }
139 
140 /* ************************************************************************* */
142  DiscreteBayesNet fragment;
143  fragment.add(Asia % "99/1");
144  fragment.add(Smoking % "50/50");
145 
146  fragment.add(Tuberculosis | Asia = "99/1 95/5");
147  fragment.add(LungCancer | Smoking = "99/1 90/10");
148  fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
149 
150  string actual = fragment.dot();
151  EXPECT(actual ==
152  "digraph {\n"
153  " size=\"5,5\";\n"
154  "\n"
155  " var0[label=\"0\"];\n"
156  " var3[label=\"3\"];\n"
157  " var4[label=\"4\"];\n"
158  " var5[label=\"5\"];\n"
159  " var6[label=\"6\"];\n"
160  "\n"
161  " var3->var5\n"
162  " var6->var5\n"
163  " var4->var6\n"
164  " var0->var3\n"
165  "}");
166 }
167 
168 /* ************************************************************************* */
169 // Check markdown representation looks as expected.
171  DiscreteBayesNet fragment;
172  fragment.add(Asia % "99/1");
173  fragment.add(Smoking | Asia = "8/2 7/3");
174 
175  string expected =
176  "`DiscreteBayesNet` of size 2\n"
177  "\n"
178  " *P(Asia):*\n\n"
179  "|Asia|value|\n"
180  "|:-:|:-:|\n"
181  "|0|0.99|\n"
182  "|1|0.01|\n"
183  "\n"
184  " *P(Smoking|Asia):*\n\n"
185  "|*Asia*|0|1|\n"
186  "|:-:|:-:|:-:|\n"
187  "|0|0.8|0.2|\n"
188  "|1|0.7|0.3|\n\n";
189  auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
190  string actual = fragment.markdown(formatter);
191  EXPECT(actual == expected);
192 }
193 
194 /* ************************************************************************* */
195 int main() {
196  TestResult tr;
197  return TestRegistry::runAllTests(tr);
198 }
199 /* ************************************************************************* */
const gtsam::Symbol key('X', 0)
#define CHECK(condition)
Definition: Test.h:108
sharedFactor back() const
Definition: FactorGraph.h:370
#define SETDEBUG(S, V)
Definition: debug.h:61
Concept check for values that can be used in unit tests.
static int runAllTests(TestResult &result)
Global debugging flags.
double evaluate(const DiscreteValues &values) const
Point2 prior(const Point2 &x)
Prior on a single pose.
Definition: simulated2D.h:88
Matrix expected
Definition: testMatrix.cpp:971
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:190
TEST(DiscreteBayesNet, bayesNet)
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:40
string markdown(const DiscreteValues &values, const KeyFormatter &keyFormatter, const DiscreteValues::Names &names)
Free version of markdown.
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter, const DotWriter &writer=DotWriter()) const
Output to graphviz format, stream version.
Definition: BayesNet-inst.h:45
MatrixXd L
Definition: LLT_example.cpp:6
Definition: BFloat16.h:88
DiscreteKey S(1, 2)
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const DiscreteFactor::Names &names={}) const
Render as markdown tables.
EIGEN_DEVICE_FUNC const LogReturnType log() const
static enum @1107 ordering
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
const KeyFormatter & formatter
Key back() const
Last key.
Definition: Factor.h:136
int main()
Eigen::VectorXd Vector
Definition: Vector.h:38
Values result
static const DiscreteKey XRay(2, 2)
static const DiscreteKey Tuberculosis(3, 2)
#define EXPECT(condition)
Definition: Test.h:150
Eigen::Triplet< double > T
double logProbability(const DiscreteValues &values) const
Vector marginalProbabilities(const DiscreteKey &key) const
std::shared_ptr< BayesNetType > eliminateSequential(OptionalOrderingType orderingType={}, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex={}) const
Array< double, 1, 3 > e(1./3., 0.5, 2.)
A class for computing marginals in a DiscreteFactorGraph.
static const DiscreteKey Asia(0, 2)
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
#define LONGS_EQUAL(expected, actual)
Definition: Test.h:134
traits
Definition: chartTesting.h:28
typedef and functions to augment Eigen&#39;s VectorXd
DiscreteKey E(5, 2)
#define EXPECT_LONGS_EQUAL(expected, actual)
Definition: Test.h:154
std::shared_ptr< This > shared_ptr
static const DiscreteKey LungCancer(6, 2)
void add(const DiscreteKey &key, const std::string &spec)
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
static const DiscreteKey Dyspnea(1, 2)
static const DiscreteKey Smoking(4, 2)
static const DiscreteKey Bronchitis(7, 2)
AlgebraicDecisionTree< Key > ADT
static const DiscreteKey Either(5, 2)
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
Marginals marginals(graph, result)
std::ptrdiff_t j


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:01