HybridBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3  * Atlanta, Georgia 30332-0415
4  * All Rights Reserved
5  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
6  * See LICENSE for the license information
7  * -------------------------------------------------------------------------- */
8 
25 
26 #include <memory>
27 
28 // In Wrappers we have no access to this so have a default ready
29 static std::mt19937_64 kRandomNumberGenerator(42);
30 
31 namespace gtsam {
32 
33 /* ************************************************************************* */
34 void HybridBayesNet::print(const std::string &s,
35  const KeyFormatter &formatter) const {
37 }
38 
39 /* ************************************************************************* */
40 bool HybridBayesNet::equals(const This &bn, double tol) const {
41  return Base::equals(bn, tol);
42 }
43 
44 /* ************************************************************************* */
46  size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
47  DiscreteValues *fixedValues) const {
48 #if GTSAM_HYBRID_TIMING
49  gttic_(HybridPruning);
50 #endif
51  // Collect all the discrete conditionals. Could be small if already pruned.
52  const DiscreteBayesNet marginal = discreteMarginal();
53 
54  // Prune discrete Bayes net
55  DiscreteValues fixed;
56  auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
57 
58  // Multiply into one big conditional. NOTE: possibly quite expensive.
59  DiscreteConditional pruned;
60  for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
61 
62  // Set the fixed values if requested.
63  if (marginalThreshold && fixedValues) {
64  *fixedValues = fixed;
65  }
66 
68  result.reserve(size());
69 
70  // Go through all the Gaussian conditionals, restrict them according to
71  // fixed values, and then prune further.
72  for (std::shared_ptr<HybridConditional> conditional : *this) {
73  if (conditional->isDiscrete()) continue;
74 
75  // No-op if not a HybridGaussianConditional.
76  if (marginalThreshold) {
77  conditional = std::static_pointer_cast<HybridConditional>(
78  conditional->restrict(fixed));
79  }
80 
81  // Now decide on type what to do:
82  if (auto hgc = conditional->asHybrid()) {
83  // Prune the hybrid Gaussian conditional!
84  auto prunedHybridGaussianConditional = hgc->prune(pruned);
85  if (!prunedHybridGaussianConditional) {
86  throw std::runtime_error(
87  "A HybridGaussianConditional had all its conditionals pruned");
88  }
89  // Type-erase and add to the pruned Bayes Net fragment.
90  result.push_back(prunedHybridGaussianConditional);
91  } else if (conditional->isContinuous()) {
92  // Add the non-Hybrid GaussianConditional conditional
93  result.push_back(conditional);
94  } else
95  throw std::runtime_error(
96  "HybrdiBayesNet::prune: Unknown HybridConditional type.");
97  }
98 
99 #if GTSAM_HYBRID_TIMING
100  gttoc_(HybridPruning);
101 #endif
102 
103  // Add the pruned discrete conditionals to the result.
104  for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
105  result.push_back(discrete);
106 
107  return result;
108 }
109 
110 /* ************************************************************************* */
113  for (auto &&conditional : *this) {
114  if (auto dc = conditional->asDiscrete()) {
115  result.push_back(dc);
116  }
117  }
118  return result;
119 }
120 
121 /* ************************************************************************* */
123  const DiscreteValues &assignment) const {
125  for (auto &&conditional : *this) {
126  if (auto gm = conditional->asHybrid()) {
127  // If conditional is hybrid, select based on assignment.
128  gbn.push_back(gm->choose(assignment));
129  } else if (auto gc = conditional->asGaussian()) {
130  // If continuous only, add Gaussian conditional.
131  gbn.push_back(gc);
132  } else if (auto dc = conditional->asDiscrete()) {
133  // If conditional is discrete-only, we simply continue.
134  continue;
135  }
136  }
137 
138  return gbn;
139 }
140 
141 /* ************************************************************************* */
143  // Collect all the discrete factors to compute MPE
144  DiscreteFactorGraph discrete_fg;
145 
146  for (auto &&conditional : *this) {
147  if (conditional->isDiscrete()) {
148  if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
149  // The number of keys should be small so should not
150  // be expensive to convert to DiscreteConditional.
151  discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
152  dtc->toDecisionTreeFactor()));
153  } else {
154  discrete_fg.push_back(conditional->asDiscrete());
155  }
156  }
157  }
158 
159  return discrete_fg.optimize();
160 }
161 
162 /* ************************************************************************* */
164  // Solve for the MPE
165  DiscreteValues mpe = this->mpe();
166 
167  // Given the MPE, compute the optimal continuous values.
168  return HybridValues(optimize(mpe), mpe);
169 }
170 
171 /* ************************************************************************* */
173  GaussianBayesNet gbn = choose(assignment);
174 
175  // Check if there exists a nullptr in the GaussianBayesNet
176  // If yes, return an empty VectorValues
177  if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
178  return VectorValues();
179  }
180  return gbn.optimize();
181 }
182 
183 /* ************************************************************************* */
185  std::mt19937_64 *rng) const {
186  DiscreteBayesNet dbn;
187  for (auto &&conditional : *this) {
188  if (conditional->isDiscrete()) {
189  // If conditional is discrete-only, we add to the discrete Bayes net.
190  dbn.push_back(conditional->asDiscrete());
191  }
192  }
193  // Sample a discrete assignment.
194  const DiscreteValues assignment = dbn.sample(given.discrete());
195  // Select the continuous Bayes net corresponding to the assignment.
196  GaussianBayesNet gbn = choose(assignment);
197  // Sample from the Gaussian Bayes net.
198  VectorValues sample = gbn.sample(given.continuous(), rng);
199  return {sample, assignment};
200 }
201 
202 /* ************************************************************************* */
203 HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
204  HybridValues given;
205  return sample(given, rng);
206 }
207 
208 /* ************************************************************************* */
210  return sample(given, &kRandomNumberGenerator);
211 }
212 
213 /* ************************************************************************* */
216 }
217 
218 /* ************************************************************************* */
220  const VectorValues &continuousValues) const {
222 
223  // Iterate over each conditional.
224  for (auto &&conditional : *this) {
225  result = result + conditional->errorTree(continuousValues);
226  }
227 
228  return result;
229 }
230 
231 /* ************************************************************************* */
233  const std::optional<DiscreteValues> &discrete) const {
234  double negLogNormConst = 0.0;
235  // Iterate over each conditional.
236  for (auto &&conditional : *this) {
237  if (discrete.has_value()) {
238  if (auto gm = conditional->asHybrid()) {
239  negLogNormConst += gm->choose(*discrete)->negLogConstant();
240  } else if (auto gc = conditional->asGaussian()) {
241  negLogNormConst += gc->negLogConstant();
242  } else if (auto dc = conditional->asDiscrete()) {
243  negLogNormConst += dc->choose(*discrete)->negLogConstant();
244  } else {
245  throw std::runtime_error(
246  "Unknown conditional type when computing negLogConstant");
247  }
248  } else {
249  negLogNormConst += conditional->negLogConstant();
250  }
251  }
252  return negLogNormConst;
253 }
254 
255 /* ************************************************************************* */
257  const VectorValues &continuousValues) const {
258  AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
260  errors.apply([](double error) { return exp(-error); });
261  return p / p.sum();
262 }
263 
264 /* ************************************************************************* */
266  return exp(logProbability(values));
267 }
268 
269 /* ************************************************************************* */
271  const VectorValues &measurements) const {
273 
274  // For all nodes in the Bayes net, if its frontal variable is in measurements,
275  // replace it by a likelihood factor:
276  for (auto &&conditional : *this) {
277  if (conditional->frontalsIn(measurements)) {
278  if (auto gc = conditional->asGaussian()) {
279  fg.push_back(gc->likelihood(measurements));
280  } else if (auto gm = conditional->asHybrid()) {
281  fg.push_back(gm->likelihood(measurements));
282  } else {
283  throw std::runtime_error("Unknown conditional type");
284  }
285  } else {
286  fg.push_back(conditional);
287  }
288  }
289  return fg;
290 }
291 
292 } // namespace gtsam
gtsam::HybridBayesNet::equals
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
Definition: HybridBayesNet.cpp:40
TableDistribution.h
DiscreteBayesNet.h
gtsam::HybridValues
Definition: HybridValues.h:37
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::HybridBayesNet::sample
HybridValues sample() const
Sample using ancestral sampling, use default rng.
Definition: HybridBayesNet.cpp:214
gtsam::DiscreteBayesNet::sample
DiscreteValues sample() const
do ancestral sampling
Definition: DiscreteBayesNet.cpp:53
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::BayesNet< HybridConditional >::print
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31
gtsam::FactorGraph< HybridConditional >::error
double error(const HybridValues &values) const
Definition: FactorGraph-inst.h:66
gtsam::HybridBayesNet
Definition: HybridBayesNet.h:37
DiscreteFactorGraph.h
gtsam::HybridBayesNet::evaluate
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
Definition: HybridBayesNet.cpp:265
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:247
DiscreteConditional.h
exp
const EIGEN_DEVICE_FUNC ExpReturnType exp() const
Definition: ArrayCwiseUnaryOps.h:97
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::HybridValues::continuous
const VectorValues & continuous() const
Return the multi-dimensional vector values.
Definition: HybridValues.cpp:54
gtsam::AlgebraicDecisionTree< Key >
gtsam::HybridBayesNet::optimize
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
Definition: HybridBayesNet.cpp:163
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
gttoc_
#define gttoc_(label)
Definition: timing.h:273
gtsam::TableDistribution
Definition: TableDistribution.h:39
gttic_
#define gttic_(label)
Definition: timing.h:268
gtsam::kRandomNumberGenerator
static std::mt19937 kRandomNumberGenerator(42)
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::HybridBayesNet::errorTree
AlgebraicDecisionTree< Key > errorTree(const VectorValues &continuousValues) const
Compute the negative log posterior log P'(M|x) of all assignments up to a constant,...
Definition: HybridBayesNet.cpp:219
gtsam::DiscreteBayesNet::prune
DiscreteBayesNet prune(size_t maxNrLeaves, const std::optional< double > &marginalThreshold={}, DiscreteValues *fixedValues=nullptr) const
Prune the Bayes net.
Definition: DiscreteBayesNet.cpp:77
gtsam::FactorGraph< HybridConditional >::equals
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
Definition: FactorGraph-inst.h:50
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:106
gtsam::HybridBayesNet::print
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
Definition: HybridBayesNet.cpp:34
gtsam::HybridBayesNet::mpe
DiscreteValues mpe() const
Compute the Most Probable Explanation (MPE) of the discrete variables.
Definition: HybridBayesNet.cpp:142
gtsam::HybridBayesNet::discreteMarginal
DiscreteBayesNet discreteMarginal() const
Get the discrete Bayes Net P(M). As the hybrid Bayes net defines P(X,M) = P(X|M) P(M),...
Definition: HybridBayesNet.cpp:111
gtsam::DecisionTree::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:1003
sampling::gbn
static const GaussianBayesNet gbn
Definition: testGaussianBayesNet.cpp:171
gtsam::DiscreteConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: DiscreteConditional.h:43
estimation_fixture::measurements
std::vector< double > measurements
Definition: testHybridEstimation.cpp:52
gtsam::HybridBayesNet::discretePosterior
AlgebraicDecisionTree< Key > discretePosterior(const VectorValues &continuousValues) const
Compute normalized posterior P(M|X=x) and return as a tree.
Definition: HybridBayesNet.cpp:256
gtsam::FactorGraph< HybridConditional >::size
size_t size() const
Definition: FactorGraph.h:297
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
different_sigmas::gc
const auto gc
Definition: testHybridBayesNet.cpp:233
gtsam
traits
Definition: SFMdata.h:40
gtsam::BayesNet< HybridConditional >::logProbability
double logProbability(const HybridValues &x) const
Definition: BayesNet-inst.h:94
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
gtsam::HybridValues::discrete
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.cpp:57
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::tol
const G double tol
Definition: Group.h:79
different_sigmas::hgc
const auto hgc
Definition: testHybridBayesNet.cpp:238
gtsam::DiscreteFactorGraph::optimize
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
Definition: DiscreteFactorGraph.cpp:187
gtsam::HybridBayesNet::prune
HybridBayesNet prune(size_t maxNrLeaves, const std::optional< double > &marginalThreshold={}, DiscreteValues *fixedValues=nullptr) const
Prune the Bayes Net such that we have at most maxNrLeaves leaves.
Definition: HybridBayesNet.cpp:45
gtsam::HybridBayesNet::toFactorGraph
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
Definition: HybridBayesNet.cpp:270
HybridValues.h
gtsam::GaussianBayesNet
Definition: GaussianBayesNet.h:35
gtsam::HybridBayesNet::choose
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes net P(X|M=m) corresponding to a specific assignment m for the discrete variabl...
Definition: HybridBayesNet.cpp:122
kRandomNumberGenerator
static std::mt19937_64 kRandomNumberGenerator(42)
gtsam::HybridBayesNet::negLogConstant
double negLogConstant(const std::optional< DiscreteValues > &discrete) const
Get the negative log of the normalization constant corresponding to the joint density represented by ...
Definition: HybridBayesNet.cpp:232


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:01:48