38 conditionals_(conditionals) {
46 this->
logConstant_, conditional->logNormalizationConstant());
60 std::vector<GaussianConditional::shared_ptr> &&
conditionals)
68 const std::vector<GaussianConditional::shared_ptr> &
conditionals)
69 :
GaussianMixture(continuousFrontals, continuousParents, discreteParents,
108 if (!ptr)
return nullptr;
113 throw std::logic_error(
114 "A GaussianMixture unexpectedly contained a non-conditional");
119 const This *
e =
dynamic_cast<const This *
>(&lf);
120 if (e ==
nullptr)
return false;
131 return f1->equals(*(f2), tol);
138 std::cout << (s.empty() ?
"" : s +
"\n");
141 if (
isHybrid()) std::cout <<
"Hybrid ";
143 std::cout <<
" Discrete Keys = ";
145 std::cout <<
"(" <<
formatter(dk.first) <<
", " << dk.second <<
"), ";
152 if (gf && !gf->empty()) {
153 gf->print(
"", formatter);
168 const Key key = discreteKey.first;
170 continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
171 continuousParentKeys.end(),
key),
172 continuousParentKeys.end());
174 return continuousParentKeys;
179 for (
auto &&kv : given) {
180 if (given.find(kv.first) == given.end()) {
191 throw std::runtime_error(
192 "GaussianMixture::likelihood: given values are missing some frontals.");
199 const auto likelihood_m = conditional->likelihood(given);
200 const double Cgm_Kgcm =
202 if (Cgm_Kgcm == 0.0) {
211 auto constantFactor = std::make_shared<JacobianFactor>(
c);
213 return std::make_shared<JacobianFactor>(gfg);
216 return std::make_shared<GaussianMixtureFactor>(
217 continuousParentKeys, discreteParentKeys, likelihoods);
222 std::set<DiscreteKey>
s;
223 s.insert(discreteKeys.begin(), discreteKeys.end());
243 auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
252 if (gaussianMixtureKeySet == decisionTreeKeySet) {
253 if (decisionTree(values) == 0.0) {
255 std::shared_ptr<GaussianConditional> null;
261 std::vector<DiscreteKey> set_diff;
262 std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
263 gaussianMixtureKeySet.begin(),
264 gaussianMixtureKeySet.end(),
265 std::back_inserter(set_diff));
267 const std::vector<DiscreteValues> assignments =
271 augmented_values.
insert(assignment);
275 if (decisionTree(augmented_values) > 0.0) {
307 return conditional->logProbability(continuousValues);
321 return conditional->error(continuousValues) +
332 return conditional->error(values.
continuous()) +
339 return conditional->logProbability(values.
continuous());
345 return conditional->evaluate(values.
continuous());
const gtsam::Symbol key('X', 0)
A set of GaussianFactors, indexed by a set of discrete keys.
A hybrid conditional in the Conditional Linear Gaussian scheme.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
AlgebraicDecisionTree< Key > logProbability(const VectorValues &continuousValues) const
Compute logProbability of the GaussianMixture as a tree.
std::function< GaussianConditional::shared_ptr(const Assignment< Key > &, const GaussianConditional::shared_ptr &)> prunerFunc(const DecisionTreeFactor &decisionTree)
Helper function to get the pruner functor.
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const
Merge the Gaussian Factor Graphs in this and sum while maintaining the decision tree structure...
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
void visit(Func f) const
Visit all leaves in depth-first fashion.
KeyVector continuousParents() const
Returns the continuous keys among the parents.
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
const VectorValues & continuous() const
Return the multi-dimensional vector values.
double f2(const Vector2 &x)
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
const KeyFormatter & formatter
double logConstant_
log of the normalization constant.
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
A conditional of gaussian mixtures indexed by discrete variables, as part of a Bayes Network...
bool equals(const HybridFactor &lf, double tol=1e-9) const override
Test equality with base HybridFactor.
GaussianConditional::shared_ptr operator()(const DiscreteValues &discreteValues) const
Return the conditional Gaussian for the given discrete assignment.
std::string str() const
return the string
virtual bool equals(const HybridFactor &lf, double tol=1e-9) const
equals
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
double error(const HybridValues &values) const override
Compute the error of this Gaussian Mixture.
void print(const std::string &s="GaussianMixture\, const KeyFormatter &formatter=DefaultKeyFormatter) const override
Print utility.
Array< double, 1, 3 > e(1./3., 0.5, 2.)
size_t nrComponents() const
Returns the total number of continuous components.
Linear Factor Graph where all factors are Gaussians.
bool isContinuous() const
True if this is a factor of continuous variables only.
bool equals(const DecisionTree &other, const CompareFunc &compare=&DefaultCompare) const
DecisionTree apply(const Unary &op) const
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
NonlinearFactorGraph graph2()
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
bool empty() const
Check if tree is empty.
std::shared_ptr< GaussianMixtureFactor > likelihood(const VectorValues &given) const
bool allFrontalsGiven(const VectorValues &given) const
Check whether given has values for all frontal keys.
Conditionals conditionals_
a decision tree of Gaussian conditionals.
bool isHybrid() const
True is this is a Discrete-Continuous factor.
Point2 f1(const Point3 &p, OptionalJacobian< 2, 3 > H)
Double_ range(const Point2_ &p, const Point2_ &q)
mxArray * wrap(const Class &value)
const Conditionals & conditionals() const
Getter for the underlying Conditionals DecisionTree.
const DiscreteValues & discrete() const
Return the discrete values.
void print(const std::string &s="Conditional", const KeyFormatter &formatter=DefaultKeyFormatter) const
const std::vector< GaussianConditional::shared_ptr > conditionals
Jet< T, N > sqrt(const Jet< T, N > &f)
std::pair< iterator, bool > insert(const value_type &value)
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
void prune(const DecisionTreeFactor &decisionTree)
Prune the decision tree of Gaussian factors as per the discrete decisionTree.
double evaluate(const HybridValues &values) const override
Calculate probability density for given values.
std::uint64_t Key
Integer nonlinear key type.
bool isDiscrete() const
True if this is a factor of discrete variables only.
GaussianFactorGraphTree asGaussianFactorGraphTree() const
Convert a DecisionTree of factors into a DT of Gaussian FGs.
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys)
GaussianMixture()=default
Default constructor, mainly for serialization.
DiscreteKeys is a set of keys that can be assembled using the & operator.