HybridBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3  * Atlanta, Georgia 30332-0415
4  * All Rights Reserved
5  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
6  * See LICENSE for the license information
7  * -------------------------------------------------------------------------- */
8 
23 
24 // In Wrappers we have no access to this so have a default ready
25 static std::mt19937_64 kRandomNumberGenerator(42);
26 
27 namespace gtsam {
28 
29 /* ************************************************************************* */
30 void HybridBayesNet::print(const std::string &s,
31  const KeyFormatter &formatter) const {
33 }
34 
35 /* ************************************************************************* */
36 bool HybridBayesNet::equals(const This &bn, double tol) const {
37  return Base::equals(bn, tol);
38 }
39 
40 /* ************************************************************************* */
48 std::function<double(const Assignment<Key> &, double)> prunerFunc(
49  const DecisionTreeFactor &prunedDiscreteProbs,
50  const HybridConditional &conditional) {
51  // Get the discrete keys as sets for the decision tree
52  // and the Gaussian mixture.
53  std::set<DiscreteKey> discreteProbsKeySet =
54  DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
55  std::set<DiscreteKey> conditionalKeySet =
56  DiscreteKeysAsSet(conditional.discreteKeys());
57 
58  auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
59  const Assignment<Key> &choices,
60  double probability) -> double {
61  // This corresponds to 0 probability
62  double pruned_prob = 0.0;
63 
64  // typecast so we can use this to get probability value
65  DiscreteValues values(choices);
66  // Case where the Gaussian mixture has the same
67  // discrete keys as the decision tree.
68  if (conditionalKeySet == discreteProbsKeySet) {
69  if (prunedDiscreteProbs(values) == 0) {
70  return pruned_prob;
71  } else {
72  return probability;
73  }
74  } else {
75  // Due to branch merging (aka pruning) in DecisionTree, it is possible we
76  // get a `values` which doesn't have the full set of keys.
77  std::set<Key> valuesKeys;
78  for (auto kvp : values) {
79  valuesKeys.insert(kvp.first);
80  }
81  std::set<Key> conditionalKeys;
82  for (auto kvp : conditionalKeySet) {
83  conditionalKeys.insert(kvp.first);
84  }
85  // If true, then values is missing some keys
86  if (conditionalKeys != valuesKeys) {
87  // Get the keys present in conditionalKeys but not in valuesKeys
88  std::vector<Key> missing_keys;
89  std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
90  valuesKeys.begin(), valuesKeys.end(),
91  std::back_inserter(missing_keys));
92  // Insert missing keys with a default assignment.
93  for (auto missing_key : missing_keys) {
94  values[missing_key] = 0;
95  }
96  }
97 
98  // Now we generate the full assignment by enumerating
99  // over all keys in the prunedDiscreteProbs.
100  // First we find the differing keys
101  std::vector<DiscreteKey> set_diff;
102  std::set_difference(discreteProbsKeySet.begin(),
103  discreteProbsKeySet.end(), conditionalKeySet.begin(),
104  conditionalKeySet.end(),
105  std::back_inserter(set_diff));
106 
107  // Now enumerate over all assignments of the differing keys
108  const std::vector<DiscreteValues> assignments =
110  for (const DiscreteValues &assignment : assignments) {
111  DiscreteValues augmented_values(values);
112  augmented_values.insert(assignment);
113 
114  // If any one of the sub-branches are non-zero,
115  // we need this probability.
116  if (prunedDiscreteProbs(augmented_values) > 0.0) {
117  return probability;
118  }
119  }
120  // If we are here, it means that all the sub-branches are 0,
121  // so we prune.
122  return pruned_prob;
123  }
124  };
125  return pruner;
126 }
127 
128 /* ************************************************************************* */
130  size_t maxNrLeaves) {
131  // Get the joint distribution of only the discrete keys
132  // The joint discrete probability.
133  DiscreteConditional discreteProbs;
134 
135  std::vector<size_t> discrete_factor_idxs;
136  // Record frontal keys so we can maintain ordering
137  Ordering discrete_frontals;
138 
139  for (size_t i = 0; i < this->size(); i++) {
140  auto conditional = this->at(i);
141  if (conditional->isDiscrete()) {
142  discreteProbs = discreteProbs * (*conditional->asDiscrete());
143 
144  Ordering conditional_keys(conditional->frontals());
145  discrete_frontals += conditional_keys;
146  discrete_factor_idxs.push_back(i);
147  }
148  }
149 
150  const DecisionTreeFactor prunedDiscreteProbs =
151  discreteProbs.prune(maxNrLeaves);
152 
153  // Eliminate joint probability back into conditionals
154  DiscreteFactorGraph dfg{prunedDiscreteProbs};
155  DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
156 
157  // Assign pruned discrete conditionals back at the correct indices.
158  for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
159  size_t idx = discrete_factor_idxs.at(i);
160  this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
161  }
162 
163  return prunedDiscreteProbs;
164 }
165 
166 /* ************************************************************************* */
168  DecisionTreeFactor prunedDiscreteProbs =
169  this->pruneDiscreteConditionals(maxNrLeaves);
170 
171  /* To prune, we visitWith every leaf in the GaussianMixture.
172  * For each leaf, using the assignment we can check the discrete decision tree
173  * for 0.0 probability, then just set the leaf to a nullptr.
174  *
175  * We can later check the GaussianMixture for just nullptrs.
176  */
177 
178  HybridBayesNet prunedBayesNetFragment;
179 
180  // Go through all the conditionals in the
181  // Bayes Net and prune them as per prunedDiscreteProbs.
182  for (auto &&conditional : *this) {
183  if (auto gm = conditional->asMixture()) {
184  // Make a copy of the Gaussian mixture and prune it!
185  auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
186  prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
187 
188  // Type-erase and add to the pruned Bayes Net fragment.
189  prunedBayesNetFragment.push_back(prunedGaussianMixture);
190 
191  } else {
192  // Add the non-GaussianMixture conditional
193  prunedBayesNetFragment.push_back(conditional);
194  }
195  }
196 
197  return prunedBayesNetFragment;
198 }
199 
200 /* ************************************************************************* */
202  const DiscreteValues &assignment) const {
204  for (auto &&conditional : *this) {
205  if (auto gm = conditional->asMixture()) {
206  // If conditional is hybrid, select based on assignment.
207  gbn.push_back((*gm)(assignment));
208  } else if (auto gc = conditional->asGaussian()) {
209  // If continuous only, add Gaussian conditional.
210  gbn.push_back(gc);
211  } else if (auto dc = conditional->asDiscrete()) {
212  // If conditional is discrete-only, we simply continue.
213  continue;
214  }
215  }
216 
217  return gbn;
218 }
219 
220 /* ************************************************************************* */
222  // Collect all the discrete factors to compute MPE
223  DiscreteBayesNet discrete_bn;
224  for (auto &&conditional : *this) {
225  if (conditional->isDiscrete()) {
226  discrete_bn.push_back(conditional->asDiscrete());
227  }
228  }
229 
230  // Solve for the MPE
231  DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
232 
233  // Given the MPE, compute the optimal continuous values.
234  return HybridValues(optimize(mpe), mpe);
235 }
236 
237 /* ************************************************************************* */
239  GaussianBayesNet gbn = choose(assignment);
240 
241  // Check if there exists a nullptr in the GaussianBayesNet
242  // If yes, return an empty VectorValues
243  if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
244  return VectorValues();
245  }
246  return gbn.optimize();
247 }
248 
249 /* ************************************************************************* */
251  std::mt19937_64 *rng) const {
252  DiscreteBayesNet dbn;
253  for (auto &&conditional : *this) {
254  if (conditional->isDiscrete()) {
255  // If conditional is discrete-only, we add to the discrete Bayes net.
256  dbn.push_back(conditional->asDiscrete());
257  }
258  }
259  // Sample a discrete assignment.
260  const DiscreteValues assignment = dbn.sample(given.discrete());
261  // Select the continuous Bayes net corresponding to the assignment.
262  GaussianBayesNet gbn = choose(assignment);
263  // Sample from the Gaussian Bayes net.
264  VectorValues sample = gbn.sample(given.continuous(), rng);
265  return {sample, assignment};
266 }
267 
268 /* ************************************************************************* */
269 HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
270  HybridValues given;
271  return sample(given, rng);
272 }
273 
274 /* ************************************************************************* */
276  return sample(given, &kRandomNumberGenerator);
277 }
278 
279 /* ************************************************************************* */
282 }
283 
284 /* ************************************************************************* */
286  const VectorValues &continuousValues) const {
288 
289  // Iterate over each conditional.
290  for (auto &&conditional : *this) {
291  if (auto gm = conditional->asMixture()) {
292  // If conditional is hybrid, compute error for all assignments.
293  result = result + gm->errorTree(continuousValues);
294 
295  } else if (auto gc = conditional->asGaussian()) {
296  // If continuous, get the error and add it to the result
297  double error = gc->error(continuousValues);
298  // Add the computed error to every leaf of the result tree.
299  result = result.apply(
300  [error](double leaf_value) { return leaf_value + error; });
301 
302  } else if (auto dc = conditional->asDiscrete()) {
303  // If discrete, add the discrete error in the right branch
304  result = result.apply(
305  [dc](const Assignment<Key> &assignment, double leaf_value) {
306  return leaf_value + dc->error(DiscreteValues(assignment));
307  });
308  }
309  }
310 
311  return result;
312 }
313 
314 /* ************************************************************************* */
316  const VectorValues &continuousValues) const {
318 
319  // Iterate over each conditional.
320  for (auto &&conditional : *this) {
321  if (auto gm = conditional->asMixture()) {
322  // If conditional is hybrid, select based on assignment and compute
323  // logProbability.
324  result = result + gm->logProbability(continuousValues);
325  } else if (auto gc = conditional->asGaussian()) {
326  // If continuous, get the (double) logProbability and add it to the
327  // result
328  double logProbability = gc->logProbability(continuousValues);
329  // Add the computed logProbability to every leaf of the logProbability
330  // tree.
331  result = result.apply([logProbability](double leaf_value) {
332  return leaf_value + logProbability;
333  });
334  } else if (auto dc = conditional->asDiscrete()) {
335  // If discrete, add the discrete logProbability in the right branch
336  result = result.apply(
337  [dc](const Assignment<Key> &assignment, double leaf_value) {
338  return leaf_value + dc->logProbability(DiscreteValues(assignment));
339  });
340  }
341  }
342 
343  return result;
344 }
345 
346 /* ************************************************************************* */
348  const VectorValues &continuousValues) const {
349  AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
350  return tree.apply([](double log) { return exp(log); });
351 }
352 
353 /* ************************************************************************* */
355  return exp(logProbability(values));
356 }
357 
358 /* ************************************************************************* */
360  const VectorValues &measurements) const {
362 
363  // For all nodes in the Bayes net, if its frontal variable is in measurements,
364  // replace it by a likelihood factor:
365  for (auto &&conditional : *this) {
366  if (conditional->frontalsIn(measurements)) {
367  if (auto gc = conditional->asGaussian()) {
368  fg.push_back(gc->likelihood(measurements));
369  } else if (auto gm = conditional->asMixture()) {
370  fg.push_back(gm->likelihood(measurements));
371  } else {
372  throw std::runtime_error("Unknown conditional type");
373  }
374  } else {
375  fg.push_back(conditional);
376  }
377  }
378  return fg;
379 }
380 
381 } // namespace gtsam
gtsam::HybridBayesNet::equals
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
Definition: HybridBayesNet.cpp:36
DiscreteBayesNet.h
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
gtsam::HybridValues
Definition: HybridValues.h:38
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::HybridBayesNet::sample
HybridValues sample() const
Sample using ancestral sampling, use default rng.
Definition: HybridBayesNet.cpp:280
gtsam::DiscreteBayesNet::sample
DiscreteValues sample() const
do ancestral sampling
Definition: DiscreteBayesNet.cpp:52
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:98
gtsam::HybridConditional
Definition: HybridConditional.h:59
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::BayesNet< HybridConditional >::print
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31
gtsam::prunerFunc
std::function< double(const Assignment< Key > &, double)> prunerFunc(const DecisionTreeFactor &prunedDiscreteProbs, const HybridConditional &conditional)
Helper function to get the pruner functional.
Definition: HybridBayesNet.cpp:48
gtsam::FactorGraph< HybridConditional >::error
double error(const HybridValues &values) const
Definition: FactorGraph-inst.h:66
gtsam::HybridBayesNet
Definition: HybridBayesNet.h:35
DiscreteFactorGraph.h
gtsam::HybridBayesNet::evaluate
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
Definition: HybridBayesNet.cpp:354
gtsam::HybridValues::continuous
const VectorValues & continuous() const
Return the multi-dimensional vector values.
Definition: HybridValues.h:89
tree
Definition: testExpression.cpp:212
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:371
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
log
const EIGEN_DEVICE_FUNC LogReturnType log() const
Definition: ArrayCwiseUnaryOps.h:128
exp
const EIGEN_DEVICE_FUNC ExpReturnType exp() const
Definition: ArrayCwiseUnaryOps.h:97
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::AlgebraicDecisionTree< Key >
gtsam::FactorGraph< HybridConditional >::at
const sharedFactor at(size_t i) const
Definition: FactorGraph.h:306
gtsam::HybridBayesNet::optimize
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
Definition: HybridBayesNet.cpp:221
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
gtsam::DiscreteValues::CartesianProduct
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
Definition: DiscreteValues.h:85
gtsam::kRandomNumberGenerator
static std::mt19937 kRandomNumberGenerator(42)
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::HybridBayesNet::errorTree
AlgebraicDecisionTree< Key > errorTree(const VectorValues &continuousValues) const
Compute conditional error for each discrete assignment, and return as a tree.
Definition: HybridBayesNet.cpp:285
gtsam::FactorGraph< HybridConditional >::equals
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
Definition: FactorGraph-inst.h:50
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:104
gtsam::HybridBayesNet::print
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
Definition: HybridBayesNet.cpp:30
gtsam::Assignment< Key >
gtsam::HybridBayesNet::pruneDiscreteConditionals
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves)
Prune all the discrete conditionals.
Definition: HybridBayesNet.cpp:129
sampling::gbn
static const GaussianBayesNet gbn
Definition: testGaussianBayesNet.cpp:170
gtsam::HybridBayesNet::prune
HybridBayesNet prune(size_t maxNrLeaves)
Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
Definition: HybridBayesNet.cpp:167
gtsam::FactorGraph< HybridConditional >::size
size_t size() const
Definition: FactorGraph.h:297
gtsam::DiscreteKeysAsSet
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
Definition: GaussianMixture.cpp:221
gtsam::HybridValues::discrete
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.h:92
gtsam::HybridBayesNet::push_back
void push_back(std::shared_ptr< HybridConditional > conditional)
Add a hybrid conditional using a shared_ptr.
Definition: HybridBayesNet.h:69
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam
traits
Definition: chartTesting.h:28
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
leaf::values
leaf::MyValues values
gtsam::DiscreteValues::insert
std::pair< iterator, bool > insert(const value_type &value)
Definition: DiscreteValues.h:68
gtsam::DiscreteBayesNet::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: DiscreteBayesNet.h:43
gtsam::HybridBayesNet::logProbability
AlgebraicDecisionTree< Key > logProbability(const VectorValues &continuousValues) const
Compute log probability for each discrete assignment, and return as a tree.
Definition: HybridBayesNet.cpp:315
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::DiscreteFactor::discreteKeys
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Definition: DiscreteFactor.cpp:32
gtsam::DiscreteFactorGraph::optimize
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
Definition: DiscreteFactorGraph.cpp:189
gtsam::HybridBayesNet::toFactorGraph
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
Definition: HybridBayesNet.cpp:359
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::HybridFactor::discreteKeys
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:130
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
gtsam::GaussianBayesNet
Definition: GaussianBayesNet.h:35
gtsam::HybridBayesNet::choose
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes Net which corresponds to a specific discrete value assignment.
Definition: HybridBayesNet.cpp:201
kRandomNumberGenerator
static std::mt19937_64 kRandomNumberGenerator(42)


gtsam
Author(s):
autogenerated on Thu Jun 13 2024 03:02:33