HybridBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3  * Atlanta, Georgia 30332-0415
4  * All Rights Reserved
5  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
6  * See LICENSE for the license information
7  * -------------------------------------------------------------------------- */
8 
24 
25 #include <memory>
26 
27 // In Wrappers we have no access to this so have a default ready
28 static std::mt19937_64 kRandomNumberGenerator(42);
29 
30 namespace gtsam {
31 
32 /* ************************************************************************* */
33 void HybridBayesNet::print(const std::string &s,
34  const KeyFormatter &formatter) const {
36 }
37 
38 /* ************************************************************************* */
39 bool HybridBayesNet::equals(const This &bn, double tol) const {
40  return Base::equals(bn, tol);
41 }
42 
43 /* ************************************************************************* */
44 // The implementation is: build the entire joint into one factor and then prune.
45 // TODO(Frank): This can be quite expensive *unless* the factors have already
46 // been pruned before. Another, possibly faster approach is branch and bound
47 // search to find the K-best leaves and then create a single pruned conditional.
48 HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
49  // Collect all the discrete conditionals. Could be small if already pruned.
50  const DiscreteBayesNet marginal = discreteMarginal();
51 
52  // Multiply into one big conditional. NOTE: possibly quite expensive.
53  DiscreteConditional joint;
54  for (auto &&conditional : marginal) {
55  joint = joint * (*conditional);
56  }
57 
58  // Prune the joint. NOTE: again, possibly quite expensive.
59  const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
60 
61  // Create a the result starting with the pruned joint.
63  result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
64 
65  /* To prune, we visitWith every leaf in the HybridGaussianConditional.
66  * For each leaf, using the assignment we can check the discrete decision tree
67  * for 0.0 probability, then just set the leaf to a nullptr.
68  *
69  * We can later check the HybridGaussianConditional for just nullptrs.
70  */
71 
72  // Go through all the Gaussian conditionals in the Bayes Net and prune them as
73  // per pruned Discrete joint.
74  for (auto &&conditional : *this) {
75  if (auto hgc = conditional->asHybrid()) {
76  // Prune the hybrid Gaussian conditional!
77  auto prunedHybridGaussianConditional = hgc->prune(pruned);
78 
79  // Type-erase and add to the pruned Bayes Net fragment.
80  result.push_back(prunedHybridGaussianConditional);
81  } else if (auto gc = conditional->asGaussian()) {
82  // Add the non-HybridGaussianConditional conditional
83  result.push_back(gc);
84  }
85  // We ignore DiscreteConditional as they are already pruned and added.
86  }
87 
88  return result;
89 }
90 
91 /* ************************************************************************* */
94  for (auto &&conditional : *this) {
95  if (auto dc = conditional->asDiscrete()) {
96  result.push_back(dc);
97  }
98  }
99  return result;
100 }
101 
102 /* ************************************************************************* */
104  const DiscreteValues &assignment) const {
106  for (auto &&conditional : *this) {
107  if (auto gm = conditional->asHybrid()) {
108  // If conditional is hybrid, select based on assignment.
109  gbn.push_back(gm->choose(assignment));
110  } else if (auto gc = conditional->asGaussian()) {
111  // If continuous only, add Gaussian conditional.
112  gbn.push_back(gc);
113  } else if (auto dc = conditional->asDiscrete()) {
114  // If conditional is discrete-only, we simply continue.
115  continue;
116  }
117  }
118 
119  return gbn;
120 }
121 
122 /* ************************************************************************* */
124  // Collect all the discrete factors to compute MPE
125  DiscreteFactorGraph discrete_fg;
126 
127  for (auto &&conditional : *this) {
128  if (conditional->isDiscrete()) {
129  discrete_fg.push_back(conditional->asDiscrete());
130  }
131  }
132 
133  // Solve for the MPE
134  DiscreteValues mpe = discrete_fg.optimize();
135 
136  // Given the MPE, compute the optimal continuous values.
137  return HybridValues(optimize(mpe), mpe);
138 }
139 
140 /* ************************************************************************* */
142  GaussianBayesNet gbn = choose(assignment);
143 
144  // Check if there exists a nullptr in the GaussianBayesNet
145  // If yes, return an empty VectorValues
146  if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
147  return VectorValues();
148  }
149  return gbn.optimize();
150 }
151 
152 /* ************************************************************************* */
154  std::mt19937_64 *rng) const {
155  DiscreteBayesNet dbn;
156  for (auto &&conditional : *this) {
157  if (conditional->isDiscrete()) {
158  // If conditional is discrete-only, we add to the discrete Bayes net.
159  dbn.push_back(conditional->asDiscrete());
160  }
161  }
162  // Sample a discrete assignment.
163  const DiscreteValues assignment = dbn.sample(given.discrete());
164  // Select the continuous Bayes net corresponding to the assignment.
165  GaussianBayesNet gbn = choose(assignment);
166  // Sample from the Gaussian Bayes net.
167  VectorValues sample = gbn.sample(given.continuous(), rng);
168  return {sample, assignment};
169 }
170 
171 /* ************************************************************************* */
172 HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
173  HybridValues given;
174  return sample(given, rng);
175 }
176 
177 /* ************************************************************************* */
179  return sample(given, &kRandomNumberGenerator);
180 }
181 
182 /* ************************************************************************* */
185 }
186 
187 /* ************************************************************************* */
189  const VectorValues &continuousValues) const {
191 
192  // Iterate over each conditional.
193  for (auto &&conditional : *this) {
194  result = result + conditional->errorTree(continuousValues);
195  }
196 
197  return result;
198 }
199 
200 /* ************************************************************************* */
202  const std::optional<DiscreteValues> &discrete) const {
203  double negLogNormConst = 0.0;
204  // Iterate over each conditional.
205  for (auto &&conditional : *this) {
206  if (discrete.has_value()) {
207  if (auto gm = conditional->asHybrid()) {
208  negLogNormConst += gm->choose(*discrete)->negLogConstant();
209  } else if (auto gc = conditional->asGaussian()) {
210  negLogNormConst += gc->negLogConstant();
211  } else if (auto dc = conditional->asDiscrete()) {
212  negLogNormConst += dc->choose(*discrete)->negLogConstant();
213  } else {
214  throw std::runtime_error(
215  "Unknown conditional type when computing negLogConstant");
216  }
217  } else {
218  negLogNormConst += conditional->negLogConstant();
219  }
220  }
221  return negLogNormConst;
222 }
223 
224 /* ************************************************************************* */
226  const VectorValues &continuousValues) const {
227  AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
229  errors.apply([](double error) { return exp(-error); });
230  return p / p.sum();
231 }
232 
233 /* ************************************************************************* */
235  return exp(logProbability(values));
236 }
237 
238 /* ************************************************************************* */
240  const VectorValues &measurements) const {
242 
243  // For all nodes in the Bayes net, if its frontal variable is in measurements,
244  // replace it by a likelihood factor:
245  for (auto &&conditional : *this) {
246  if (conditional->frontalsIn(measurements)) {
247  if (auto gc = conditional->asGaussian()) {
248  fg.push_back(gc->likelihood(measurements));
249  } else if (auto gm = conditional->asHybrid()) {
250  fg.push_back(gm->likelihood(measurements));
251  } else {
252  throw std::runtime_error("Unknown conditional type");
253  }
254  } else {
255  fg.push_back(conditional);
256  }
257  }
258  return fg;
259 }
260 
261 } // namespace gtsam
gtsam::HybridBayesNet::equals
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
Definition: HybridBayesNet.cpp:39
DiscreteBayesNet.h
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
gtsam::HybridValues
Definition: HybridValues.h:37
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::HybridBayesNet::sample
HybridValues sample() const
Sample using ancestral sampling, use default rng.
Definition: HybridBayesNet.cpp:183
gtsam::DiscreteBayesNet::sample
DiscreteValues sample() const
do ancestral sampling
Definition: DiscreteBayesNet.cpp:52
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::BayesNet< HybridConditional >::print
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31
gtsam::FactorGraph< HybridConditional >::error
double error(const HybridValues &values) const
Definition: FactorGraph-inst.h:66
gtsam::HybridBayesNet
Definition: HybridBayesNet.h:37
DiscreteFactorGraph.h
gtsam::HybridBayesNet::evaluate
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
Definition: HybridBayesNet.cpp:234
gtsam::HybridBayesNet::prune
HybridBayesNet prune(size_t maxNrLeaves) const
Prune the Bayes Net such that we have at most maxNrLeaves leaves.
Definition: HybridBayesNet.cpp:48
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:508
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
DiscreteConditional.h
exp
const EIGEN_DEVICE_FUNC ExpReturnType exp() const
Definition: ArrayCwiseUnaryOps.h:97
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::HybridValues::continuous
const VectorValues & continuous() const
Return the multi-dimensional vector values.
Definition: HybridValues.cpp:54
gtsam::AlgebraicDecisionTree< Key >
gtsam::HybridBayesNet::optimize
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
Definition: HybridBayesNet.cpp:123
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
gtsam::kRandomNumberGenerator
static std::mt19937 kRandomNumberGenerator(42)
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::HybridBayesNet::errorTree
AlgebraicDecisionTree< Key > errorTree(const VectorValues &continuousValues) const
Compute the negative log posterior log P'(M|x) of all assignments up to a constant,...
Definition: HybridBayesNet.cpp:188
gtsam::FactorGraph< HybridConditional >::equals
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
Definition: FactorGraph-inst.h:50
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:106
gtsam::HybridBayesNet::print
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
Definition: HybridBayesNet.cpp:33
gtsam::HybridBayesNet::discreteMarginal
DiscreteBayesNet discreteMarginal() const
Get the discrete Bayes Net P(M). As the hybrid Bayes net defines P(X,M) = P(X|M) P(M),...
Definition: HybridBayesNet.cpp:92
gtsam::DecisionTree::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:1000
sampling::gbn
static const GaussianBayesNet gbn
Definition: testGaussianBayesNet.cpp:171
gtsam::HybridBayesNet::discretePosterior
AlgebraicDecisionTree< Key > discretePosterior(const VectorValues &continuousValues) const
Compute normalized posterior P(M|X=x) and return as a tree.
Definition: HybridBayesNet.cpp:225
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
different_sigmas::gc
const auto gc
Definition: testHybridBayesNet.cpp:231
gtsam
traits
Definition: SFMdata.h:40
gtsam::BayesNet< HybridConditional >::logProbability
double logProbability(const HybridValues &x) const
Definition: BayesNet-inst.h:94
estimation_fixture::measurements
std::vector< double > measurements
Definition: testHybridEstimation.cpp:51
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
gtsam::HybridValues::discrete
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.cpp:57
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::tol
const G double tol
Definition: Group.h:79
different_sigmas::hgc
const auto hgc
Definition: testHybridBayesNet.cpp:236
gtsam::DiscreteFactorGraph::optimize
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
Definition: DiscreteFactorGraph.cpp:209
gtsam::HybridBayesNet::toFactorGraph
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
Definition: HybridBayesNet.cpp:239
gtsam::Factor::size
size_t size() const
Definition: Factor.h:160
HybridValues.h
gtsam::GaussianBayesNet
Definition: GaussianBayesNet.h:35
gtsam::HybridBayesNet::choose
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes net P(X|M=m) corresponding to a specific assignment m for the discrete variabl...
Definition: HybridBayesNet.cpp:103
kRandomNumberGenerator
static std::mt19937_64 kRandomNumberGenerator(42)
gtsam::HybridBayesNet::negLogConstant
double negLogConstant(const std::optional< DiscreteValues > &discrete) const
Get the negative log of the normalization constant corresponding to the joint density represented by ...
Definition: HybridBayesNet.cpp:201


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:22