HybridBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2  * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
3  * Atlanta, Georgia 30332-0415
4  * All Rights Reserved
5  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
6  * See LICENSE for the license information
7  * -------------------------------------------------------------------------- */
8 
23 
24 // In Wrappers we have no access to this so have a default ready
25 static std::mt19937_64 kRandomNumberGenerator(42);
26 
27 namespace gtsam {
28 
29 /* ************************************************************************* */
30 void HybridBayesNet::print(const std::string &s,
31  const KeyFormatter &formatter) const {
32  Base::print(s, formatter);
33 }
34 
35 /* ************************************************************************* */
36 bool HybridBayesNet::equals(const This &bn, double tol) const {
37  return Base::equals(bn, tol);
38 }
39 
40 /* ************************************************************************* */
42  AlgebraicDecisionTree<Key> decisionTree;
43 
44  // The canonical decision tree factor which will get
45  // the discrete conditionals added to it.
46  DecisionTreeFactor dtFactor;
47 
48  for (auto &&conditional : *this) {
49  if (conditional->isDiscrete()) {
50  // Convert to a DecisionTreeFactor and add it to the main factor.
51  DecisionTreeFactor f(*conditional->asDiscrete());
52  dtFactor = dtFactor * f;
53  }
54  }
55  return std::make_shared<DecisionTreeFactor>(dtFactor);
56 }
57 
58 /* ************************************************************************* */
66 std::function<double(const Assignment<Key> &, double)> prunerFunc(
67  const DecisionTreeFactor &prunedDecisionTree,
68  const HybridConditional &conditional) {
69  // Get the discrete keys as sets for the decision tree
70  // and the Gaussian mixture.
71  std::set<DiscreteKey> decisionTreeKeySet =
72  DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
73  std::set<DiscreteKey> conditionalKeySet =
74  DiscreteKeysAsSet(conditional.discreteKeys());
75 
76  auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
77  const Assignment<Key> &choices,
78  double probability) -> double {
79  // This corresponds to 0 probability
80  double pruned_prob = 0.0;
81 
82  // typecast so we can use this to get probability value
83  DiscreteValues values(choices);
84  // Case where the Gaussian mixture has the same
85  // discrete keys as the decision tree.
86  if (conditionalKeySet == decisionTreeKeySet) {
87  if (prunedDecisionTree(values) == 0) {
88  return pruned_prob;
89  } else {
90  return probability;
91  }
92  } else {
93  // Due to branch merging (aka pruning) in DecisionTree, it is possible we
94  // get a `values` which doesn't have the full set of keys.
95  std::set<Key> valuesKeys;
96  for (auto kvp : values) {
97  valuesKeys.insert(kvp.first);
98  }
99  std::set<Key> conditionalKeys;
100  for (auto kvp : conditionalKeySet) {
101  conditionalKeys.insert(kvp.first);
102  }
103  // If true, then values is missing some keys
104  if (conditionalKeys != valuesKeys) {
105  // Get the keys present in conditionalKeys but not in valuesKeys
106  std::vector<Key> missing_keys;
107  std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
108  valuesKeys.begin(), valuesKeys.end(),
109  std::back_inserter(missing_keys));
110  // Insert missing keys with a default assignment.
111  for (auto missing_key : missing_keys) {
112  values[missing_key] = 0;
113  }
114  }
115 
116  // Now we generate the full assignment by enumerating
117  // over all keys in the prunedDecisionTree.
118  // First we find the differing keys
119  std::vector<DiscreteKey> set_diff;
120  std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
121  conditionalKeySet.begin(), conditionalKeySet.end(),
122  std::back_inserter(set_diff));
123 
124  // Now enumerate over all assignments of the differing keys
125  const std::vector<DiscreteValues> assignments =
127  for (const DiscreteValues &assignment : assignments) {
128  DiscreteValues augmented_values(values);
129  augmented_values.insert(assignment);
130 
131  // If any one of the sub-branches are non-zero,
132  // we need this probability.
133  if (prunedDecisionTree(augmented_values) > 0.0) {
134  return probability;
135  }
136  }
137  // If we are here, it means that all the sub-branches are 0,
138  // so we prune.
139  return pruned_prob;
140  }
141  };
142  return pruner;
143 }
144 
145 /* ************************************************************************* */
147  const DecisionTreeFactor &prunedDecisionTree) {
148  KeyVector prunedTreeKeys = prunedDecisionTree.keys();
149 
150  // Loop with index since we need it later.
151  for (size_t i = 0; i < this->size(); i++) {
152  HybridConditional::shared_ptr conditional = this->at(i);
153  if (conditional->isDiscrete()) {
154  auto discrete = conditional->asDiscrete();
155 
156  // Apply prunerFunc to the underlying AlgebraicDecisionTree
157  auto discreteTree =
158  std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
159  DecisionTreeFactor::ADT prunedDiscreteTree =
160  discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
161 
162  // Create the new (hybrid) conditional
163  KeyVector frontals(discrete->frontals().begin(),
164  discrete->frontals().end());
165  auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
166  frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
167  conditional = std::make_shared<HybridConditional>(prunedDiscrete);
168 
169  // Add it back to the BayesNet
170  this->at(i) = conditional;
171  }
172  }
173 }
174 
175 /* ************************************************************************* */
177  // Get the decision tree of only the discrete keys
179  const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
180 
181  this->updateDiscreteConditionals(decisionTree);
182 
183  /* To Prune, we visitWith every leaf in the GaussianMixture.
184  * For each leaf, using the assignment we can check the discrete decision tree
185  * for 0.0 probability, then just set the leaf to a nullptr.
186  *
187  * We can later check the GaussianMixture for just nullptrs.
188  */
189 
190  HybridBayesNet prunedBayesNetFragment;
191 
192  // Go through all the conditionals in the
193  // Bayes Net and prune them as per decisionTree.
194  for (auto &&conditional : *this) {
195  if (auto gm = conditional->asMixture()) {
196  // Make a copy of the Gaussian mixture and prune it!
197  auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
198  prunedGaussianMixture->prune(decisionTree); // imperative :-(
199 
200  // Type-erase and add to the pruned Bayes Net fragment.
201  prunedBayesNetFragment.push_back(prunedGaussianMixture);
202 
203  } else {
204  // Add the non-GaussianMixture conditional
205  prunedBayesNetFragment.push_back(conditional);
206  }
207  }
208 
209  return prunedBayesNetFragment;
210 }
211 
212 /* ************************************************************************* */
214  const DiscreteValues &assignment) const {
216  for (auto &&conditional : *this) {
217  if (auto gm = conditional->asMixture()) {
218  // If conditional is hybrid, select based on assignment.
219  gbn.push_back((*gm)(assignment));
220  } else if (auto gc = conditional->asGaussian()) {
221  // If continuous only, add Gaussian conditional.
222  gbn.push_back(gc);
223  } else if (auto dc = conditional->asDiscrete()) {
224  // If conditional is discrete-only, we simply continue.
225  continue;
226  }
227  }
228 
229  return gbn;
230 }
231 
232 /* ************************************************************************* */
234  // Collect all the discrete factors to compute MPE
235  DiscreteBayesNet discrete_bn;
236  for (auto &&conditional : *this) {
237  if (conditional->isDiscrete()) {
238  discrete_bn.push_back(conditional->asDiscrete());
239  }
240  }
241 
242  // Solve for the MPE
243  DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
244 
245  // Given the MPE, compute the optimal continuous values.
246  return HybridValues(optimize(mpe), mpe);
247 }
248 
249 /* ************************************************************************* */
251  GaussianBayesNet gbn = choose(assignment);
252 
253  // Check if there exists a nullptr in the GaussianBayesNet
254  // If yes, return an empty VectorValues
255  if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
256  return VectorValues();
257  }
258  return gbn.optimize();
259 }
260 
261 /* ************************************************************************* */
263  std::mt19937_64 *rng) const {
264  DiscreteBayesNet dbn;
265  for (auto &&conditional : *this) {
266  if (conditional->isDiscrete()) {
267  // If conditional is discrete-only, we add to the discrete Bayes net.
268  dbn.push_back(conditional->asDiscrete());
269  }
270  }
271  // Sample a discrete assignment.
272  const DiscreteValues assignment = dbn.sample(given.discrete());
273  // Select the continuous Bayes net corresponding to the assignment.
274  GaussianBayesNet gbn = choose(assignment);
275  // Sample from the Gaussian Bayes net.
276  VectorValues sample = gbn.sample(given.continuous(), rng);
277  return {sample, assignment};
278 }
279 
280 /* ************************************************************************* */
281 HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
282  HybridValues given;
283  return sample(given, rng);
284 }
285 
286 /* ************************************************************************* */
288  return sample(given, &kRandomNumberGenerator);
289 }
290 
291 /* ************************************************************************* */
294 }
295 
296 /* ************************************************************************* */
298  const VectorValues &continuousValues) const {
300 
301  // Iterate over each conditional.
302  for (auto &&conditional : *this) {
303  if (auto gm = conditional->asMixture()) {
304  // If conditional is hybrid, select based on assignment and compute
305  // logProbability.
306  result = result + gm->logProbability(continuousValues);
307  } else if (auto gc = conditional->asGaussian()) {
308  // If continuous, get the (double) logProbability and add it to the
309  // result
310  double logProbability = gc->logProbability(continuousValues);
311  // Add the computed logProbability to every leaf of the logProbability
312  // tree.
313  result = result.apply([logProbability](double leaf_value) {
314  return leaf_value + logProbability;
315  });
316  } else if (auto dc = conditional->asDiscrete()) {
317  // If discrete, add the discrete logProbability in the right branch
318  result = result.apply(
319  [dc](const Assignment<Key> &assignment, double leaf_value) {
320  return leaf_value + dc->logProbability(DiscreteValues(assignment));
321  });
322  }
323  }
324 
325  return result;
326 }
327 
328 /* ************************************************************************* */
330  const VectorValues &continuousValues) const {
331  AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
332  return tree.apply([](double log) { return exp(log); });
333 }
334 
335 /* ************************************************************************* */
337  return exp(logProbability(values));
338 }
339 
340 /* ************************************************************************* */
342  const VectorValues &measurements) const {
344 
345  // For all nodes in the Bayes net, if its frontal variable is in measurements,
346  // replace it by a likelihood factor:
347  for (auto &&conditional : *this) {
348  if (conditional->frontalsIn(measurements)) {
349  if (auto gc = conditional->asGaussian()) {
350  fg.push_back(gc->likelihood(measurements));
351  } else if (auto gm = conditional->asMixture()) {
352  fg.push_back(gm->likelihood(measurements));
353  } else {
354  throw std::runtime_error("Unknown conditional type");
355  }
356  } else {
357  fg.push_back(conditional);
358  }
359  }
360  return fg;
361 }
362 
363 } // namespace gtsam
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
static std::mt19937 kRandomNumberGenerator(42)
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes Net which corresponds to a specific discrete value assignment.
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
VectorValues sample(std::mt19937_64 *rng) const
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:190
static std::mt19937 rng
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:129
DecisionTreeFactor::shared_ptr discreteConditionals() const
Get all the discrete conditionals as a decision tree factor.
leaf::MyValues values
static const GaussianBayesNet gbn
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
const VectorValues & continuous() const
Return the multi-dimensional vector values.
Definition: HybridValues.h:89
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
EIGEN_DEVICE_FUNC const LogReturnType log() const
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
const KeyFormatter & formatter
A Bayes net of Gaussian Conditionals indexed by discrete keys.
DiscreteConditional::shared_ptr asDiscrete() const
Return conditional as a DiscreteConditional.
EIGEN_DEVICE_FUNC const ExpReturnType exp() const
VectorValues optimize() const
const_iterator begin() const
Definition: FactorGraph.h:361
bool frontalsIn(const VectorValues &measurements) const
Check if VectorValues measurements contains all frontal keys.
GaussianConditional::shared_ptr asGaussian() const
Return HybridConditional as a GaussianConditional.
Values result
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
HybridBayesNet prune(size_t maxNrLeaves)
Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
RealScalar s
void push_back(std::shared_ptr< HybridConditional > conditional)
Add a hybrid conditional using a shared_ptr.
DecisionTree apply(const Unary &op) const
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
std::shared_ptr< This > shared_ptr
shared_ptr to this class
const_iterator end() const
Definition: FactorGraph.h:364
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
traits
Definition: chartTesting.h:28
std::shared_ptr< DecisionTreeFactor > shared_ptr
GaussianMixture::shared_ptr asMixture() const
Return HybridConditional as a GaussianMixture.
const sharedFactor at(size_t i) const
Definition: FactorGraph.h:343
static std::mt19937_64 kRandomNumberGenerator(42)
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree)
Update the discrete conditionals with the pruned versions.
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.h:92
const KeyVector & keys() const
Access the factor&#39;s involved variable keys.
Definition: Factor.h:142
std::function< double(const Assignment< Key > &, double)> prunerFunc(const DecisionTreeFactor &prunedDecisionTree, const HybridConditional &conditional)
Helper function to get the pruner functional.
const G double tol
Definition: Group.h:86
std::pair< iterator, bool > insert(const value_type &value)
AlgebraicDecisionTree< Key > logProbability(const VectorValues &continuousValues) const
Compute conditional error for each discrete assignment, and return as a tree.
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
bool isDiscrete() const
True if this is a factor of discrete variables only.
Definition: HybridFactor.h:117
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Definition: BayesNet-inst.h:31
DiscreteValues sample() const
do ancestral sampling
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
HybridValues sample() const
Sample using ancestral sampling, use default rng.


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:20