HybridGaussianConditional.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #include <gtsam/base/utilities.h>
31 
32 #include <cstddef>
33 #include <memory>
34 
35 namespace gtsam {
36 
37 /* *******************************************************************************/
40  if (auto conditional =
41  std::dynamic_pointer_cast<GaussianConditional>(factor)) {
42  return conditional;
43  } else {
44  throw std::logic_error(
45  "A HybridGaussianConditional unexpectedly contained a non-conditional");
46  }
47 }
48 
49 /* *******************************************************************************/
61  std::optional<size_t> nrFrontals = {};
62  double minNegLogConstant = std::numeric_limits<double>::infinity();
63 
65  using P = std::vector<std::pair<Vector, double>>;
66 
68  template <typename... Args>
69  explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
70  nrFrontals = 1;
71  std::vector<GaussianFactorValuePair> fvs;
72  std::vector<GC::shared_ptr> gcs;
73  fvs.reserve(p.size());
74  gcs.reserve(p.size());
75  for (auto &&[mean, sigma] : p) {
76  auto gaussianConditional =
77  GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
78  double value = gaussianConditional->negLogConstant();
80  fvs.emplace_back(gaussianConditional, value);
81  gcs.push_back(gaussianConditional);
82  }
83 
84  pairs = FactorValuePairs({mode}, fvs);
85  }
86 
88  explicit Helper(const Conditionals &conditionals) {
89  auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
90  if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
91  if (!nrFrontals) nrFrontals = gc->nrFrontals();
92  double value = gc->negLogConstant();
94  return {gc, value};
95  };
97  if (!nrFrontals.has_value()) {
98  throw std::runtime_error(
99  "HybridGaussianConditional: need at least one frontal variable. "
100  "Provided conditionals do not contain any frontal variables.");
101  }
102  }
103 
105  explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
106  auto func = [this](const GaussianFactorValuePair &pair) {
107  if (!pair.first) return;
108  auto gc = checkConditional(pair.first);
109  if (!nrFrontals) nrFrontals = gc->nrFrontals();
111  };
112  pairs.visit(func);
113  if (!nrFrontals.has_value()) {
114  throw std::runtime_error(
115  "HybridGaussianConditional: need at least one frontal variable. "
116  "Provided conditionals do not contain any frontal variables.");
117  }
118  }
119 };
120 
121 /* *******************************************************************************/
123  const DiscreteKeys &discreteParents, Helper &&helper, bool pruned)
124  : BaseFactor(discreteParents,
126  [&](const GaussianFactorValuePair
127  &pair) { // subtract minNegLogConstant
129  pair.first, pair.second - helper.minNegLogConstant};
130  },
131  std::move(helper.pairs))),
132  BaseConditional(*helper.nrFrontals),
133  negLogConstant_(helper.minNegLogConstant),
134  pruned_(pruned) {}
135 
137  const DiscreteKey &discreteParent,
138  const std::vector<GaussianConditional::shared_ptr> &conditionals)
139  : HybridGaussianConditional(DiscreteKeys{discreteParent},
140  Conditionals({discreteParent}, conditionals)) {}
141 
143  const DiscreteKey &discreteParent, Key key, //
144  const std::vector<std::pair<Vector, double>> &parameters)
145  : HybridGaussianConditional(DiscreteKeys{discreteParent},
146  Helper(discreteParent, parameters, key)) {}
147 
149  const DiscreteKey &discreteParent, Key key, //
150  const Matrix &A, Key parent,
151  const std::vector<std::pair<Vector, double>> &parameters)
153  DiscreteKeys{discreteParent},
154  Helper(discreteParent, parameters, key, A, parent)) {}
155 
157  const DiscreteKey &discreteParent, Key key, //
158  const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
159  const std::vector<std::pair<Vector, double>> &parameters)
161  DiscreteKeys{discreteParent},
162  Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
163 
165  const DiscreteKeys &discreteParents,
167  : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
168 
170  const DiscreteKeys &discreteParents, const FactorValuePairs &pairs,
171  bool pruned)
172  : HybridGaussianConditional(discreteParents, Helper(pairs), pruned) {}
173 
174 /* *******************************************************************************/
177  return Conditionals(factors(), [](auto &&pair) {
178  return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
179  });
180 }
181 
182 /* *******************************************************************************/
184  size_t total = 0;
185  factors().visit([&total](auto &&node) {
186  if (node.first) total += 1;
187  });
188  return total;
189 }
190 
191 /* *******************************************************************************/
193  const DiscreteValues &discreteValues) const {
194  try {
195  auto &[factor, _] = factors()(discreteValues);
196  if (!factor) return nullptr;
197 
198  auto conditional = checkConditional(factor);
199  return conditional;
200  } catch (const std::out_of_range &e) {
201  GTSAM_PRINT(*this);
202  GTSAM_PRINT(discreteValues);
203  throw std::runtime_error(
204  "HybridGaussianConditional::choose: discreteValues does not contain "
205  "all discrete parents.");
206  }
207 }
208 
209 /* *******************************************************************************/
211  double tol) const {
212  const This *e = dynamic_cast<const This *>(&lf);
213  if (e == nullptr) return false;
214 
215  // Factors existence and scalar values are checked in BaseFactor::equals.
216  // Here we check additionally that the factors *are* conditionals
217  // and are equal.
218  auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
219  const GaussianFactorValuePair &pair2) {
220  auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
221  c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
222  return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
223  };
224  return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
225 }
226 
227 /* *******************************************************************************/
228 void HybridGaussianConditional::print(const std::string &s,
229  const KeyFormatter &formatter) const {
230  std::cout << (s.empty() ? "" : s + "\n");
232  std::cout << " Discrete Keys = ";
233  for (auto &dk : discreteKeys()) {
234  std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
235  }
236  std::cout << std::endl
237  << " logNormalizationConstant: " << -negLogConstant() << std::endl
238  << std::endl;
239  factors().print(
240  "", [&](Key k) { return formatter(k); },
241  [&](const GaussianFactorValuePair &pair) -> std::string {
242  RedirectCout rd;
243  if (auto gf =
244  std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
245  gf->print("", formatter);
246  return rd.str();
247  } else {
248  return "nullptr";
249  }
250  });
251 }
252 
253 /* ************************************************************************* */
255  // Get all parent keys:
256  const auto range = parents();
257  KeyVector continuousParentKeys(range.begin(), range.end());
258  // Loop over all discrete keys:
259  for (const auto &discreteKey : discreteKeys()) {
260  const Key key = discreteKey.first;
261  // remove that key from continuousParentKeys:
262  continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
263  continuousParentKeys.end(), key),
264  continuousParentKeys.end());
265  }
266  return continuousParentKeys;
267 }
268 
269 /* ************************************************************************* */
271  const VectorValues &given) const {
272  for (auto &&kv : given) {
273  if (given.find(kv.first) == given.end()) {
274  return false;
275  }
276  }
277  return true;
278 }
279 
280 /* ************************************************************************* */
281 std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
282  const VectorValues &given) const {
283  if (!allFrontalsGiven(given)) {
284  throw std::runtime_error(
285  "HybridGaussianConditional::likelihood: given values are missing some "
286  "frontals.");
287  }
288 
289  const DiscreteKeys discreteParentKeys = discreteKeys();
290  const KeyVector continuousParentKeys = continuousParents();
291  const HybridGaussianFactor::FactorValuePairs likelihoods(
292  factors(),
294  if (auto conditional =
295  std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
296  const auto likelihood_m = conditional->likelihood(given);
297  // pair.second == conditional->negLogConstant() - negLogConstant_
298  return {likelihood_m, pair.second};
299  } else {
300  return {nullptr, std::numeric_limits<double>::infinity()};
301  }
302  });
303  return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
304  likelihoods);
305 }
306 
307 /* ************************************************************************* */
308 std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
309  std::set<DiscreteKey> s(discreteKeys.begin(), discreteKeys.end());
310  return s;
311 }
312 
313 /* *******************************************************************************/
314 HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
315  const DiscreteConditional &discreteProbs) const {
316  // Find keys in discreteProbs.keys() but not in this->keys():
317  std::set<Key> mine(this->keys().begin(), this->keys().end());
318  std::set<Key> theirs(discreteProbs.keys().begin(),
319  discreteProbs.keys().end());
320  std::vector<Key> diff;
321  std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
322  std::back_inserter(diff));
323 
324  // Find maximum probability value for every combination of *our* keys.
325  Ordering ordering(diff);
326  auto max = discreteProbs.max(ordering);
327 
328  // Check the max value for every combination of our keys.
329  // If the max value is 0.0, we can prune the corresponding conditional.
330  bool allPruned = true;
331  auto pruner =
332  [&](const Assignment<Key> &choices,
334  // If this choice is zero probability or Gaussian is null, return infinity
335  if (!pair.first || max->evaluate(choices) == 0.0) {
336  return {nullptr, std::numeric_limits<double>::infinity()};
337  } else {
338  allPruned = false;
339  // Add negLogConstant_ back so that the minimum negLogConstant in the
340  // HybridGaussianConditional is set correctly.
341  return {pair.first, pair.second + negLogConstant_};
342  }
343  };
344 
345  FactorValuePairs prunedConditionals = factors().apply(pruner);
346  if (allPruned) return nullptr;
347  return std::make_shared<HybridGaussianConditional>(discreteKeys(),
348  prunedConditionals, true);
349 }
350 
351 /* *******************************************************************************/
352 double HybridGaussianConditional::logProbability(
353  const HybridValues &values) const {
354  auto [factor, _] = factors()(values.discrete());
355  auto conditional = checkConditional(factor);
356  return conditional->logProbability(values.continuous());
357 }
358 
359 /* *******************************************************************************/
360 double HybridGaussianConditional::evaluate(const HybridValues &values) const {
361  auto [factor, _] = factors()(values.discrete());
362  auto conditional = checkConditional(factor);
363  return conditional->evaluate(values.continuous());
364 }
365 
366 /* ************************************************************************ */
367 std::shared_ptr<Factor> HybridGaussianConditional::restrict(
368  const DiscreteValues &assignment) const {
369  throw std::runtime_error(
370  "HybridGaussianConditional::restrict not implemented");
371 }
372 
373 /* ************************************************************************ */
374 } // namespace gtsam
gtsam::Conditional::print
void print(const std::string &s="Conditional", const KeyFormatter &formatter=DefaultKeyFormatter) const
Definition: Conditional-inst.h:30
GaussianFactorGraph.h
Linear Factor Graph where all factors are Gaussians.
gtsam::HybridValues
Definition: HybridValues.h:37
GaussianConditional.h
Conditional Gaussian Base class.
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const Conditionals &conditionals)
Construct from tree of GaussianConditionals.
Definition: HybridGaussianConditional.cpp:88
gtsam::HybridGaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: HybridGaussianConditional.h:60
HybridGaussianConditional.h
A hybrid conditional in the Conditional Linear Gaussian scheme.
s
RealScalar s
Definition: level1_cplx_impl.h:126
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
gtsam::HybridGaussianConditional::choose
GaussianConditional::shared_ptr choose(const DiscreteValues &discreteValues) const
Return the conditional Gaussian for the given discrete assignment.
Definition: HybridGaussianConditional.cpp:192
gtsam::HybridGaussianConditional::Helper
Helper struct for constructing HybridGaussianConditional objects.
Definition: HybridGaussianConditional.cpp:59
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const DiscreteKey &mode, const P &p, Args &&...args)
Construct from a vector of mean and sigma pairs, plus extra args.
Definition: HybridGaussianConditional.cpp:69
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:975
simple_graph::factors
const GaussianFactorGraph factors
Definition: testJacobianFactor.cpp:213
gtsam::RedirectCout
Definition: base/utilities.h:16
gtsam::HybridFactor
Definition: HybridFactor.h:51
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::Matrix
Eigen::MatrixXd Matrix
Definition: base/Matrix.h:39
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:247
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::DecisionTree::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:981
GaussianBayesNet.h
Chordal Bayes Net, the result of eliminating a factor graph.
gtsam::KeyVector
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:92
HybridGaussianFactor.h
A set of GaussianFactors, indexed by a set of discrete keys.
utilities.h
equal_constants::conditionals
const std::vector< GaussianConditional::shared_ptr > conditionals
Definition: testHybridGaussianConditional.cpp:54
sampling::sigma
static const double sigma
Definition: testGaussianBayesNet.cpp:170
pruning_fixture::factor
DecisionTreeFactor factor(D &C &B &A, "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0")
c1
static double c1
Definition: airy.c:54
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
two_mode_measurement::gcs
const std::vector< GaussianConditional::shared_ptr > gcs
Definition: testHybridGaussianConditional.cpp:246
gtsam::checkConditional
GaussianConditional::shared_ptr checkConditional(const GaussianFactor::shared_ptr &factor)
Definition: HybridGaussianConditional.cpp:38
A
Definition: test_numpy_dtypes.cpp:300
gtsam::GaussianFactorValuePair
std::pair< GaussianFactor::shared_ptr, double > GaussianFactorValuePair
Alias for pair of GaussianFactor::shared_pointer and a double value.
Definition: HybridGaussianFactor.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
mode
static const DiscreteKey mode(modeKey, 2)
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
parameters
static ConjugateGradientParameters parameters
Definition: testIterative.cpp:33
gtsam::HybridGaussianConditional::nrComponents
size_t nrComponents() const
Returns the total number of continuous components.
Definition: HybridGaussianConditional.cpp:183
gtsam::GaussianFactor::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianFactor.h:42
gtsam::GaussianConditional
Definition: GaussianConditional.h:40
A2
static const double A2[]
Definition: expn.h:7
gtsam::HybridGaussianConditional::Helper::P
std::vector< std::pair< Vector, double > > P
Definition: HybridGaussianConditional.cpp:65
gtsam::HybridGaussianConditional
A conditional of gaussian conditionals indexed by discrete variables, as part of a Bayes Network....
Definition: HybridGaussianConditional.h:55
gtsam::DecisionTree::visit
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:841
gtsam::Assignment< Key >
gtsam::HybridGaussianConditional::continuousParents
KeyVector continuousParents() const
Returns the continuous keys among the parents.
Definition: HybridGaussianConditional.cpp:254
GTSAM_PRINT
#define GTSAM_PRINT(x)
Definition: Testable.h:43
gtsam::HybridGaussianFactor::factors
const FactorValuePairs & factors() const
Getter for GaussianFactor decision tree.
Definition: HybridGaussianFactor.h:150
gtsam::HybridGaussianConditional::likelihood
std::shared_ptr< HybridGaussianFactor > likelihood(const VectorValues &given) const
Definition: HybridGaussianConditional.cpp:281
gtsam::Conditional< HybridGaussianFactor, HybridGaussianConditional >::parents
Parents parents() const
Definition: Conditional.h:148
gtsam::HybridGaussianConditional::Helper::minNegLogConstant
double minNegLogConstant
Definition: HybridGaussianConditional.cpp:62
ordering
static enum @1096 ordering
gtsam::HybridGaussianConditional::Helper::nrFrontals
std::optional< size_t > nrFrontals
Definition: HybridGaussianConditional.cpp:61
gtsam::DiscreteConditional::max
virtual DiscreteFactor::shared_ptr max(const Ordering &keys) const override
Create new factor by maximizing over all values with the same separator.
Definition: DiscreteConditional.cpp:490
key
const gtsam::Symbol key('X', 0)
JacobianFactor.h
gtsam::DecisionTree< Key, GaussianFactorValuePair >
gtsam::HybridGaussianConditional::Helper::pairs
FactorValuePairs pairs
Definition: HybridGaussianConditional.cpp:60
gtsam::DiscreteKeysAsSet
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
Definition: HybridGaussianConditional.cpp:308
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
different_sigmas::gc
const auto gc
Definition: testHybridBayesNet.cpp:233
gtsam::HybridGaussianConditional::conditionals
const Conditionals conditionals() const
Definition: HybridGaussianConditional.cpp:176
gtsam::HybridGaussianConditional::Conditionals
DecisionTree< Key, GaussianConditional::shared_ptr > Conditionals
typedef for Decision Tree of Gaussian Conditionals
Definition: HybridGaussianConditional.h:65
gtsam
traits
Definition: SFMdata.h:40
DiscreteValues.h
gtsam::HybridGaussianFactor::FactorValuePairs
DecisionTree< Key, GaussianFactorValuePair > FactorValuePairs
typedef for Decision Tree of Gaussian factors and arbitrary value.
Definition: HybridGaussianFactor.h:69
gtsam::mean
Point3 mean(const CONTAINER &points)
mean
Definition: Point3.h:75
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:143
gtsam::Factor::equals
bool equals(const This &other, double tol=1e-9) const
check equality
Definition: Factor.cpp:42
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
args
Definition: pytypes.h:2212
gtsam::HybridGaussianConditional::allFrontalsGiven
bool allFrontalsGiven(const VectorValues &given) const
Check whether given has values for all frontal keys.
Definition: HybridGaussianConditional.cpp:270
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::HybridGaussianConditional::equals
bool equals(const HybridFactor &lf, double tol=1e-9) const override
Test equality with base HybridFactor.
Definition: HybridGaussianConditional.cpp:210
gtsam::HybridGaussianFactor
Implementation of a discrete-conditioned hybrid factor. Implements a joint discrete-continuous factor...
Definition: HybridGaussianFactor.h:60
A1
static const double A1[]
Definition: expn.h:6
gtsam::GaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianConditional.h:46
c2
static double c2
Definition: airy.c:55
min
#define min(a, b)
Definition: datatypes.h:19
gtsam::HybridGaussianConditional::negLogConstant
double negLogConstant() const override
Return log normalization constant in negative log space.
Definition: HybridGaussianConditional.h:202
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::HybridGaussianConditional::HybridGaussianConditional
HybridGaussianConditional()=default
Default constructor, mainly for serialization.
Eigen::placeholders::end
static const EIGEN_DEPRECATED end_t end
Definition: IndexedViewHelper.h:181
func
Definition: benchGeometry.cpp:23
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
max
#define max(a, b)
Definition: datatypes.h:20
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const FactorValuePairs &pairs)
Construct from tree of factor/scalar pairs.
Definition: HybridGaussianConditional.cpp:105
gtsam::HybridFactor::discreteKeys
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:131
test_callbacks.value
value
Definition: test_callbacks.py:162
gtsam::HybridGaussianConditional::print
void print(const std::string &s="HybridGaussianConditional\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Print utility.
Definition: HybridGaussianConditional.cpp:228
gtsam::RedirectCout::str
std::string str() const
return the string
Definition: utilities.cpp:5
Conditional-inst.h
gtsam::GaussianConditional::sharedMeanAndStddev
static shared_ptr sharedMeanAndStddev(Args &&... args)
Create shared pointer by forwarding arguments to fromMeanAndStddev.
Definition: GaussianConditional.h:126


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:01:48