48 for (
auto &&conditional : *
this) {
49 if (conditional->isDiscrete()) {
52 dtFactor = dtFactor *
f;
55 return std::make_shared<DecisionTreeFactor>(dtFactor);
66 std::function<double(const Assignment<Key> &, double)>
prunerFunc(
71 std::set<DiscreteKey> decisionTreeKeySet =
73 std::set<DiscreteKey> conditionalKeySet =
76 auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
78 double probability) ->
double {
80 double pruned_prob = 0.0;
86 if (conditionalKeySet == decisionTreeKeySet) {
87 if (prunedDecisionTree(values) == 0) {
95 std::set<Key> valuesKeys;
96 for (
auto kvp : values) {
97 valuesKeys.insert(kvp.first);
99 std::set<Key> conditionalKeys;
100 for (
auto kvp : conditionalKeySet) {
101 conditionalKeys.insert(kvp.first);
104 if (conditionalKeys != valuesKeys) {
106 std::vector<Key> missing_keys;
107 std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
108 valuesKeys.begin(), valuesKeys.end(),
109 std::back_inserter(missing_keys));
111 for (
auto missing_key : missing_keys) {
112 values[missing_key] = 0;
119 std::vector<DiscreteKey> set_diff;
120 std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
121 conditionalKeySet.begin(), conditionalKeySet.end(),
122 std::back_inserter(set_diff));
125 const std::vector<DiscreteValues> assignments =
129 augmented_values.
insert(assignment);
133 if (prunedDecisionTree(augmented_values) > 0.0) {
151 for (
size_t i = 0;
i < this->
size();
i++) {
153 if (conditional->isDiscrete()) {
154 auto discrete = conditional->asDiscrete();
163 KeyVector frontals(discrete->frontals().begin(),
164 discrete->frontals().end());
165 auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
166 frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
167 conditional = std::make_shared<HybridConditional>(prunedDiscrete);
170 this->
at(
i) = conditional;
194 for (
auto &&conditional : *
this) {
195 if (
auto gm = conditional->
asMixture()) {
197 auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
198 prunedGaussianMixture->
prune(decisionTree);
201 prunedBayesNetFragment.
push_back(prunedGaussianMixture);
205 prunedBayesNetFragment.
push_back(conditional);
209 return prunedBayesNetFragment;
216 for (
auto &&conditional : *
this) {
217 if (
auto gm = conditional->
asMixture()) {
220 }
else if (
auto gc = conditional->
asGaussian()) {
223 }
else if (
auto dc = conditional->
asDiscrete()) {
236 for (
auto &&conditional : *
this) {
255 if (std::find(gbn.
begin(), gbn.
end(),
nullptr) != gbn.
end()) {
263 std::mt19937_64 *
rng)
const {
265 for (
auto &&conditional : *
this) {
277 return {
sample, assignment};
283 return sample(given, rng);
302 for (
auto &&conditional : *
this) {
303 if (
auto gm = conditional->
asMixture()) {
306 result = result + gm->logProbability(continuousValues);
307 }
else if (
auto gc = conditional->
asGaussian()) {
313 result = result.
apply([logProbability](
double leaf_value) {
316 }
else if (
auto dc = conditional->
asDiscrete()) {
318 result = result.
apply(
320 return leaf_value + dc->logProbability(
DiscreteValues(assignment));
332 return tree.
apply([](
double log) {
return exp(log); });
347 for (
auto &&conditional : *
this) {
350 fg.
push_back(gc->likelihood(measurements));
351 }
else if (
auto gm = conditional->
asMixture()) {
352 fg.
push_back(gm->likelihood(measurements));
354 throw std::runtime_error(
"Unknown conditional type");
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
static std::mt19937 kRandomNumberGenerator(42)
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes Net which corresponds to a specific discrete value assignment.
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
VectorValues sample(std::mt19937_64 *rng) const
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
const DiscreteKeys & discreteKeys() const
Return the discrete keys for this factor.
DecisionTreeFactor::shared_ptr discreteConditionals() const
Get all the discrete conditionals as a decision tree factor.
static const GaussianBayesNet gbn
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
const VectorValues & continuous() const
Return the multi-dimensional vector values.
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
EIGEN_DEVICE_FUNC const LogReturnType log() const
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
const KeyFormatter & formatter
A Bayes net of Gaussian Conditionals indexed by discrete keys.
DiscreteConditional::shared_ptr asDiscrete() const
Return conditional as a DiscreteConditional.
EIGEN_DEVICE_FUNC const ExpReturnType exp() const
VectorValues optimize() const
const_iterator begin() const
bool frontalsIn(const VectorValues &measurements) const
Check if VectorValues measurements contains all frontal keys.
GaussianConditional::shared_ptr asGaussian() const
Return HybridConditional as a GaussianConditional.
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
std::set< DiscreteKey > DiscreteKeysAsSet(const DiscreteKeys &discreteKeys)
Return the DiscreteKey vector as a set.
HybridBayesNet prune(size_t maxNrLeaves)
Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
void push_back(std::shared_ptr< HybridConditional > conditional)
Add a hybrid conditional using a shared_ptr.
DecisionTree apply(const Unary &op) const
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
const_iterator end() const
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
std::shared_ptr< DecisionTreeFactor > shared_ptr
GaussianMixture::shared_ptr asMixture() const
Return HybridConditional as a GaussianMixture.
const sharedFactor at(size_t i) const
static std::mt19937_64 kRandomNumberGenerator(42)
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree)
Update the discrete conditionals with the pruned versions.
const DiscreteValues & discrete() const
Return the discrete values.
const KeyVector & keys() const
Access the factor's involved variable keys.
std::function< double(const Assignment< Key > &, double)> prunerFunc(const DecisionTreeFactor &prunedDecisionTree, const HybridConditional &conditional)
Helper function to get the pruner functional.
std::pair< iterator, bool > insert(const value_type &value)
AlgebraicDecisionTree< Key > logProbability(const VectorValues &continuousValues) const
Compute conditional error for each discrete assignment, and return as a tree.
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
bool isDiscrete() const
True if this is a factor of discrete variables only.
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
DiscreteValues sample() const
do ancestral sampling
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
HybridValues sample() const
Sample using ancestral sampling, use default rng.