HybridBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3  * Atlanta, Georgia 30332-0415
4  * All Rights Reserved
5  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
6  * See LICENSE for the license information
7  * -------------------------------------------------------------------------- */
8 
25 
26 #include <memory>
27 
28 namespace gtsam {
29 
30 /* ************************************************************************* */
31 void HybridBayesNet::print(const std::string &s,
32  const KeyFormatter &formatter) const {
34 }
35 
36 /* ************************************************************************* */
37 bool HybridBayesNet::equals(const This &bn, double tol) const {
38  return Base::equals(bn, tol);
39 }
40 
41 /* ************************************************************************* */
43  size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
44  DiscreteValues *fixedValues) const {
45 #if GTSAM_HYBRID_TIMING
46  gttic_(HybridPruning);
47 #endif
48  // Collect all the discrete conditionals. Could be small if already pruned.
49  const DiscreteBayesNet marginal = discreteMarginal();
50 
51  // Prune discrete Bayes net
52  DiscreteValues fixed;
53  DiscreteBayesNet prunedBN =
54  marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
55 
56  // Multiply into one big conditional. NOTE: possibly quite expensive.
57  DiscreteConditional pruned = prunedBN.joint();
58 
59  // Set the fixed values if requested.
60  if (marginalThreshold && fixedValues) {
61  *fixedValues = fixed;
62  }
63 
65  result.reserve(size());
66 
67  // Go through all the Gaussian conditionals, restrict them according to
68  // fixed values, and then prune further.
69  for (std::shared_ptr<HybridConditional> conditional : *this) {
70  if (conditional->isDiscrete()) continue;
71 
72  // No-op if not a HybridGaussianConditional.
73  if (marginalThreshold) {
74  conditional = std::static_pointer_cast<HybridConditional>(
75  conditional->restrict(fixed));
76  }
77 
78  // Now decide on type what to do:
79  if (auto hgc = conditional->asHybrid()) {
80  // Prune the hybrid Gaussian conditional!
81  auto prunedHybridGaussianConditional = hgc->prune(pruned);
82  if (!prunedHybridGaussianConditional) {
83  throw std::runtime_error(
84  "A HybridGaussianConditional had all its conditionals pruned");
85  }
86  // Type-erase and add to the pruned Bayes Net fragment.
87  result.push_back(prunedHybridGaussianConditional);
88  } else if (conditional->isContinuous()) {
89  // Add the non-Hybrid GaussianConditional conditional
90  result.push_back(conditional);
91  } else
92  throw std::runtime_error(
93  "HybrdiBayesNet::prune: Unknown HybridConditional type.");
94  }
95 
96 #if GTSAM_HYBRID_TIMING
97  gttoc_(HybridPruning);
98 #endif
99 
100  // Add the pruned discrete conditionals to the result.
101  for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
102  result.push_back(discrete);
103 
104  return result;
105 }
106 
107 /* ************************************************************************* */
110  for (auto &&conditional : *this) {
111  if (auto dc = conditional->asDiscrete()) {
112  result.push_back(dc);
113  }
114  }
115  return result;
116 }
117 
118 /* ************************************************************************* */
120  const DiscreteValues &assignment) const {
122  for (auto &&conditional : *this) {
123  if (auto gm = conditional->asHybrid()) {
124  // If conditional is hybrid, select based on assignment.
125  gbn.push_back(gm->choose(assignment));
126  } else if (auto gc = conditional->asGaussian()) {
127  // If continuous only, add Gaussian conditional.
128  gbn.push_back(gc);
129  } else if (auto dc = conditional->asDiscrete()) {
130  // If conditional is discrete-only, we simply continue.
131  continue;
132  }
133  }
134 
135  return gbn;
136 }
137 
138 /* ************************************************************************* */
140  // Collect all the discrete factors to compute MPE
141  DiscreteFactorGraph discrete_fg;
142 
143  for (auto &&conditional : *this) {
144  if (conditional->isDiscrete()) {
145  if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
146  // The number of keys should be small so should not
147  // be expensive to convert to DiscreteConditional.
148  discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
149  dtc->toDecisionTreeFactor()));
150  } else {
151  discrete_fg.push_back(conditional->asDiscrete());
152  }
153  }
154  }
155 
156  return discrete_fg.optimize();
157 }
158 
159 /* ************************************************************************* */
161  // Solve for the MPE
162  DiscreteValues mpe = this->mpe();
163 
164  // Given the MPE, compute the optimal continuous values.
165  return HybridValues(optimize(mpe), mpe);
166 }
167 
168 /* ************************************************************************* */
170  GaussianBayesNet gbn = choose(assignment);
171 
172  // Check if there exists a nullptr in the GaussianBayesNet
173  // If yes, return an empty VectorValues
174  if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
175  return VectorValues();
176  }
177  return gbn.optimize();
178 }
179 
180 /* ************************************************************************* */
182  std::mt19937_64 *rng) const {
183  DiscreteBayesNet dbn;
184  for (auto &&conditional : *this) {
185  if (conditional->isDiscrete()) {
186  // If conditional is discrete-only, we add to the discrete Bayes net.
187  dbn.push_back(conditional->asDiscrete());
188  }
189  }
190  // Sample a discrete assignment.
191  const DiscreteValues assignment = dbn.sample(given.discrete(), rng);
192  // Select the continuous Bayes net corresponding to the assignment.
193  GaussianBayesNet gbn = choose(assignment);
194  // Sample from the Gaussian Bayes net.
195  VectorValues sample = gbn.sample(given.continuous(), rng);
196  return {sample, assignment};
197 }
198 
199 /* ************************************************************************* */
200 HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
201  HybridValues given;
202  return sample(given, rng);
203 }
204 
205 /* ************************************************************************* */
207  const VectorValues &continuousValues) const {
209 
210  // Iterate over each conditional.
211  for (auto &&conditional : *this) {
212  result = result + conditional->errorTree(continuousValues);
213  }
214 
215  return result;
216 }
217 
218 /* ************************************************************************* */
220  const std::optional<DiscreteValues> &discrete) const {
221  double negLogNormConst = 0.0;
222  // Iterate over each conditional.
223  for (auto &&conditional : *this) {
224  if (discrete.has_value()) {
225  if (auto gm = conditional->asHybrid()) {
226  negLogNormConst += gm->choose(*discrete)->negLogConstant();
227  } else if (auto gc = conditional->asGaussian()) {
228  negLogNormConst += gc->negLogConstant();
229  } else if (auto dc = conditional->asDiscrete()) {
230  negLogNormConst += dc->choose(*discrete)->negLogConstant();
231  } else {
232  throw std::runtime_error(
233  "Unknown conditional type when computing negLogConstant");
234  }
235  } else {
236  negLogNormConst += conditional->negLogConstant();
237  }
238  }
239  return negLogNormConst;
240 }
241 
242 /* ************************************************************************* */
244  const VectorValues &continuousValues) const {
245  AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
247  errors.apply([](double error) { return exp(-error); });
248  return p / p.sum();
249 }
250 
251 /* ************************************************************************* */
253  return exp(logProbability(values));
254 }
255 
256 /* ************************************************************************* */
258  const VectorValues &measurements) const {
260 
261  // For all nodes in the Bayes net, if its frontal variable is in measurements,
262  // replace it by a likelihood factor:
263  for (auto &&conditional : *this) {
264  if (conditional->frontalsIn(measurements)) {
265  if (auto gc = conditional->asGaussian()) {
266  fg.push_back(gc->likelihood(measurements));
267  } else if (auto gm = conditional->asHybrid()) {
268  fg.push_back(gm->likelihood(measurements));
269  } else {
270  throw std::runtime_error("Unknown conditional type");
271  }
272  } else {
273  fg.push_back(conditional);
274  }
275  }
276  return fg;
277 }
278 
279 } // namespace gtsam
gtsam::HybridBayesNet::equals
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
Definition: HybridBayesNet.cpp:37
TableDistribution.h
DiscreteBayesNet.h
gtsam::HybridValues
Definition: HybridValues.h:37
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::BayesNet< HybridConditional >::print
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31
gtsam::FactorGraph< HybridConditional >::error
double error(const HybridValues &values) const
Definition: FactorGraph-inst.h:66
gtsam::HybridBayesNet
Definition: HybridBayesNet.h:37
DiscreteFactorGraph.h
gtsam::HybridBayesNet::evaluate
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
Definition: HybridBayesNet.cpp:252
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:247
DiscreteConditional.h
exp
const EIGEN_DEVICE_FUNC ExpReturnType exp() const
Definition: ArrayCwiseUnaryOps.h:97
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::HybridValues::continuous
const VectorValues & continuous() const
Return the multi-dimensional vector values.
Definition: HybridValues.cpp:54
gtsam::AlgebraicDecisionTree< Key >
gtsam::HybridBayesNet::optimize
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
Definition: HybridBayesNet.cpp:160
gtsam::DiscreteBayesNet::joint
DiscreteConditional joint() const
Multiply all conditionals into one big joint conditional and return it.
Definition: DiscreteBayesNet.cpp:125
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
gttoc_
#define gttoc_(label)
Definition: timing.h:273
gtsam::TableDistribution
Definition: TableDistribution.h:39
gttic_
#define gttic_(label)
Definition: timing.h:268
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::HybridBayesNet::errorTree
AlgebraicDecisionTree< Key > errorTree(const VectorValues &continuousValues) const
Compute the negative log posterior log P'(M|x) of all assignments up to a constant,...
Definition: HybridBayesNet.cpp:206
gtsam::DiscreteBayesNet::prune
DiscreteBayesNet prune(size_t maxNrLeaves, const std::optional< double > &marginalThreshold={}, DiscreteValues *fixedValues=nullptr) const
Prune the Bayes net.
Definition: DiscreteBayesNet.cpp:78
gtsam::FactorGraph< HybridConditional >::equals
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
Definition: FactorGraph-inst.h:50
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:106
gtsam::HybridBayesNet::print
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
Definition: HybridBayesNet.cpp:31
gtsam::HybridBayesNet::mpe
DiscreteValues mpe() const
Compute the Most Probable Explanation (MPE) of the discrete variables.
Definition: HybridBayesNet.cpp:139
gtsam::HybridBayesNet::discreteMarginal
DiscreteBayesNet discreteMarginal() const
Get the discrete Bayes Net P(M). As the hybrid Bayes net defines P(X,M) = P(X|M) P(M),...
Definition: HybridBayesNet.cpp:108
gtsam::DecisionTree::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:1003
sampling::gbn
static const GaussianBayesNet gbn
Definition: testGaussianBayesNet.cpp:171
gtsam::DiscreteConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: DiscreteConditional.h:44
estimation_fixture::measurements
std::vector< double > measurements
Definition: testHybridEstimation.cpp:52
gtsam::HybridBayesNet::discretePosterior
AlgebraicDecisionTree< Key > discretePosterior(const VectorValues &continuousValues) const
Compute normalized posterior P(M|X=x) and return as a tree.
Definition: HybridBayesNet.cpp:243
gtsam::FactorGraph< HybridConditional >::size
size_t size() const
Definition: FactorGraph.h:297
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:38
different_sigmas::gc
const auto gc
Definition: testHybridBayesNet.cpp:233
gtsam
traits
Definition: ABC.h:17
gtsam::BayesNet< HybridConditional >::logProbability
double logProbability(const HybridValues &x) const
Definition: BayesNet-inst.h:94
gtsam::HybridBayesNet::sample
HybridValues sample(const HybridValues &given, std::mt19937_64 *rng=nullptr) const
Sample from an incomplete BayesNet, given missing variables.
Definition: HybridBayesNet.cpp:181
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
gtsam::HybridValues::discrete
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.cpp:57
gtsam::DiscreteBayesNet::sample
DiscreteValues sample(std::mt19937_64 *rng=nullptr) const
do ancestral sampling
Definition: DiscreteBayesNet.cpp:53
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::tol
const G double tol
Definition: Group.h:79
different_sigmas::hgc
const auto hgc
Definition: testHybridBayesNet.cpp:238
gtsam::DiscreteFactorGraph::optimize
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
Definition: DiscreteFactorGraph.cpp:187
gtsam::HybridBayesNet::prune
HybridBayesNet prune(size_t maxNrLeaves, const std::optional< double > &marginalThreshold={}, DiscreteValues *fixedValues=nullptr) const
Prune the Bayes Net such that we have at most maxNrLeaves leaves.
Definition: HybridBayesNet.cpp:42
gtsam::HybridBayesNet::toFactorGraph
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
Definition: HybridBayesNet.cpp:257
HybridValues.h
gtsam::GaussianBayesNet
Definition: GaussianBayesNet.h:35
gtsam::HybridBayesNet::choose
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes net P(X|M=m) corresponding to a specific assignment m for the discrete variabl...
Definition: HybridBayesNet.cpp:119
gtsam::HybridBayesNet::negLogConstant
double negLogConstant(const std::optional< DiscreteValues > &discrete) const
Get the negative log of the normalization constant corresponding to the joint density represented by ...
Definition: HybridBayesNet.cpp:219


gtsam
Author(s):
autogenerated on Wed May 28 2025 03:01:25