HybridGaussianConditional.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #include <gtsam/base/utilities.h>
31 
32 #include <cstddef>
33 #include <memory>
34 
35 namespace gtsam {
36 
37 /* *******************************************************************************/
40  if (auto conditional =
41  std::dynamic_pointer_cast<GaussianConditional>(factor)) {
42  return conditional;
43  } else {
44  throw std::logic_error(
45  "A HybridGaussianConditional unexpectedly contained a non-conditional");
46  }
47 }
48 
49 /* *******************************************************************************/
61  std::optional<size_t> nrFrontals = {};
62  double minNegLogConstant = std::numeric_limits<double>::infinity();
63 
65  using P = std::vector<std::pair<Vector, double>>;
66 
68  template <typename... Args>
69  explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
70  nrFrontals = 1;
71  std::vector<GaussianFactorValuePair> fvs;
72  std::vector<GC::shared_ptr> gcs;
73  fvs.reserve(p.size());
74  gcs.reserve(p.size());
75  for (auto &&[mean, sigma] : p) {
76  auto gaussianConditional =
77  GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
78  double value = gaussianConditional->negLogConstant();
80  fvs.emplace_back(gaussianConditional, value);
81  gcs.push_back(gaussianConditional);
82  }
83 
84  pairs = FactorValuePairs({mode}, fvs);
85  }
86 
88  explicit Helper(const Conditionals &conditionals) {
89  auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
90  if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
91  if (!nrFrontals) nrFrontals = gc->nrFrontals();
92  double value = gc->negLogConstant();
94  return {gc, value};
95  };
97  if (!nrFrontals.has_value()) {
98  throw std::runtime_error(
99  "HybridGaussianConditional: need at least one frontal variable. "
100  "Provided conditionals do not contain any frontal variables.");
101  }
102  }
103 
105  explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
106  auto func = [this](const GaussianFactorValuePair &pair) {
107  if (!pair.first) return;
108  auto gc = checkConditional(pair.first);
109  if (!nrFrontals) nrFrontals = gc->nrFrontals();
111  };
112  pairs.visit(func);
113  if (!nrFrontals.has_value()) {
114  throw std::runtime_error(
115  "HybridGaussianConditional: need at least one frontal variable. "
116  "Provided conditionals do not contain any frontal variables.");
117  }
118  }
119 };
120 
121 /* *******************************************************************************/
123  const DiscreteKeys &discreteParents, Helper &&helper, bool pruned)
124  : BaseFactor(discreteParents,
126  [&](const GaussianFactorValuePair
127  &pair) { // subtract minNegLogConstant
129  pair.first, pair.second - helper.minNegLogConstant};
130  },
131  std::move(helper.pairs))),
132  BaseConditional(*helper.nrFrontals),
133  negLogConstant_(helper.minNegLogConstant),
134  pruned_(pruned) {}
135 
137  const DiscreteKey &discreteParent,
138  const std::vector<GaussianConditional::shared_ptr> &conditionals)
139  : HybridGaussianConditional(DiscreteKeys{discreteParent},
140  Conditionals({discreteParent}, conditionals)) {}
141 
143  const DiscreteKey &discreteParent, Key key, //
144  const std::vector<std::pair<Vector, double>> &parameters)
145  : HybridGaussianConditional(DiscreteKeys{discreteParent},
146  Helper(discreteParent, parameters, key)) {}
147 
149  const DiscreteKey &discreteParent, Key key, //
150  const Matrix &A, Key parent,
151  const std::vector<std::pair<Vector, double>> &parameters)
153  DiscreteKeys{discreteParent},
154  Helper(discreteParent, parameters, key, A, parent)) {}
155 
157  const DiscreteKey &discreteParent, Key key, //
158  const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
159  const std::vector<std::pair<Vector, double>> &parameters)
161  DiscreteKeys{discreteParent},
162  Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
163 
165  const DiscreteKeys &discreteParents,
167  : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
168 
170  const DiscreteKeys &discreteParents, const FactorValuePairs &pairs,
171  bool pruned)
172  : HybridGaussianConditional(discreteParents, Helper(pairs), pruned) {}
173 
174 /* *******************************************************************************/
177  return Conditionals(factors(), [](auto &&pair) {
178  return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
179  });
180 }
181 
182 /* *******************************************************************************/
184  size_t total = 0;
185  factors().visit([&total](auto &&node) {
186  if (node.first) total += 1;
187  });
188  return total;
189 }
190 
191 /* *******************************************************************************/
193  const DiscreteValues &discreteValues) const {
194  auto &[factor, _] = factors()(discreteValues);
195  if (!factor) return nullptr;
196 
197  auto conditional = checkConditional(factor);
198  return conditional;
199 }
200 
201 /* *******************************************************************************/
203  double tol) const {
204  const This *e = dynamic_cast<const This *>(&lf);
205  if (e == nullptr) return false;
206 
207  // Factors existence and scalar values are checked in BaseFactor::equals.
208  // Here we check additionally that the factors *are* conditionals
209  // and are equal.
210  auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
211  const GaussianFactorValuePair &pair2) {
212  auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
213  c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
214  return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
215  };
216  return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
217 }
218 
219 /* *******************************************************************************/
220 void HybridGaussianConditional::print(const std::string &s,
221  const KeyFormatter &formatter) const {
222  std::cout << (s.empty() ? "" : s + "\n");
224  std::cout << " Discrete Keys = ";
225  for (auto &dk : discreteKeys()) {
226  std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
227  }
228  std::cout << std::endl
229  << " logNormalizationConstant: " << -negLogConstant() << std::endl
230  << std::endl;
231  factors().print(
232  "", [&](Key k) { return formatter(k); },
233  [&](const GaussianFactorValuePair &pair) -> std::string {
234  RedirectCout rd;
235  if (auto gf =
236  std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
237  gf->print("", formatter);
238  return rd.str();
239  } else {
240  return "nullptr";
241  }
242  });
243 }
244 
245 /* ************************************************************************* */
247  // Get all parent keys:
248  const auto range = parents();
249  KeyVector continuousParentKeys(range.begin(), range.end());
250  // Loop over all discrete keys:
251  for (const auto &discreteKey : discreteKeys()) {
252  const Key key = discreteKey.first;
253  // remove that key from continuousParentKeys:
254  continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
255  continuousParentKeys.end(), key),
256  continuousParentKeys.end());
257  }
258  return continuousParentKeys;
259 }
260 
261 /* ************************************************************************* */
263  const VectorValues &given) const {
264  for (auto &&kv : given) {
265  if (given.find(kv.first) == given.end()) {
266  return false;
267  }
268  }
269  return true;
270 }
271 
272 /* ************************************************************************* */
273 std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
274  const VectorValues &given) const {
275  if (!allFrontalsGiven(given)) {
276  throw std::runtime_error(
277  "HybridGaussianConditional::likelihood: given values are missing some "
278  "frontals.");
279  }
280 
281  const DiscreteKeys discreteParentKeys = discreteKeys();
282  const KeyVector continuousParentKeys = continuousParents();
283  const HybridGaussianFactor::FactorValuePairs likelihoods(
284  factors(),
286  if (auto conditional =
287  std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
288  const auto likelihood_m = conditional->likelihood(given);
289  // pair.second == conditional->negLogConstant() - negLogConstant_
290  return {likelihood_m, pair.second};
291  } else {
292  return {nullptr, std::numeric_limits<double>::infinity()};
293  }
294  });
295  return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
296  likelihoods);
297 }
298 
299 /* ************************************************************************* */
300 std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
301  std::set<DiscreteKey> s(discreteKeys.begin(), discreteKeys.end());
302  return s;
303 }
304 
305 /* *******************************************************************************/
306 HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
307  const DecisionTreeFactor &discreteProbs) const {
308  // Find keys in discreteProbs.keys() but not in this->keys():
309  std::set<Key> mine(this->keys().begin(), this->keys().end());
310  std::set<Key> theirs(discreteProbs.keys().begin(),
311  discreteProbs.keys().end());
312  std::vector<Key> diff;
313  std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
314  std::back_inserter(diff));
315 
316  // Find maximum probability value for every combination of our keys.
317  Ordering keys(diff);
318  auto max = discreteProbs.max(keys);
319 
320  // Check the max value for every combination of our keys.
321  // If the max value is 0.0, we can prune the corresponding conditional.
322  auto pruner =
323  [&](const Assignment<Key> &choices,
325  if (max->evaluate(choices) == 0.0)
326  return {nullptr, std::numeric_limits<double>::infinity()};
327  else {
328  // Add negLogConstant_ back so that the minimum negLogConstant in the
329  // HybridGaussianConditional is set correctly.
330  return {pair.first, pair.second + negLogConstant_};
331  }
332  };
333 
334  FactorValuePairs prunedConditionals = factors().apply(pruner);
335  return std::make_shared<HybridGaussianConditional>(discreteKeys(),
336  prunedConditionals, true);
337 }
338 
339 /* *******************************************************************************/
340 double HybridGaussianConditional::logProbability(
341  const HybridValues &values) const {
342  auto [factor, _] = factors()(values.discrete());
343  auto conditional = checkConditional(factor);
344  return conditional->logProbability(values.continuous());
345 }
346 
347 /* *******************************************************************************/
348 double HybridGaussianConditional::evaluate(const HybridValues &values) const {
349  auto [factor, _] = factors()(values.discrete());
350  auto conditional = checkConditional(factor);
351  return conditional->evaluate(values.continuous());
352 }
353 
354 } // namespace gtsam
gtsam::Conditional::print
void print(const std::string &s="Conditional", const KeyFormatter &formatter=DefaultKeyFormatter) const
Definition: Conditional-inst.h:30
GaussianFactorGraph.h
Linear Factor Graph where all factors are Gaussians.
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
gtsam::HybridValues
Definition: HybridValues.h:37
GaussianConditional.h
Conditional Gaussian Base class.
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const Conditionals &conditionals)
Construct from tree of GaussianConditionals.
Definition: HybridGaussianConditional.cpp:88
gtsam::HybridGaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: HybridGaussianConditional.h:59
HybridGaussianConditional.h
A hybrid conditional in the Conditional Linear Gaussian scheme.
s
RealScalar s
Definition: level1_cplx_impl.h:126
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
gtsam::HybridGaussianConditional::choose
GaussianConditional::shared_ptr choose(const DiscreteValues &discreteValues) const
Return the conditional Gaussian for the given discrete assignment.
Definition: HybridGaussianConditional.cpp:192
gtsam::HybridGaussianConditional::Helper
Helper struct for constructing HybridGaussianConditional objects.
Definition: HybridGaussianConditional.cpp:59
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const DiscreteKey &mode, const P &p, Args &&...args)
Construct from a vector of mean and sigma pairs, plus extra args.
Definition: HybridGaussianConditional.cpp:69
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:972
simple_graph::factors
const GaussianFactorGraph factors
Definition: testJacobianFactor.cpp:213
gtsam::RedirectCout
Definition: base/utilities.h:16
gtsam::HybridFactor
Definition: HybridFactor.h:51
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::Matrix
Eigen::MatrixXd Matrix
Definition: base/Matrix.h:39
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::DecisionTree::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:978
GaussianBayesNet.h
Chordal Bayes Net, the result of eliminating a factor graph.
gtsam::KeyVector
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:92
HybridGaussianFactor.h
A set of GaussianFactors, indexed by a set of discrete keys.
utilities.h
equal_constants::conditionals
const std::vector< GaussianConditional::shared_ptr > conditionals
Definition: testHybridGaussianConditional.cpp:53
sampling::sigma
static const double sigma
Definition: testGaussianBayesNet.cpp:170
pruning_fixture::factor
DecisionTreeFactor factor(D &C &B &A, "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0")
c1
static double c1
Definition: airy.c:54
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
gtsam::checkConditional
GaussianConditional::shared_ptr checkConditional(const GaussianFactor::shared_ptr &factor)
Definition: HybridGaussianConditional.cpp:38
A
Definition: test_numpy_dtypes.cpp:298
gtsam::GaussianFactorValuePair
std::pair< GaussianFactor::shared_ptr, double > GaussianFactorValuePair
Alias for pair of GaussianFactor::shared_pointer and a double value.
Definition: HybridGaussianFactor.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
mode
static const DiscreteKey mode(modeKey, 2)
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
parameters
static ConjugateGradientParameters parameters
Definition: testIterative.cpp:33
gtsam::HybridGaussianConditional::nrComponents
size_t nrComponents() const
Returns the total number of continuous components.
Definition: HybridGaussianConditional.cpp:183
gtsam::GaussianFactor::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianFactor.h:42
gtsam::GaussianConditional
Definition: GaussianConditional.h:40
A2
static const double A2[]
Definition: expn.h:7
gtsam::HybridGaussianConditional::Helper::P
std::vector< std::pair< Vector, double > > P
Definition: HybridGaussianConditional.cpp:65
gtsam::HybridGaussianConditional
A conditional of gaussian conditionals indexed by discrete variables, as part of a Bayes Network....
Definition: HybridGaussianConditional.h:54
gtsam::DecisionTree::visit
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:842
gtsam::Assignment< Key >
gtsam::HybridGaussianConditional::continuousParents
KeyVector continuousParents() const
Returns the continuous keys among the parents.
Definition: HybridGaussianConditional.cpp:246
gtsam::HybridGaussianFactor::factors
const FactorValuePairs & factors() const
Getter for GaussianFactor decision tree.
Definition: HybridGaussianFactor.h:150
gtsam::HybridGaussianConditional::likelihood
std::shared_ptr< HybridGaussianFactor > likelihood(const VectorValues &given) const
Definition: HybridGaussianConditional.cpp:273
gtsam::Conditional< HybridGaussianFactor, HybridGaussianConditional >::parents
Parents parents() const
Definition: Conditional.h:148
gtsam::HybridGaussianConditional::Helper::minNegLogConstant
double minNegLogConstant
Definition: HybridGaussianConditional.cpp:62
gtsam::HybridGaussianConditional::Helper::nrFrontals
std::optional< size_t > nrFrontals
Definition: HybridGaussianConditional.cpp:61
key
const gtsam::Symbol key('X', 0)
JacobianFactor.h
gtsam::DecisionTree< Key, GaussianFactorValuePair >
gtsam::HybridGaussianConditional::Helper::pairs
FactorValuePairs pairs
Definition: HybridGaussianConditional.cpp:60
gtsam::DiscreteKeysAsSet
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
Definition: HybridGaussianConditional.cpp:300
different_sigmas::gc
const auto gc
Definition: testHybridBayesNet.cpp:231
gtsam::HybridGaussianConditional::conditionals
const Conditionals conditionals() const
Definition: HybridGaussianConditional.cpp:176
gtsam::HybridGaussianConditional::Conditionals
DecisionTree< Key, GaussianConditional::shared_ptr > Conditionals
typedef for Decision Tree of Gaussian Conditionals
Definition: HybridGaussianConditional.h:64
gtsam
traits
Definition: SFMdata.h:40
DiscreteValues.h
gtsam::HybridGaussianFactor::FactorValuePairs
DecisionTree< Key, GaussianFactorValuePair > FactorValuePairs
typedef for Decision Tree of Gaussian factors and arbitrary value.
Definition: HybridGaussianFactor.h:69
gtsam::mean
Point3 mean(const CONTAINER &points)
mean
Definition: Point3.h:75
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:143
gtsam::Factor::equals
bool equals(const This &other, double tol=1e-9) const
check equality
Definition: Factor.cpp:42
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
args
Definition: pytypes.h:2210
gtsam::HybridGaussianConditional::allFrontalsGiven
bool allFrontalsGiven(const VectorValues &given) const
Check whether given has values for all frontal keys.
Definition: HybridGaussianConditional.cpp:262
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::HybridGaussianConditional::equals
bool equals(const HybridFactor &lf, double tol=1e-9) const override
Test equality with base HybridFactor.
Definition: HybridGaussianConditional.cpp:202
gtsam::HybridGaussianFactor
Implementation of a discrete-conditioned hybrid factor. Implements a joint discrete-continuous factor...
Definition: HybridGaussianFactor.h:60
A1
static const double A1[]
Definition: expn.h:6
gtsam::GaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianConditional.h:46
c2
static double c2
Definition: airy.c:55
min
#define min(a, b)
Definition: datatypes.h:19
gtsam::HybridGaussianConditional::negLogConstant
double negLogConstant() const override
Return log normalization constant in negative log space.
Definition: HybridGaussianConditional.h:201
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::DecisionTreeFactor::max
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override
Create new factor by maximizing over all values with the same separator.
Definition: DecisionTreeFactor.h:205
gtsam::HybridGaussianConditional::HybridGaussianConditional
HybridGaussianConditional()=default
Default constructor, mainly for serialization.
Eigen::placeholders::end
static const EIGEN_DEPRECATED end_t end
Definition: IndexedViewHelper.h:181
func
Definition: benchGeometry.cpp:23
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
max
#define max(a, b)
Definition: datatypes.h:20
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const FactorValuePairs &pairs)
Construct from tree of factor/scalar pairs.
Definition: HybridGaussianConditional.cpp:105
gtsam::HybridFactor::discreteKeys
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:131
test_callbacks.value
value
Definition: test_callbacks.py:160
gtsam::HybridGaussianConditional::print
void print(const std::string &s="HybridGaussianConditional\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Print utility.
Definition: HybridGaussianConditional.cpp:220
gtsam::RedirectCout::str
std::string str() const
return the string
Definition: utilities.cpp:5
Conditional-inst.h
gtsam::GaussianConditional::sharedMeanAndStddev
static shared_ptr sharedMeanAndStddev(Args &&... args)
Create shared pointer by forwarding arguments to fromMeanAndStddev.
Definition: GaussianConditional.h:105


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:22