HybridGaussianConditional.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #include <gtsam/base/utilities.h>
31 
32 #include <cstddef>
33 #include <memory>
34 
35 namespace gtsam {
36 
37 /* *******************************************************************************/
39  const GaussianFactor::shared_ptr &factor) {
40  if (auto conditional =
41  std::dynamic_pointer_cast<GaussianConditional>(factor)) {
42  return conditional;
43  } else {
44  throw std::logic_error(
45  "A HybridGaussianConditional unexpectedly contained a non-conditional");
46  }
47 }
48 
49 /* *******************************************************************************/
61  std::optional<size_t> nrFrontals = {};
62  double minNegLogConstant = std::numeric_limits<double>::infinity();
63 
65  using P = std::vector<std::pair<Vector, double>>;
66 
68  template <typename... Args>
69  explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
70  nrFrontals = 1;
71  std::vector<GaussianFactorValuePair> fvs;
72  std::vector<GC::shared_ptr> gcs;
73  fvs.reserve(p.size());
74  gcs.reserve(p.size());
75  for (auto &&[mean, sigma] : p) {
76  auto gaussianConditional =
77  GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
78  double value = gaussianConditional->negLogConstant();
80  fvs.emplace_back(gaussianConditional, value);
81  gcs.push_back(gaussianConditional);
82  }
83 
84  pairs = FactorValuePairs({mode}, fvs);
85  }
86 
88  explicit Helper(const Conditionals &conditionals) {
89  auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
90  if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
91  if (!nrFrontals) nrFrontals = gc->nrFrontals();
92  double value = gc->negLogConstant();
94  return {gc, value};
95  };
97  if (!nrFrontals.has_value()) {
98  throw std::runtime_error(
99  "HybridGaussianConditional: need at least one frontal variable. "
100  "Provided conditionals do not contain any frontal variables.");
101  }
102  }
103 
105  explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
106  auto func = [this](const GaussianFactorValuePair &pair) {
107  if (!pair.first) return;
108  auto gc = checkConditional(pair.first);
109  if (!nrFrontals) nrFrontals = gc->nrFrontals();
111  };
112  pairs.visit(func);
113  if (!nrFrontals.has_value()) {
114  throw std::runtime_error(
115  "HybridGaussianConditional: need at least one frontal variable. "
116  "Provided conditionals do not contain any frontal variables.");
117  }
118  }
119 };
120 
121 /* *******************************************************************************/
123  const DiscreteKeys &discreteParents, Helper &&helper)
124  : BaseFactor(discreteParents,
126  [&](const GaussianFactorValuePair
127  &pair) { // subtract minNegLogConstant
129  pair.first, pair.second - helper.minNegLogConstant};
130  },
131  std::move(helper.pairs))),
132  BaseConditional(*helper.nrFrontals),
133  negLogConstant_(helper.minNegLogConstant) {}
134 
136  const DiscreteKey &discreteParent,
137  const std::vector<GaussianConditional::shared_ptr> &conditionals)
138  : HybridGaussianConditional(DiscreteKeys{discreteParent},
139  Conditionals({discreteParent}, conditionals)) {}
140 
142  const DiscreteKey &discreteParent, Key key, //
143  const std::vector<std::pair<Vector, double>> &parameters)
144  : HybridGaussianConditional(DiscreteKeys{discreteParent},
145  Helper(discreteParent, parameters, key)) {}
146 
148  const DiscreteKey &discreteParent, Key key, //
149  const Matrix &A, Key parent,
150  const std::vector<std::pair<Vector, double>> &parameters)
152  DiscreteKeys{discreteParent},
153  Helper(discreteParent, parameters, key, A, parent)) {}
154 
156  const DiscreteKey &discreteParent, Key key, //
157  const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
158  const std::vector<std::pair<Vector, double>> &parameters)
160  DiscreteKeys{discreteParent},
161  Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
162 
164  const DiscreteKeys &discreteParents,
166  : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
167 
169  const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
170  : HybridGaussianConditional(discreteParents, Helper(pairs)) {}
171 
172 /* *******************************************************************************/
175  return Conditionals(factors(), [](auto &&pair) {
176  return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
177  });
178 }
179 
180 /* *******************************************************************************/
182  size_t total = 0;
183  factors().visit([&total](auto &&node) {
184  if (node.first) total += 1;
185  });
186  return total;
187 }
188 
189 /* *******************************************************************************/
191  const DiscreteValues &discreteValues) const {
192  auto &[factor, _] = factors()(discreteValues);
193  if (!factor) return nullptr;
194 
195  auto conditional = checkConditional(factor);
196  return conditional;
197 }
198 
199 /* *******************************************************************************/
201  double tol) const {
202  const This *e = dynamic_cast<const This *>(&lf);
203  if (e == nullptr) return false;
204 
205  // Factors existence and scalar values are checked in BaseFactor::equals.
206  // Here we check additionally that the factors *are* conditionals
207  // and are equal.
208  auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
209  const GaussianFactorValuePair &pair2) {
210  auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
211  c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
212  return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
213  };
214  return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
215 }
216 
217 /* *******************************************************************************/
218 void HybridGaussianConditional::print(const std::string &s,
219  const KeyFormatter &formatter) const {
220  std::cout << (s.empty() ? "" : s + "\n");
222  std::cout << " Discrete Keys = ";
223  for (auto &dk : discreteKeys()) {
224  std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
225  }
226  std::cout << std::endl
227  << " logNormalizationConstant: " << -negLogConstant() << std::endl
228  << std::endl;
229  factors().print(
230  "", [&](Key k) { return formatter(k); },
231  [&](const GaussianFactorValuePair &pair) -> std::string {
232  RedirectCout rd;
233  if (auto gf =
234  std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
235  gf->print("", formatter);
236  return rd.str();
237  } else {
238  return "nullptr";
239  }
240  });
241 }
242 
243 /* ************************************************************************* */
245  // Get all parent keys:
246  const auto range = parents();
247  KeyVector continuousParentKeys(range.begin(), range.end());
248  // Loop over all discrete keys:
249  for (const auto &discreteKey : discreteKeys()) {
250  const Key key = discreteKey.first;
251  // remove that key from continuousParentKeys:
252  continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
253  continuousParentKeys.end(), key),
254  continuousParentKeys.end());
255  }
256  return continuousParentKeys;
257 }
258 
259 /* ************************************************************************* */
261  const VectorValues &given) const {
262  for (auto &&kv : given) {
263  if (given.find(kv.first) == given.end()) {
264  return false;
265  }
266  }
267  return true;
268 }
269 
270 /* ************************************************************************* */
271 std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
272  const VectorValues &given) const {
273  if (!allFrontalsGiven(given)) {
274  throw std::runtime_error(
275  "HybridGaussianConditional::likelihood: given values are missing some "
276  "frontals.");
277  }
278 
279  const DiscreteKeys discreteParentKeys = discreteKeys();
280  const KeyVector continuousParentKeys = continuousParents();
281  const HybridGaussianFactor::FactorValuePairs likelihoods(
282  factors(),
284  if (auto conditional =
285  std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
286  const auto likelihood_m = conditional->likelihood(given);
287  // pair.second == conditional->negLogConstant() - negLogConstant_
288  return {likelihood_m, pair.second};
289  } else {
290  return {nullptr, std::numeric_limits<double>::infinity()};
291  }
292  });
293  return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
294  likelihoods);
295 }
296 
297 /* ************************************************************************* */
298 std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
299  std::set<DiscreteKey> s(discreteKeys.begin(), discreteKeys.end());
300  return s;
301 }
302 
303 /* *******************************************************************************/
304 HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
305  const DecisionTreeFactor &discreteProbs) const {
306  // Find keys in discreteProbs.keys() but not in this->keys():
307  std::set<Key> mine(this->keys().begin(), this->keys().end());
308  std::set<Key> theirs(discreteProbs.keys().begin(),
309  discreteProbs.keys().end());
310  std::vector<Key> diff;
311  std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
312  std::back_inserter(diff));
313 
314  // Find maximum probability value for every combination of our keys.
315  Ordering keys(diff);
316  auto max = discreteProbs.max(keys);
317 
318  // Check the max value for every combination of our keys.
319  // If the max value is 0.0, we can prune the corresponding conditional.
320  auto pruner =
321  [&](const Assignment<Key> &choices,
323  if (max->evaluate(choices) == 0.0)
324  return {nullptr, std::numeric_limits<double>::infinity()};
325  else {
326  // Add negLogConstant_ back so that the minimum negLogConstant in the
327  // HybridGaussianConditional is set correctly.
328  return {pair.first, pair.second + negLogConstant_};
329  }
330  };
331 
332  FactorValuePairs prunedConditionals = factors().apply(pruner);
333  return std::make_shared<HybridGaussianConditional>(discreteKeys(),
334  prunedConditionals);
335 }
336 
337 /* *******************************************************************************/
338 double HybridGaussianConditional::logProbability(
339  const HybridValues &values) const {
340  auto [factor, _] = factors()(values.discrete());
341  auto conditional = checkConditional(factor);
342  return conditional->logProbability(values.continuous());
343 }
344 
345 /* *******************************************************************************/
346 double HybridGaussianConditional::evaluate(const HybridValues &values) const {
347  auto [factor, _] = factors()(values.discrete());
348  auto conditional = checkConditional(factor);
349  return conditional->evaluate(values.continuous());
350 }
351 
352 } // namespace gtsam
gtsam::Conditional::print
void print(const std::string &s="Conditional", const KeyFormatter &formatter=DefaultKeyFormatter) const
Definition: Conditional-inst.h:30
GaussianFactorGraph.h
Linear Factor Graph where all factors are Gaussians.
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
gtsam::HybridValues
Definition: HybridValues.h:37
GaussianConditional.h
Conditional Gaussian Base class.
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const Conditionals &conditionals)
Construct from tree of GaussianConditionals.
Definition: HybridGaussianConditional.cpp:88
gtsam::HybridGaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
Definition: HybridGaussianConditional.h:59
HybridGaussianConditional.h
A hybrid conditional in the Conditional Linear Gaussian scheme.
s
RealScalar s
Definition: level1_cplx_impl.h:126
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
gtsam::HybridGaussianConditional::choose
GaussianConditional::shared_ptr choose(const DiscreteValues &discreteValues) const
Return the conditional Gaussian for the given discrete assignment.
Definition: HybridGaussianConditional.cpp:190
gtsam::HybridGaussianConditional::Helper
Helper struct for constructing HybridGaussianConditional objects.
Definition: HybridGaussianConditional.cpp:59
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const DiscreteKey &mode, const P &p, Args &&...args)
Construct from a vector of mean and sigma pairs, plus extra args.
Definition: HybridGaussianConditional.cpp:69
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:972
simple_graph::factors
const GaussianFactorGraph factors
Definition: testJacobianFactor.cpp:213
gtsam::RedirectCout
Definition: base/utilities.h:16
gtsam::HybridFactor
Definition: HybridFactor.h:51
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::Matrix
Eigen::MatrixXd Matrix
Definition: base/Matrix.h:39
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::DecisionTree::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:978
GaussianBayesNet.h
Chordal Bayes Net, the result of eliminating a factor graph.
gtsam::KeyVector
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:92
HybridGaussianFactor.h
A set of GaussianFactors, indexed by a set of discrete keys.
utilities.h
equal_constants::conditionals
const std::vector< GaussianConditional::shared_ptr > conditionals
Definition: testHybridGaussianConditional.cpp:53
sampling::sigma
static const double sigma
Definition: testGaussianBayesNet.cpp:170
c1
static double c1
Definition: airy.c:54
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
gtsam::checkConditional
GaussianConditional::shared_ptr checkConditional(const GaussianFactor::shared_ptr &factor)
Definition: HybridGaussianConditional.cpp:38
A
Definition: test_numpy_dtypes.cpp:298
gtsam::GaussianFactorValuePair
std::pair< GaussianFactor::shared_ptr, double > GaussianFactorValuePair
Alias for pair of GaussianFactor::shared_pointer and a double value.
Definition: HybridGaussianFactor.h:38
gtsam::VectorValues
Definition: VectorValues.h:74
mode
static const DiscreteKey mode(modeKey, 2)
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
parameters
static ConjugateGradientParameters parameters
Definition: testIterative.cpp:33
gtsam::HybridGaussianConditional::nrComponents
size_t nrComponents() const
Returns the total number of continuous components.
Definition: HybridGaussianConditional.cpp:181
gtsam::GaussianFactor::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianFactor.h:42
gtsam::GaussianConditional
Definition: GaussianConditional.h:40
A2
static const double A2[]
Definition: expn.h:7
gtsam::HybridGaussianConditional::Helper::P
std::vector< std::pair< Vector, double > > P
Definition: HybridGaussianConditional.cpp:65
gtsam::HybridGaussianConditional
A conditional of gaussian conditionals indexed by discrete variables, as part of a Bayes Network....
Definition: HybridGaussianConditional.h:54
gtsam::DecisionTree::visit
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:842
gtsam::Assignment< Key >
gtsam::HybridGaussianConditional::continuousParents
KeyVector continuousParents() const
Returns the continuous keys among the parents.
Definition: HybridGaussianConditional.cpp:244
gtsam::HybridGaussianFactor::factors
const FactorValuePairs & factors() const
Getter for GaussianFactor decision tree.
Definition: HybridGaussianFactor.h:150
gtsam::HybridGaussianConditional::likelihood
std::shared_ptr< HybridGaussianFactor > likelihood(const VectorValues &given) const
Definition: HybridGaussianConditional.cpp:271
gtsam::Conditional< HybridGaussianFactor, HybridGaussianConditional >::parents
Parents parents() const
Definition: Conditional.h:148
gtsam::HybridGaussianConditional::Helper::minNegLogConstant
double minNegLogConstant
Definition: HybridGaussianConditional.cpp:62
gtsam::HybridGaussianConditional::Helper::nrFrontals
std::optional< size_t > nrFrontals
Definition: HybridGaussianConditional.cpp:61
key
const gtsam::Symbol key('X', 0)
JacobianFactor.h
gtsam::DecisionTree< Key, GaussianFactorValuePair >
gtsam::HybridGaussianConditional::Helper::pairs
FactorValuePairs pairs
Definition: HybridGaussianConditional.cpp:60
gtsam::DiscreteKeysAsSet
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
Definition: HybridGaussianConditional.cpp:298
different_sigmas::gc
const auto gc
Definition: testHybridBayesNet.cpp:231
gtsam::HybridGaussianConditional::conditionals
const Conditionals conditionals() const
Definition: HybridGaussianConditional.cpp:174
gtsam::HybridGaussianConditional::Conditionals
DecisionTree< Key, GaussianConditional::shared_ptr > Conditionals
typedef for Decision Tree of Gaussian Conditionals
Definition: HybridGaussianConditional.h:64
gtsam
traits
Definition: SFMdata.h:40
DiscreteValues.h
gtsam::HybridGaussianFactor::FactorValuePairs
DecisionTree< Key, GaussianFactorValuePair > FactorValuePairs
typedef for Decision Tree of Gaussian factors and arbitrary value.
Definition: HybridGaussianFactor.h:69
gtsam::mean
Point3 mean(const CONTAINER &points)
mean
Definition: Point3.h:70
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:143
gtsam::Factor::equals
bool equals(const This &other, double tol=1e-9) const
check equality
Definition: Factor.cpp:42
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
args
Definition: pytypes.h:2210
gtsam::DecisionTreeFactor::max
shared_ptr max(size_t nrFrontals) const
Create new factor by maximizing over all values with the same separator.
Definition: DecisionTreeFactor.h:172
gtsam::HybridGaussianConditional::allFrontalsGiven
bool allFrontalsGiven(const VectorValues &given) const
Check whether given has values for all frontal keys.
Definition: HybridGaussianConditional.cpp:260
p
float * p
Definition: Tutorial_Map_using.cpp:9
gtsam::HybridGaussianConditional::equals
bool equals(const HybridFactor &lf, double tol=1e-9) const override
Test equality with base HybridFactor.
Definition: HybridGaussianConditional.cpp:200
gtsam::HybridGaussianFactor
Implementation of a discrete-conditioned hybrid factor. Implements a joint discrete-continuous factor...
Definition: HybridGaussianFactor.h:60
A1
static const double A1[]
Definition: expn.h:6
gtsam::GaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianConditional.h:46
c2
static double c2
Definition: airy.c:55
min
#define min(a, b)
Definition: datatypes.h:19
gtsam::HybridGaussianConditional::negLogConstant
double negLogConstant() const override
Return log normalization constant in negative log space.
Definition: HybridGaussianConditional.h:197
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::HybridGaussianConditional::HybridGaussianConditional
HybridGaussianConditional()=default
Default constructor, mainly for serialization.
Eigen::placeholders::end
static const EIGEN_DEPRECATED end_t end
Definition: IndexedViewHelper.h:181
func
Definition: benchGeometry.cpp:23
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
max
#define max(a, b)
Definition: datatypes.h:20
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::HybridGaussianConditional::Helper::Helper
Helper(const FactorValuePairs &pairs)
Construct from tree of factor/scalar pairs.
Definition: HybridGaussianConditional.cpp:105
gtsam::HybridFactor::discreteKeys
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:131
test_callbacks.value
value
Definition: test_callbacks.py:160
gtsam::HybridGaussianConditional::print
void print(const std::string &s="HybridGaussianConditional\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Print utility.
Definition: HybridGaussianConditional.cpp:218
gtsam::RedirectCout::str
std::string str() const
return the string
Definition: utilities.cpp:5
Conditional-inst.h
gtsam::GaussianConditional::sharedMeanAndStddev
static shared_ptr sharedMeanAndStddev(Args &&... args)
Create shared pointer by forwarding arguments to fromMeanAndStddev.
Definition: GaussianConditional.h:105


gtsam
Author(s):
autogenerated on Sat Nov 16 2024 04:02:26