GaussianMixture.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #include <gtsam/base/utilities.h>
28 
29 namespace gtsam {
30 
32  const KeyVector &continuousFrontals, const KeyVector &continuousParents,
33  const DiscreteKeys &discreteParents,
35  : BaseFactor(CollectKeys(continuousFrontals, continuousParents),
36  discreteParents),
37  BaseConditional(continuousFrontals.size()),
38  conditionals_(conditionals) {
39  // Calculate logConstant_ as the maximum of the log constants of the
40  // conditionals, by visiting the decision tree:
41  logConstant_ = -std::numeric_limits<double>::infinity();
43  [this](const GaussianConditional::shared_ptr &conditional) {
44  if (conditional) {
45  this->logConstant_ = std::max(
46  this->logConstant_, conditional->logNormalizationConstant());
47  }
48  });
49 }
50 
51 /* *******************************************************************************/
53  return conditionals_;
54 }
55 
56 /* *******************************************************************************/
58  KeyVector &&continuousFrontals, KeyVector &&continuousParents,
59  DiscreteKeys &&discreteParents,
60  std::vector<GaussianConditional::shared_ptr> &&conditionals)
61  : GaussianMixture(continuousFrontals, continuousParents, discreteParents,
62  Conditionals(discreteParents, conditionals)) {}
63 
64 /* *******************************************************************************/
66  const KeyVector &continuousFrontals, const KeyVector &continuousParents,
67  const DiscreteKeys &discreteParents,
68  const std::vector<GaussianConditional::shared_ptr> &conditionals)
69  : GaussianMixture(continuousFrontals, continuousParents, discreteParents,
70  Conditionals(discreteParents, conditionals)) {}
71 
72 /* *******************************************************************************/
73 // TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
74 // GaussianMixtureFactor, no?
76  const GaussianFactorGraphTree &sum) const {
77  using Y = GaussianFactorGraph;
78  auto add = [](const Y &graph1, const Y &graph2) {
79  auto result = graph1;
80  result.push_back(graph2);
81  return result;
82  };
83  const auto tree = asGaussianFactorGraphTree();
84  return sum.empty() ? tree : sum.apply(tree, add);
85 }
86 
87 /* *******************************************************************************/
89  auto wrap = [](const GaussianConditional::shared_ptr &gc) {
90  return GaussianFactorGraph{gc};
91  };
92  return {conditionals_, wrap};
93 }
94 
95 /* *******************************************************************************/
97  size_t total = 0;
98  conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
99  if (node) total += 1;
100  });
101  return total;
102 }
103 
104 /* *******************************************************************************/
106  const DiscreteValues &discreteValues) const {
107  auto &ptr = conditionals_(discreteValues);
108  if (!ptr) return nullptr;
109  auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
110  if (conditional)
111  return conditional;
112  else
113  throw std::logic_error(
114  "A GaussianMixture unexpectedly contained a non-conditional");
115 }
116 
117 /* *******************************************************************************/
118 bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
119  const This *e = dynamic_cast<const This *>(&lf);
120  if (e == nullptr) return false;
121 
122  // This will return false if either conditionals_ is empty or e->conditionals_
123  // is empty, but not if both are empty or both are not empty:
124  if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
125 
126  // Check the base and the factors:
127  return BaseFactor::equals(*e, tol) &&
128  conditionals_.equals(e->conditionals_,
131  return f1->equals(*(f2), tol);
132  });
133 }
134 
135 /* *******************************************************************************/
136 void GaussianMixture::print(const std::string &s,
137  const KeyFormatter &formatter) const {
138  std::cout << (s.empty() ? "" : s + "\n");
139  if (isContinuous()) std::cout << "Continuous ";
140  if (isDiscrete()) std::cout << "Discrete ";
141  if (isHybrid()) std::cout << "Hybrid ";
143  std::cout << " Discrete Keys = ";
144  for (auto &dk : discreteKeys()) {
145  std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
146  }
147  std::cout << "\n";
149  "", [&](Key k) { return formatter(k); },
150  [&](const GaussianConditional::shared_ptr &gf) -> std::string {
151  RedirectCout rd;
152  if (gf && !gf->empty()) {
153  gf->print("", formatter);
154  return rd.str();
155  } else {
156  return "nullptr";
157  }
158  });
159 }
160 
161 /* ************************************************************************* */
163  // Get all parent keys:
164  const auto range = parents();
165  KeyVector continuousParentKeys(range.begin(), range.end());
166  // Loop over all discrete keys:
167  for (const auto &discreteKey : discreteKeys()) {
168  const Key key = discreteKey.first;
169  // remove that key from continuousParentKeys:
170  continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
171  continuousParentKeys.end(), key),
172  continuousParentKeys.end());
173  }
174  return continuousParentKeys;
175 }
176 
177 /* ************************************************************************* */
179  for (auto &&kv : given) {
180  if (given.find(kv.first) == given.end()) {
181  return false;
182  }
183  }
184  return true;
185 }
186 
187 /* ************************************************************************* */
188 std::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
189  const VectorValues &given) const {
190  if (!allFrontalsGiven(given)) {
191  throw std::runtime_error(
192  "GaussianMixture::likelihood: given values are missing some frontals.");
193  }
194 
195  const DiscreteKeys discreteParentKeys = discreteKeys();
196  const KeyVector continuousParentKeys = continuousParents();
197  const GaussianMixtureFactor::Factors likelihoods(
198  conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
199  const auto likelihood_m = conditional->likelihood(given);
200  const double Cgm_Kgcm =
201  logConstant_ - conditional->logNormalizationConstant();
202  if (Cgm_Kgcm == 0.0) {
203  return likelihood_m;
204  } else {
205  // Add a constant factor to the likelihood in case the noise models
206  // are not all equal.
208  gfg.push_back(likelihood_m);
209  Vector c(1);
210  c << std::sqrt(2.0 * Cgm_Kgcm);
211  auto constantFactor = std::make_shared<JacobianFactor>(c);
212  gfg.push_back(constantFactor);
213  return std::make_shared<JacobianFactor>(gfg);
214  }
215  });
216  return std::make_shared<GaussianMixtureFactor>(
217  continuousParentKeys, discreteParentKeys, likelihoods);
218 }
219 
220 /* ************************************************************************* */
221 std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
222  std::set<DiscreteKey> s;
223  s.insert(discreteKeys.begin(), discreteKeys.end());
224  return s;
225 }
226 
227 /* ************************************************************************* */
235 std::function<GaussianConditional::shared_ptr(
236  const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
238  // Get the discrete keys as sets for the decision tree
239  // and the gaussian mixture.
240  auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
241  auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
242 
243  auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
244  const Assignment<Key> &choices,
245  const GaussianConditional::shared_ptr &conditional)
247  // typecast so we can use this to get probability value
248  const DiscreteValues values(choices);
249 
250  // Case where the gaussian mixture has the same
251  // discrete keys as the decision tree.
252  if (gaussianMixtureKeySet == discreteProbsKeySet) {
253  if (discreteProbs(values) == 0.0) {
254  // empty aka null pointer
255  std::shared_ptr<GaussianConditional> null;
256  return null;
257  } else {
258  return conditional;
259  }
260  } else {
261  std::vector<DiscreteKey> set_diff;
262  std::set_difference(
263  discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
264  gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
265  std::back_inserter(set_diff));
266 
267  const std::vector<DiscreteValues> assignments =
269  for (const DiscreteValues &assignment : assignments) {
270  DiscreteValues augmented_values(values);
271  augmented_values.insert(assignment);
272 
273  // If any one of the sub-branches are non-zero,
274  // we need this conditional.
275  if (discreteProbs(augmented_values) > 0.0) {
276  return conditional;
277  }
278  }
279  // If we are here, it means that all the sub-branches are 0,
280  // so we prune.
281  return nullptr;
282  }
283  };
284  return pruner;
285 }
286 
287 /* *******************************************************************************/
288 void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
289  // Functional which loops over all assignments and create a set of
290  // GaussianConditionals
291  auto pruner = prunerFunc(discreteProbs);
292 
293  auto pruned_conditionals = conditionals_.apply(pruner);
294  conditionals_.root_ = pruned_conditionals.root_;
295 }
296 
297 /* *******************************************************************************/
299  const VectorValues &continuousValues) const {
300  // functor to calculate (double) logProbability value from
301  // GaussianConditional.
302  auto probFunc =
303  [continuousValues](const GaussianConditional::shared_ptr &conditional) {
304  if (conditional) {
305  return conditional->logProbability(continuousValues);
306  } else {
307  // Return arbitrarily small logProbability if conditional is null
308  // Conditional is null if it is pruned out.
309  return -1e20;
310  }
311  };
312  return DecisionTree<Key, double>(conditionals_, probFunc);
313 }
314 
315 /* *******************************************************************************/
317  const VectorValues &continuousValues) const {
318  auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
319  return conditional->error(continuousValues) + //
320  logConstant_ - conditional->logNormalizationConstant();
321  };
322  DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
323  return error_tree;
324 }
325 
326 /* *******************************************************************************/
328  // Directly index to get the conditional, no need to build the whole tree.
329  auto conditional = conditionals_(values.discrete());
330  return conditional->error(values.continuous()) + //
331  logConstant_ - conditional->logNormalizationConstant();
332 }
333 
334 /* *******************************************************************************/
336  auto conditional = conditionals_(values.discrete());
337  return conditional->logProbability(values.continuous());
338 }
339 
340 /* *******************************************************************************/
342  auto conditional = conditionals_(values.discrete());
343  return conditional->evaluate(values.continuous());
344 }
345 
346 } // namespace gtsam
gtsam::GaussianMixture::GaussianMixture
GaussianMixture()=default
Default constructor, mainly for serialization.
gtsam::Conditional< HybridFactor, GaussianMixture >::print
void print(const std::string &s="Conditional", const KeyFormatter &formatter=DefaultKeyFormatter) const
Definition: Conditional-inst.h:30
GaussianFactorGraph.h
Linear Factor Graph where all factors are Gaussians.
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
gtsam::HybridValues
Definition: HybridValues.h:38
Y
const char Y
Definition: test/EulerAngles.cpp:31
s
RealScalar s
Definition: level1_cplx_impl.h:126
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
gtsam::HybridFactor::equals
virtual bool equals(const HybridFactor &lf, double tol=1e-9) const
equals
Definition: HybridFactor.cpp:73
gtsam::DecisionTree::empty
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:269
gtsam::GaussianMixture::logConstant_
double logConstant_
log of the normalization constant.
Definition: GaussianMixture.h:67
c
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
tree
Definition: testExpression.cpp:212
gtsam::DecisionTree::equals
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
Definition: DecisionTree-inl.h:898
gtsam::HybridFactor::isDiscrete
bool isDiscrete() const
True if this is a factor of discrete variables only.
Definition: HybridFactor.h:118
gtsam::GaussianMixture
A conditional of gaussian mixtures indexed by discrete variables, as part of a Bayes Network....
Definition: GaussianMixture.h:53
gtsam::RedirectCout
Definition: base/utilities.h:16
gtsam::HybridFactor
Definition: HybridFactor.h:53
wrap
mxArray * wrap(const Class &value)
Definition: matlab.h:126
gtsam::GaussianMixture::asGaussianFactorGraphTree
GaussianFactorGraphTree asGaussianFactorGraphTree() const
Convert a DecisionTree of factors into a DT of Gaussian FGs.
Definition: GaussianMixture.cpp:88
simple::graph2
NonlinearFactorGraph graph2()
Definition: testInitializePose3.cpp:72
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
f2
double f2(const Vector2 &x)
Definition: testNumericalDerivative.cpp:56
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
gtsam::DecisionTree::print
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:904
gtsam::HybridFactor::isHybrid
bool isHybrid() const
True is this is a Discrete-Continuous factor.
Definition: HybridFactor.h:124
gtsam::Vector
Eigen::VectorXd Vector
Definition: Vector.h:38
gtsam::KeyVector
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:92
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::GaussianMixture::continuousParents
KeyVector continuousParents() const
Returns the continuous keys among the parents.
Definition: GaussianMixture.cpp:162
utilities.h
equal_constants::conditionals
const std::vector< GaussianConditional::shared_ptr > conditionals
Definition: testGaussianMixture.cpp:51
gtsam::AlgebraicDecisionTree< Key >
size
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
gtsam::GaussianFactorGraph
Definition: GaussianFactorGraph.h:73
gtsam::GaussianMixture::print
void print(const std::string &s="GaussianMixture\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
Print utility.
Definition: GaussianMixture.cpp:136
gtsam::DecisionTree::root_
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:150
gtsam::GaussianMixture::conditionals
const Conditionals & conditionals() const
Getter for the underlying Conditionals DecisionTree.
Definition: GaussianMixture.cpp:52
gtsam::VectorValues
Definition: VectorValues.h:74
gtsam::DiscreteValues::CartesianProduct
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
Definition: DiscreteValues.h:85
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::GaussianFactor::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianFactor.h:42
gtsam::GaussianMixture::likelihood
std::shared_ptr< GaussianMixtureFactor > likelihood(const VectorValues &given) const
Definition: GaussianMixture.cpp:188
gtsam::GaussianMixture::add
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const
Merge the Gaussian Factor Graphs in this and sum while maintaining the decision tree structure.
Definition: GaussianMixture.cpp:75
gtsam::DecisionTree::visit
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:768
GaussianMixture.h
A hybrid conditional in the Conditional Linear Gaussian scheme.
gtsam::CollectKeys
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys)
Definition: HybridFactor.cpp:23
gtsam::Assignment< Key >
gtsam::HybridFactor::isContinuous
bool isContinuous() const
True if this is a factor of continuous variables only.
Definition: HybridFactor.h:121
gtsam::GaussianMixture::logProbability
AlgebraicDecisionTree< Key > logProbability(const VectorValues &continuousValues) const
Compute logProbability of the GaussianMixture as a tree.
Definition: GaussianMixture.cpp:298
gtsam::Conditional< HybridFactor, GaussianMixture >
gtsam::DecisionTree::apply
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:921
gtsam::GaussianMixture::allFrontalsGiven
bool allFrontalsGiven(const VectorValues &given) const
Check whether given has values for all frontal keys.
Definition: GaussianMixture.cpp:178
gtsam::Conditional< HybridFactor, GaussianMixture >::parents
Parents parents() const
Definition: Conditional.h:146
key
const gtsam::Symbol key('X', 0)
gtsam::GaussianMixture::conditionals_
Conditionals conditionals_
a decision tree of Gaussian conditionals.
Definition: GaussianMixture.h:66
gtsam::DecisionTree< Key, GaussianConditional::shared_ptr >
gtsam::DiscreteKeysAsSet
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
Definition: GaussianMixture.cpp:221
gtsam::GaussianMixture::prune
void prune(const DecisionTreeFactor &discreteProbs)
Prune the decision tree of Gaussian factors as per the discrete discreteProbs.
Definition: GaussianMixture.cpp:288
gtsam::GaussianMixture::errorTree
AlgebraicDecisionTree< Key > errorTree(const VectorValues &continuousValues) const
Compute error of the GaussianMixture as a tree.
Definition: GaussianMixture.cpp:316
gtsam
traits
Definition: chartTesting.h:28
DiscreteValues.h
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
leaf::values
leaf::MyValues values
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
gtsam::DiscreteValues::insert
std::pair< iterator, bool > insert(const value_type &value)
Definition: DiscreteValues.h:68
gtsam::GaussianMixture::equals
bool equals(const HybridFactor &lf, double tol=1e-9) const override
Test equality with base HybridFactor.
Definition: GaussianMixture.cpp:118
gtsam::GaussianMixture::operator()
GaussianConditional::shared_ptr operator()(const DiscreteValues &discreteValues) const
Return the conditional Gaussian for the given discrete assignment.
Definition: GaussianMixture.cpp:105
gtsam::GaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianConditional.h:46
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::GaussianMixture::nrComponents
size_t nrComponents() const
Returns the total number of continuous components.
Definition: GaussianMixture.cpp:96
gtsam::DiscreteFactor::discreteKeys
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Definition: DiscreteFactor.cpp:32
unary::f1
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Definition: testExpression.cpp:79
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
max
#define max(a, b)
Definition: datatypes.h:20
gtsam::GaussianMixture::evaluate
double evaluate(const HybridValues &values) const override
Calculate probability density for given values.
Definition: GaussianMixture.cpp:341
HybridValues.h
gtsam::GaussianMixture::prunerFunc
std::function< GaussianConditional::shared_ptr(const Assignment< Key > &, const GaussianConditional::shared_ptr &)> prunerFunc(const DecisionTreeFactor &discreteProbs)
Helper function to get the pruner functor.
Definition: GaussianMixture.cpp:237
gtsam::HybridFactor::discreteKeys
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
Definition: HybridFactor.h:130
ceres::sqrt
Jet< T, N > sqrt(const Jet< T, N > &f)
Definition: jet.h:418
gtsam::RedirectCout::str
std::string str() const
return the string
Definition: utilities.cpp:5
GaussianMixtureFactor.h
A set of GaussianFactors, indexed by a set of discrete keys.
Conditional-inst.h
gtsam::GaussianMixture::error
double error(const HybridValues &values) const override
Compute the error of this Gaussian Mixture.
Definition: GaussianMixture.cpp:327


gtsam
Author(s):
autogenerated on Thu Jun 13 2024 03:02:23