HybridGaussianFactorGraph.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #include <gtsam/base/utilities.h>
34 #include <gtsam/inference/Key.h>
41 
42 #include <algorithm>
43 #include <cstddef>
44 #include <iostream>
45 #include <iterator>
46 #include <memory>
47 #include <stdexcept>
48 #include <utility>
49 #include <vector>
50 
51 // #define HYBRID_TIMING
52 
53 namespace gtsam {
54 
56 template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
57 
59 
60 using std::dynamic_pointer_cast;
61 
62 /* ************************************************************************ */
63 // Throw a runtime exception for method specified in string s, and factor f:
64 static void throwRuntimeError(const std::string &s,
65  const std::shared_ptr<Factor> &f) {
66  auto &fr = *f;
67  throw std::runtime_error(s + " not implemented for factor type " +
68  demangle(typeid(fr).name()) + ".");
69 }
70 
71 /* ************************************************************************ */
73  KeySet discrete_keys = graph.discreteKeySet();
74  const VariableIndex index(graph);
76  index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
77 }
78 
79 /* ************************************************************************ */
81  const GaussianFactorGraphTree &gfgTree,
82  const GaussianFactor::shared_ptr &factor) {
83  // If the decision tree is not initialized, then initialize it.
84  if (gfgTree.empty()) {
87  } else {
88  auto add = [&factor](const GaussianFactorGraph &graph) {
89  auto result = graph;
90  result.push_back(factor);
91  return result;
92  };
93  return gfgTree.apply(add);
94  }
95 }
96 
97 /* ************************************************************************ */
98 // TODO(dellaert): it's probably more efficient to first collect the discrete
99 // keys, and then loop over all assignments to populate a vector.
102 
104 
105  for (auto &f : factors_) {
106  // TODO(dellaert): just use a virtual method defined in HybridFactor.
107  if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
108  result = addGaussian(result, gf);
109  } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
110  result = gmf->add(result);
111  } else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
112  result = gm->add(result);
113  } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
114  if (auto gm = hc->asMixture()) {
115  result = gm->add(result);
116  } else if (auto g = hc->asGaussian()) {
117  result = addGaussian(result, g);
118  } else {
119  // Has to be discrete.
120  // TODO(dellaert): in C++20, we can use std::visit.
121  continue;
122  }
123  } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
124  // Don't do anything for discrete-only factors
125  // since we want to eliminate continuous values only.
126  continue;
127  } else {
128  // TODO(dellaert): there was an unattributed comment here: We need to
129  // handle the case where the object is actually an BayesTreeOrphanWrapper!
130  throwRuntimeError("gtsam::assembleGraphTree", f);
131  }
132  }
133 
135 
136  return result;
137 }
138 
139 /* ************************************************************************ */
140 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
142  const Ordering &frontalKeys) {
144  for (auto &f : factors) {
145  if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
146  gfg.push_back(gf);
147  } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
148  // Ignore orphaned clique.
149  // TODO(dellaert): is this correct? If so explain here.
150  } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
151  auto gc = hc->asGaussian();
152  if (!gc) throwRuntimeError("continuousElimination", gc);
153  gfg.push_back(gc);
154  } else {
155  throwRuntimeError("continuousElimination", f);
156  }
157  }
158 
159  auto result = EliminatePreferCholesky(gfg, frontalKeys);
160  return {std::make_shared<HybridConditional>(result.first), result.second};
161 }
162 
163 /* ************************************************************************ */
164 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
166  const Ordering &frontalKeys) {
168 
169  for (auto &f : factors) {
170  if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
171  dfg.push_back(dtf);
172  } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
173  // Ignore orphaned clique.
174  // TODO(dellaert): is this correct? If so explain here.
175  } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
176  auto dc = hc->asDiscrete();
177  if (!dc) throwRuntimeError("continuousElimination", dc);
178  dfg.push_back(hc->asDiscrete());
179  } else {
180  throwRuntimeError("continuousElimination", f);
181  }
182  }
183 
184  // NOTE: This does sum-product. For max-product, use EliminateForMPE.
185  auto result = EliminateDiscrete(dfg, frontalKeys);
186 
187  return {std::make_shared<HybridConditional>(result.first), result.second};
188 }
189 
190 /* ************************************************************************ */
191 // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
192 // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
193 // otherwise create a GFG with a single (null) factor, which doesn't register as null.
195  auto emptyGaussian = [](const GaussianFactorGraph &graph) {
196  bool hasNull =
197  std::any_of(graph.begin(), graph.end(),
198  [](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
199  return hasNull ? GaussianFactorGraph() : graph;
200  };
201  return GaussianFactorGraphTree(sum, emptyGaussian);
202 }
203 
204 /* ************************************************************************ */
205 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
207  const Ordering &frontalKeys,
208  const KeyVector &continuousSeparator,
209  const std::set<DiscreteKey> &discreteSeparatorSet) {
210  // NOTE: since we use the special JunctionTree,
211  // only possibility is continuous conditioned on discrete.
212  DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
213  discreteSeparatorSet.end());
214 
215  // Collect all the factors to create a set of Gaussian factor graphs in a
216  // decision tree indexed by all discrete keys involved.
217  GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
218 
219  // Convert factor graphs with a nullptr to an empty factor graph.
220  // This is done after assembly since it is non-trivial to keep track of which
221  // FG has a nullptr as we're looping over the factors.
222  factorGraphTree = removeEmpty(factorGraphTree);
223 
224  using Result = std::pair<std::shared_ptr<GaussianConditional>,
226 
227  // This is the elimination method on the leaf nodes
228  auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
229  if (graph.empty()) {
230  return {nullptr, nullptr};
231  }
232 
233 #ifdef HYBRID_TIMING
234  gttic_(hybrid_eliminate);
235 #endif
236 
237  auto result = EliminatePreferCholesky(graph, frontalKeys);
238 
239 #ifdef HYBRID_TIMING
240  gttoc_(hybrid_eliminate);
241 #endif
242 
243  return result;
244  };
245 
246  // Perform elimination!
247  DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
248 
249 #ifdef HYBRID_TIMING
250  tictoc_print_();
251 #endif
252 
253  // Separate out decision tree into conditionals and remaining factors.
254  const auto [conditionals, newFactors] = unzip(eliminationResults);
255 
256  // Create the GaussianMixture from the conditionals
257  auto gaussianMixture = std::make_shared<GaussianMixture>(
258  frontalKeys, continuousSeparator, discreteSeparator, conditionals);
259 
260  if (continuousSeparator.empty()) {
261  // If there are no more continuous parents, then we create a
262  // DiscreteFactor here, with the error for each discrete choice.
263 
264  // Integrate the probability mass in the last continuous conditional using
265  // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
266  // discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
267  auto probability = [&](const Result &pair) -> double {
268  static const VectorValues kEmpty;
269  // If the factor is not null, it has no keys, just contains the residual.
270  const auto &factor = pair.second;
271  if (!factor) return 1.0; // TODO(dellaert): not loving this.
272  return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
273  };
274 
275  DecisionTree<Key, double> probabilities(eliminationResults, probability);
276  return {
277  std::make_shared<HybridConditional>(gaussianMixture),
278  std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
279  } else {
280  // Otherwise, we create a resulting GaussianMixtureFactor on the separator,
281  // taking care to correct for conditional constant.
282 
283  // Correct for the normalization constant used up by the conditional
284  auto correct = [&](const Result &pair) {
285  const auto &factor = pair.second;
286  if (!factor) return;
287  auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
288  if (!hf) throw std::runtime_error("Expected HessianFactor!");
289  hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
290  };
291  eliminationResults.visit(correct);
292 
293  const auto mixtureFactor = std::make_shared<GaussianMixtureFactor>(
294  continuousSeparator, discreteSeparator, newFactors);
295 
296  return {std::make_shared<HybridConditional>(gaussianMixture),
297  mixtureFactor};
298  }
299 }
300 
301 /* ************************************************************************
302  * Function to eliminate variables **under the following assumptions**:
303  * 1. When the ordering is fully continuous, and the graph only contains
304  * continuous and hybrid factors
305  * 2. When the ordering is fully discrete, and the graph only contains discrete
306  * factors
307  *
308  * Any usage outside of this is considered incorrect.
309  *
310  * \warning This function is not meant to be used with arbitrary hybrid factor
311  * graphs. For example, if there exists continuous parents, and one tries to
312  * eliminate a discrete variable (as specified in the ordering), the result will
313  * be INCORRECT and there will be NO error raised.
314  */
315 std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
317  const Ordering &frontalKeys) {
318  // NOTE: Because we are in the Conditional Gaussian regime there are only
319  // a few cases:
320  // 1. continuous variable, make a Gaussian Mixture if there are hybrid
321  // factors;
322  // 2. continuous variable, we make a Gaussian Factor if there are no hybrid
323  // factors;
324  // 3. discrete variable, no continuous factor is allowed
325  // (escapes Conditional Gaussian regime), if discrete only we do the discrete
326  // elimination.
327 
328  // However it is not that simple. During elimination it is possible that the
329  // multifrontal needs to eliminate an ordering that contains both Gaussian and
330  // hybrid variables, for example x1, c1.
331  // In this scenario, we will have a density P(x1, c1) that is a Conditional
332  // Linear Gaussian P(x1|c1)P(c1) (see Murphy02).
333 
334  // The issue here is that, how can we know which variable is discrete if we
335  // unify Values? Obviously we can tell using the factors, but is that fast?
336 
337  // In the case of multifrontal, we will need to use a constrained ordering
338  // so that the discrete parts will be guaranteed to be eliminated last!
339  // Because of all these reasons, we carefully consider how to
340  // implement the hybrid factors so that we do not get poor performance.
341 
342  // The first thing is how to represent the GaussianMixture.
343  // A very possible scenario is that the incoming factors will have different
344  // levels of discrete keys. For example, imagine we are going to eliminate the
345  // fragment: $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid.
346  // Now we will need to know how to retrieve the corresponding continuous
347  // densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO
348  // defined order!). We also need to consider when there is pruning. Two
349  // mixture factors could have different pruning patterns - one could have
350  // (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this
351  // creates a big problem in how to identify the intersection of non-pruned
352  // branches.
353 
354  // Our approach is first building the collection of all discrete keys. After
355  // that we enumerate the space of all key combinations *lazily* so that the
356  // exploration branch terminates whenever an assignment yields NULL in any of
357  // the hybrid factors.
358 
359  // When the number of assignments is large we may encounter stack overflows.
360  // However this is also the case with iSAM2, so no pressure :)
361 
362  // PREPROCESS: Identify the nature of the current elimination
363 
364  // TODO(dellaert): just check the factors:
365  // 1. if all factors are discrete, then we can do discrete elimination:
366  // 2. if all factors are continuous, then we can do continuous elimination:
367  // 3. if not, we do hybrid elimination:
368 
369  // First, identify the separator keys, i.e. all keys that are not frontal.
370  KeySet separatorKeys;
371  for (auto &&factor : factors) {
372  separatorKeys.insert(factor->begin(), factor->end());
373  }
374  // remove frontals from separator
375  for (auto &k : frontalKeys) {
376  separatorKeys.erase(k);
377  }
378 
379  // Build a map from keys to DiscreteKeys
380  auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
381 
382  // Fill in discrete frontals and continuous frontals.
383  std::set<DiscreteKey> discreteFrontals;
384  KeySet continuousFrontals;
385  for (auto &k : frontalKeys) {
386  if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
387  discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
388  } else {
389  continuousFrontals.insert(k);
390  }
391  }
392 
393  // Fill in discrete discrete separator keys and continuous separator keys.
394  std::set<DiscreteKey> discreteSeparatorSet;
395  KeyVector continuousSeparator;
396  for (auto &k : separatorKeys) {
397  if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
398  discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
399  } else {
400  continuousSeparator.push_back(k);
401  }
402  }
403 
404  // Check if we have any continuous keys:
405  const bool discrete_only =
406  continuousFrontals.empty() && continuousSeparator.empty();
407 
408  // NOTE: We should really defer the product here because of pruning
409 
410  if (discrete_only) {
411  // Case 1: we are only dealing with discrete
412  return discreteElimination(factors, frontalKeys);
413  } else if (mapFromKeyToDiscreteKey.empty()) {
414  // Case 2: we are only dealing with continuous
415  return continuousElimination(factors, frontalKeys);
416  } else {
417  // Case 3: We are now in the hybrid land!
418  return hybridElimination(factors, frontalKeys, continuousSeparator,
419  discreteSeparatorSet);
420  }
421 }
422 
423 /* ************************************************************************ */
425  const VectorValues &continuousValues) const {
426  AlgebraicDecisionTree<Key> error_tree(0.0);
427 
428  // Iterate over each factor.
429  for (auto &f : factors_) {
430  // TODO(dellaert): just use a virtual method defined in HybridFactor.
431  AlgebraicDecisionTree<Key> factor_error;
432 
433  if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
434  // Compute factor error and add it.
435  error_tree = error_tree + gaussianMixture->error(continuousValues);
436  } else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
437  // If continuous only, get the (double) error
438  // and add it to the error_tree
439  double error = gaussian->error(continuousValues);
440  // Add the gaussian factor error to every leaf of the error tree.
441  error_tree = error_tree.apply(
442  [error](double leaf_value) { return leaf_value + error; });
443  } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
444  // If factor at `idx` is discrete-only, we skip.
445  continue;
446  } else {
447  throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
448  }
449  }
450 
451  return error_tree;
452 }
453 
454 /* ************************************************************************ */
456  double error = this->error(values);
457  // NOTE: The 0.5 term is handled by each factor
458  return std::exp(-error);
459 }
460 
461 /* ************************************************************************ */
463  const VectorValues &continuousValues) const {
464  AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
465  AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
466  // NOTE: The 0.5 term is handled by each factor
467  return exp(-error);
468  });
469  return prob_tree;
470 }
471 
472 } // namespace gtsam
A set of GaussianFactors, indexed by a set of discrete keys.
A hybrid conditional in the Conditional Linear Gaussian scheme.
KeySet discreteKeySet() const
Get all the discrete keys in the factor graph, as a set.
static void throwRuntimeError(const std::string &s, const std::shared_ptr< Factor > &f)
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
#define gttic_(label)
Definition: timing.h:245
void visit(Func f) const
Visit all leaves in depth-first fashion.
An assignment from labels to a discrete value index (size_t)
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:190
leaf::MyValues values
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph)
Return a Colamd constrained ordering where the discrete keys are eliminated after the continuous keys...
const GaussianFactorGraph factors
std::pair< std::shared_ptr< GaussianConditional >, std::shared_ptr< GaussianFactor > > EliminatePreferCholesky(const GaussianFactorGraph &factors, const Ordering &keys)
NonlinearFactorGraph graph
void g(const string &key, int i)
Definition: testBTree.cpp:41
EIGEN_DEVICE_FUNC const ExpReturnType exp() const
#define gttic(label)
Definition: timing.h:295
static GaussianFactorGraphTree addGaussian(const GaussianFactorGraphTree &gfgTree, const GaussianFactor::shared_ptr &factor)
Values result
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:425
AlgebraicDecisionTree< Key > error(const VectorValues &continuousValues) const
Compute error for each discrete assignment, and return as a tree.
DecisionTree< Key, GaussianFactorGraph > GaussianFactorGraphTree
Alias for DecisionTree of GaussianFactorGraphs.
Definition: HybridFactor.h:34
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
static std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > continuousElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys)
Contains the HessianFactor class, a general quadratic factor.
void tictoc_print_()
Definition: timing.h:268
RealScalar s
std::shared_ptr< GaussianFactor > sharedFactor
Conditional Gaussian Base class.
Linear Factor Graph where all factors are Gaussians.
DecisionTree apply(const Unary &op) const
AlgebraicDecisionTree< Key > probPrime(const VectorValues &continuousValues) const
Compute unnormalized probability for each discrete assignment, and return as a tree.
Linearized Hybrid factor graph that uses type erasure.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
traits
Definition: chartTesting.h:28
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:240
std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > EliminateHybrid(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for HybridGaussianFactorGraph.
std::string demangle(const char *name)
Pretty print Value type name.
Definition: types.cpp:37
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum)
#define gttoc(label)
Definition: timing.h:296
A Gaussian factor using the canonical parameters (information form)
GaussianFactorGraphTree assembleGraphTree() const
Create a decision tree of factor graphs out of this hybrid factor graph.
graph add(PriorFactor< Pose2 >(1, priorMean, priorNoise))
const std::vector< GaussianConditional::shared_ptr > conditionals
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
#define gttoc_(label)
Definition: timing.h:250
static Ordering ColamdConstrainedLast(const FACTOR_GRAPH &graph, const KeyVector &constrainLast, bool forceOrder=false)
static std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, const KeyVector &continuousSeparator, const std::set< DiscreteKey > &discreteSeparatorSet)
FastVector< sharedFactor > factors_
Definition: FactorGraph.h:135
static std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys)
negation< all_of< negation< Ts >... > > any_of
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:20