56 template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
60 using std::dynamic_pointer_cast;
65 const std::shared_ptr<Factor> &
f) {
67 throw std::runtime_error(s +
" not implemented for factor type " +
76 index,
KeyVector(discrete_keys.begin(), discrete_keys.end()),
true);
84 if (gfgTree.
empty()) {
107 if (
auto gf = dynamic_pointer_cast<GaussianFactor>(
f)) {
109 }
else if (
auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(
f)) {
110 result = gmf->add(result);
111 }
else if (
auto gm = dynamic_pointer_cast<GaussianMixture>(
f)) {
112 result = gm->add(result);
113 }
else if (
auto hc = dynamic_pointer_cast<HybridConditional>(
f)) {
114 if (
auto gm =
hc->asMixture()) {
115 result = gm->add(result);
116 }
else if (
auto g =
hc->asGaussian()) {
123 }
else if (dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
140 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
144 for (
auto &
f : factors) {
145 if (
auto gf = dynamic_pointer_cast<GaussianFactor>(
f)) {
147 }
else if (
auto orphan = dynamic_pointer_cast<OrphanWrapper>(
f)) {
150 }
else if (
auto hc = dynamic_pointer_cast<HybridConditional>(
f)) {
151 auto gc =
hc->asGaussian();
160 return {std::make_shared<HybridConditional>(
result.first),
result.second};
164 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
169 for (
auto &
f : factors) {
170 if (
auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
172 }
else if (
auto orphan = dynamic_pointer_cast<OrphanWrapper>(
f)) {
175 }
else if (
auto hc = dynamic_pointer_cast<HybridConditional>(
f)) {
176 auto dc =
hc->asDiscrete();
187 return {std::make_shared<HybridConditional>(
result.first),
result.second};
205 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
209 const std::set<DiscreteKey> &discreteSeparatorSet) {
212 DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
213 discreteSeparatorSet.end());
224 using Result = std::pair<std::shared_ptr<GaussianConditional>,
230 return {
nullptr,
nullptr};
257 auto gaussianMixture = std::make_shared<GaussianMixture>(
258 frontalKeys, continuousSeparator, discreteSeparator,
conditionals);
260 if (continuousSeparator.empty()) {
267 auto probability = [&](
const Result &pair) ->
double {
270 const auto &factor = pair.second;
271 if (!factor)
return 1.0;
272 return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
277 std::make_shared<HybridConditional>(gaussianMixture),
278 std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
284 auto correct = [&](
const Result &pair) {
285 const auto &factor = pair.second;
288 if (!hf)
throw std::runtime_error(
"Expected HessianFactor!");
289 hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
291 eliminationResults.
visit(correct);
293 const auto mixtureFactor = std::make_shared<GaussianMixtureFactor>(
294 continuousSeparator, discreteSeparator, newFactors);
296 return {std::make_shared<HybridConditional>(gaussianMixture),
315 std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
371 for (
auto &&factor : factors) {
372 separatorKeys.insert(factor->begin(), factor->end());
375 for (
auto &k : frontalKeys) {
376 separatorKeys.erase(k);
380 auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
383 std::set<DiscreteKey> discreteFrontals;
384 KeySet continuousFrontals;
385 for (
auto &k : frontalKeys) {
386 if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
387 discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
389 continuousFrontals.insert(k);
394 std::set<DiscreteKey> discreteSeparatorSet;
396 for (
auto &k : separatorKeys) {
397 if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
398 discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
400 continuousSeparator.push_back(k);
405 const bool discrete_only =
406 continuousFrontals.empty() && continuousSeparator.empty();
413 }
else if (mapFromKeyToDiscreteKey.empty()) {
419 discreteSeparatorSet);
433 if (
auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(
f)) {
435 error_tree = error_tree + gaussianMixture->error(continuousValues);
436 }
else if (
auto gaussian = dynamic_pointer_cast<GaussianFactor>(
f)) {
439 double error = gaussian->error(continuousValues);
441 error_tree = error_tree.
apply(
442 [error](
double leaf_value) {
return leaf_value +
error; });
443 }
else if (dynamic_pointer_cast<DecisionTreeFactor>(
f)) {
A set of GaussianFactors, indexed by a set of discrete keys.
A hybrid conditional in the Conditional Linear Gaussian scheme.
KeySet discreteKeySet() const
Get all the discrete keys in the factor graph, as a set.
static void throwRuntimeError(const std::string &s, const std::shared_ptr< Factor > &f)
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
void visit(Func f) const
Visit all leaves in depth-first fashion.
An assignment from labels to a discrete value index (size_t)
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph)
Return a Colamd constrained ordering where the discrete keys are eliminated after the continuous keys...
const GaussianFactorGraph factors
std::pair< std::shared_ptr< GaussianConditional >, std::shared_ptr< GaussianFactor > > EliminatePreferCholesky(const GaussianFactorGraph &factors, const Ordering &keys)
NonlinearFactorGraph graph
void g(const string &key, int i)
EIGEN_DEVICE_FUNC const ExpReturnType exp() const
static GaussianFactorGraphTree addGaussian(const GaussianFactorGraphTree &gfgTree, const GaussianFactor::shared_ptr &factor)
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
AlgebraicDecisionTree< Key > error(const VectorValues &continuousValues) const
Compute error for each discrete assignment, and return as a tree.
DecisionTree< Key, GaussianFactorGraph > GaussianFactorGraphTree
Alias for DecisionTree of GaussianFactorGraphs.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
static std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > continuousElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys)
Contains the HessianFactor class, a general quadratic factor.
std::shared_ptr< GaussianFactor > sharedFactor
Conditional Gaussian Base class.
Linear Factor Graph where all factors are Gaussians.
DecisionTree apply(const Unary &op) const
AlgebraicDecisionTree< Key > probPrime(const VectorValues &continuousValues) const
Compute unnormalized probability for each discrete assignment, and return as a tree.
Linearized Hybrid factor graph that uses type erasure.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
bool empty() const
Check if tree is empty.
std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > EliminateHybrid(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for HybridGaussianFactorGraph.
std::string demangle(const char *name)
Pretty print Value type name.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum)
A Gaussian factor using the canonical parameters (information form)
GaussianFactorGraphTree assembleGraphTree() const
Create a decision tree of factor graphs out of this hybrid factor graph.
graph add(PriorFactor< Pose2 >(1, priorMean, priorNoise))
const std::vector< GaussianConditional::shared_ptr > conditionals
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
static Ordering ColamdConstrainedLast(const FACTOR_GRAPH &graph, const KeyVector &constrainLast, bool forceOrder=false)
static std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, const KeyVector &continuousSeparator, const std::set< DiscreteKey > &discreteSeparatorSet)
FastVector< sharedFactor > factors_
static std::pair< HybridConditional::shared_ptr, std::shared_ptr< Factor > > discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys)
negation< all_of< negation< Ts >... > > any_of
DiscreteKeys is a set of keys that can be assembled using the & operator.