GaussianMixture.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #include <gtsam/base/utilities.h>
28 
29 namespace gtsam {
30 
32  const KeyVector &continuousFrontals, const KeyVector &continuousParents,
33  const DiscreteKeys &discreteParents,
35  : BaseFactor(CollectKeys(continuousFrontals, continuousParents),
36  discreteParents),
37  BaseConditional(continuousFrontals.size()),
38  conditionals_(conditionals) {
39  // Calculate logConstant_ as the maximum of the log constants of the
40  // conditionals, by visiting the decision tree:
41  logConstant_ = -std::numeric_limits<double>::infinity();
43  [this](const GaussianConditional::shared_ptr &conditional) {
44  if (conditional) {
45  this->logConstant_ = std::max(
46  this->logConstant_, conditional->logNormalizationConstant());
47  }
48  });
49 }
50 
51 /* *******************************************************************************/
53  return conditionals_;
54 }
55 
56 /* *******************************************************************************/
58  KeyVector &&continuousFrontals, KeyVector &&continuousParents,
59  DiscreteKeys &&discreteParents,
60  std::vector<GaussianConditional::shared_ptr> &&conditionals)
61  : GaussianMixture(continuousFrontals, continuousParents, discreteParents,
62  Conditionals(discreteParents, conditionals)) {}
63 
64 /* *******************************************************************************/
66  const KeyVector &continuousFrontals, const KeyVector &continuousParents,
67  const DiscreteKeys &discreteParents,
68  const std::vector<GaussianConditional::shared_ptr> &conditionals)
69  : GaussianMixture(continuousFrontals, continuousParents, discreteParents,
70  Conditionals(discreteParents, conditionals)) {}
71 
72 /* *******************************************************************************/
73 // TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
74 // GaussianMixtureFactor, no?
76  const GaussianFactorGraphTree &sum) const {
77  using Y = GaussianFactorGraph;
78  auto add = [](const Y &graph1, const Y &graph2) {
79  auto result = graph1;
80  result.push_back(graph2);
81  return result;
82  };
83  const auto tree = asGaussianFactorGraphTree();
84  return sum.empty() ? tree : sum.apply(tree, add);
85 }
86 
87 /* *******************************************************************************/
89  auto wrap = [](const GaussianConditional::shared_ptr &gc) {
90  return GaussianFactorGraph{gc};
91  };
92  return {conditionals_, wrap};
93 }
94 
95 /* *******************************************************************************/
97  size_t total = 0;
98  conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
99  if (node) total += 1;
100  });
101  return total;
102 }
103 
104 /* *******************************************************************************/
106  const DiscreteValues &discreteValues) const {
107  auto &ptr = conditionals_(discreteValues);
108  if (!ptr) return nullptr;
109  auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
110  if (conditional)
111  return conditional;
112  else
113  throw std::logic_error(
114  "A GaussianMixture unexpectedly contained a non-conditional");
115 }
116 
117 /* *******************************************************************************/
118 bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
119  const This *e = dynamic_cast<const This *>(&lf);
120  if (e == nullptr) return false;
121 
122  // This will return false if either conditionals_ is empty or e->conditionals_
123  // is empty, but not if both are empty or both are not empty:
124  if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
125 
126  // Check the base and the factors:
127  return BaseFactor::equals(*e, tol) &&
129  [tol](const GaussianConditional::shared_ptr &f1,
131  return f1->equals(*(f2), tol);
132  });
133 }
134 
135 /* *******************************************************************************/
136 void GaussianMixture::print(const std::string &s,
137  const KeyFormatter &formatter) const {
138  std::cout << (s.empty() ? "" : s + "\n");
139  if (isContinuous()) std::cout << "Continuous ";
140  if (isDiscrete()) std::cout << "Discrete ";
141  if (isHybrid()) std::cout << "Hybrid ";
142  BaseConditional::print("", formatter);
143  std::cout << " Discrete Keys = ";
144  for (auto &dk : discreteKeys()) {
145  std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
146  }
147  std::cout << "\n";
149  "", [&](Key k) { return formatter(k); },
150  [&](const GaussianConditional::shared_ptr &gf) -> std::string {
151  RedirectCout rd;
152  if (gf && !gf->empty()) {
153  gf->print("", formatter);
154  return rd.str();
155  } else {
156  return "nullptr";
157  }
158  });
159 }
160 
161 /* ************************************************************************* */
163  // Get all parent keys:
164  const auto range = parents();
165  KeyVector continuousParentKeys(range.begin(), range.end());
166  // Loop over all discrete keys:
167  for (const auto &discreteKey : discreteKeys()) {
168  const Key key = discreteKey.first;
169  // remove that key from continuousParentKeys:
170  continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
171  continuousParentKeys.end(), key),
172  continuousParentKeys.end());
173  }
174  return continuousParentKeys;
175 }
176 
177 /* ************************************************************************* */
179  for (auto &&kv : given) {
180  if (given.find(kv.first) == given.end()) {
181  return false;
182  }
183  }
184  return true;
185 }
186 
187 /* ************************************************************************* */
188 std::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
189  const VectorValues &given) const {
190  if (!allFrontalsGiven(given)) {
191  throw std::runtime_error(
192  "GaussianMixture::likelihood: given values are missing some frontals.");
193  }
194 
195  const DiscreteKeys discreteParentKeys = discreteKeys();
196  const KeyVector continuousParentKeys = continuousParents();
197  const GaussianMixtureFactor::Factors likelihoods(
198  conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
199  const auto likelihood_m = conditional->likelihood(given);
200  const double Cgm_Kgcm =
201  logConstant_ - conditional->logNormalizationConstant();
202  if (Cgm_Kgcm == 0.0) {
203  return likelihood_m;
204  } else {
205  // Add a constant factor to the likelihood in case the noise models
206  // are not all equal.
208  gfg.push_back(likelihood_m);
209  Vector c(1);
210  c << std::sqrt(2.0 * Cgm_Kgcm);
211  auto constantFactor = std::make_shared<JacobianFactor>(c);
212  gfg.push_back(constantFactor);
213  return std::make_shared<JacobianFactor>(gfg);
214  }
215  });
216  return std::make_shared<GaussianMixtureFactor>(
217  continuousParentKeys, discreteParentKeys, likelihoods);
218 }
219 
220 /* ************************************************************************* */
221 std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
222  std::set<DiscreteKey> s;
223  s.insert(discreteKeys.begin(), discreteKeys.end());
224  return s;
225 }
226 
227 /* ************************************************************************* */
235 std::function<GaussianConditional::shared_ptr(
238  // Get the discrete keys as sets for the decision tree
239  // and the gaussian mixture.
240  auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
241  auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
242 
243  auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
244  const Assignment<Key> &choices,
245  const GaussianConditional::shared_ptr &conditional)
247  // typecast so we can use this to get probability value
248  const DiscreteValues values(choices);
249 
250  // Case where the gaussian mixture has the same
251  // discrete keys as the decision tree.
252  if (gaussianMixtureKeySet == decisionTreeKeySet) {
253  if (decisionTree(values) == 0.0) {
254  // empty aka null pointer
255  std::shared_ptr<GaussianConditional> null;
256  return null;
257  } else {
258  return conditional;
259  }
260  } else {
261  std::vector<DiscreteKey> set_diff;
262  std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
263  gaussianMixtureKeySet.begin(),
264  gaussianMixtureKeySet.end(),
265  std::back_inserter(set_diff));
266 
267  const std::vector<DiscreteValues> assignments =
269  for (const DiscreteValues &assignment : assignments) {
270  DiscreteValues augmented_values(values);
271  augmented_values.insert(assignment);
272 
273  // If any one of the sub-branches are non-zero,
274  // we need this conditional.
275  if (decisionTree(augmented_values) > 0.0) {
276  return conditional;
277  }
278  }
279  // If we are here, it means that all the sub-branches are 0,
280  // so we prune.
281  return nullptr;
282  }
283  };
284  return pruner;
285 }
286 
287 /* *******************************************************************************/
288 void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
289  auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
290  auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
291  // Functional which loops over all assignments and create a set of
292  // GaussianConditionals
293  auto pruner = prunerFunc(decisionTree);
294 
295  auto pruned_conditionals = conditionals_.apply(pruner);
296  conditionals_.root_ = pruned_conditionals.root_;
297 }
298 
299 /* *******************************************************************************/
301  const VectorValues &continuousValues) const {
302  // functor to calculate (double) logProbability value from
303  // GaussianConditional.
304  auto probFunc =
305  [continuousValues](const GaussianConditional::shared_ptr &conditional) {
306  if (conditional) {
307  return conditional->logProbability(continuousValues);
308  } else {
309  // Return arbitrarily small logProbability if conditional is null
310  // Conditional is null if it is pruned out.
311  return -1e20;
312  }
313  };
314  return DecisionTree<Key, double>(conditionals_, probFunc);
315 }
316 
317 /* *******************************************************************************/
319  const VectorValues &continuousValues) const {
320  auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
321  return conditional->error(continuousValues) + //
322  logConstant_ - conditional->logNormalizationConstant();
323  };
324  DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
325  return errorTree;
326 }
327 
328 /* *******************************************************************************/
330  // Directly index to get the conditional, no need to build the whole tree.
331  auto conditional = conditionals_(values.discrete());
332  return conditional->error(values.continuous()) + //
333  logConstant_ - conditional->logNormalizationConstant();
334 }
335 
336 /* *******************************************************************************/
338  auto conditional = conditionals_(values.discrete());
339  return conditional->logProbability(values.continuous());
340 }
341 
342 /* *******************************************************************************/
344  auto conditional = conditionals_(values.discrete());
345  return conditional->evaluate(values.continuous());
346 }
347 
348 } // namespace gtsam
const gtsam::Symbol key('X', 0)
A set of GaussianFactors, indexed by a set of discrete keys.
A hybrid conditional in the Conditional Linear Gaussian scheme.
const char Y
std::shared_ptr< This > shared_ptr
shared_ptr to this class
AlgebraicDecisionTree< Key > logProbability(const VectorValues &continuousValues) const
Compute logProbability of the GaussianMixture as a tree.
#define max(a, b)
Definition: datatypes.h:20
std::function< GaussianConditional::shared_ptr(const Assignment< Key > &, const GaussianConditional::shared_ptr &)> prunerFunc(const DecisionTreeFactor &decisionTree)
Helper function to get the pruner functor.
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const
Merge the Gaussian Factor Graphs in this and sum while maintaining the decision tree structure...
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
void visit(Func f) const
Visit all leaves in depth-first fashion.
KeyVector continuousParents() const
Returns the continuous keys among the parents.
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:190
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:129
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
leaf::MyValues values
const VectorValues & continuous() const
Return the multi-dimensional vector values.
Definition: HybridValues.h:89
double f2(const Vector2 &x)
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
const KeyFormatter & formatter
double logConstant_
log of the normalization constant.
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:136
A conditional of gaussian mixtures indexed by discrete variables, as part of a Bayes Network...
bool equals(const HybridFactor &lf, double tol=1e-9) const override
Test equality with base HybridFactor.
GaussianConditional::shared_ptr operator()(const DiscreteValues &discreteValues) const
Return the conditional Gaussian for the given discrete assignment.
Eigen::VectorXd Vector
Definition: Vector.h:38
Values result
std::string str() const
return the string
Definition: utilities.cpp:5
virtual bool equals(const HybridFactor &lf, double tol=1e-9) const
equals
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
double error(const HybridValues &values) const override
Compute the error of this Gaussian Mixture.
void print(const std::string &s="GaussianMixture\, const KeyFormatter &formatter=DefaultKeyFormatter) const override
Print utility.
Array< double, 1, 3 > e(1./3., 0.5, 2.)
RealScalar s
size_t nrComponents() const
Returns the total number of continuous components.
Linear Factor Graph where all factors are Gaussians.
bool isContinuous() const
True if this is a factor of continuous variables only.
Definition: HybridFactor.h:120
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
DecisionTree apply(const Unary &op) const
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
std::shared_ptr< This > shared_ptr
shared_ptr to this class
NonlinearFactorGraph graph2()
traits
Definition: chartTesting.h:28
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:240
std::shared_ptr< GaussianMixtureFactor > likelihood(const VectorValues &given) const
bool allFrontalsGiven(const VectorValues &given) const
Check whether given has values for all frontal keys.
Conditionals conditionals_
a decision tree of Gaussian conditionals.
bool isHybrid() const
True is this is a Discrete-Continuous factor.
Definition: HybridFactor.h:123
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Double_ range(const Point2_ &p, const Point2_ &q)
mxArray * wrap(const Class &value)
Definition: matlab.h:126
const Conditionals & conditionals() const
Getter for the underlying Conditionals DecisionTree.
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.h:92
void print(const std::string &s="Conditional", const KeyFormatter &formatter=DefaultKeyFormatter) const
const std::vector< GaussianConditional::shared_ptr > conditionals
Jet< T, N > sqrt(const Jet< T, N > &f)
Definition: jet.h:418
const G double tol
Definition: Group.h:86
std::pair< iterator, bool > insert(const value_type &value)
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
void prune(const DecisionTreeFactor &decisionTree)
Prune the decision tree of Gaussian factors as per the discrete decisionTree.
double evaluate(const HybridValues &values) const override
Calculate probability density for given values.
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
bool isDiscrete() const
True if this is a factor of discrete variables only.
Definition: HybridFactor.h:117
GaussianFactorGraphTree asGaussianFactorGraphTree() const
Convert a DecisionTree of factors into a DT of Gaussian FGs.
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys)
GaussianMixture()=default
Default constructor, mainly for serialization.
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:15