Go to the documentation of this file.
54 for (
auto &&conditional : marginal) {
55 joint = joint * (*conditional);
74 for (
auto &&conditional : *
this) {
75 if (
auto hgc = conditional->asHybrid()) {
77 auto prunedHybridGaussianConditional =
hgc->prune(pruned);
80 result.push_back(prunedHybridGaussianConditional);
81 }
else if (
auto gc = conditional->asGaussian()) {
94 for (
auto &&conditional : *
this) {
95 if (
auto dc = conditional->asDiscrete()) {
106 for (
auto &&conditional : *
this) {
107 if (
auto gm = conditional->asHybrid()) {
109 gbn.push_back(gm->choose(assignment));
110 }
else if (
auto gc = conditional->asGaussian()) {
113 }
else if (
auto dc = conditional->asDiscrete()) {
127 for (
auto &&conditional : *
this) {
128 if (conditional->isDiscrete()) {
129 discrete_fg.
push_back(conditional->asDiscrete());
146 if (std::find(
gbn.begin(),
gbn.end(),
nullptr) !=
gbn.end()) {
149 return gbn.optimize();
154 std::mt19937_64 *
rng)
const {
156 for (
auto &&conditional : *
this) {
157 if (conditional->isDiscrete()) {
159 dbn.
push_back(conditional->asDiscrete());
168 return {
sample, assignment};
193 for (
auto &&conditional : *
this) {
194 result =
result + conditional->errorTree(continuousValues);
202 const std::optional<DiscreteValues> &discrete)
const {
203 double negLogNormConst = 0.0;
205 for (
auto &&conditional : *
this) {
206 if (discrete.has_value()) {
207 if (
auto gm = conditional->asHybrid()) {
208 negLogNormConst += gm->choose(*discrete)->negLogConstant();
209 }
else if (
auto gc = conditional->asGaussian()) {
210 negLogNormConst +=
gc->negLogConstant();
211 }
else if (
auto dc = conditional->asDiscrete()) {
212 negLogNormConst += dc->choose(*discrete)->negLogConstant();
214 throw std::runtime_error(
215 "Unknown conditional type when computing negLogConstant");
218 negLogNormConst += conditional->negLogConstant();
221 return negLogNormConst;
245 for (
auto &&conditional : *
this) {
247 if (
auto gc = conditional->asGaussian()) {
249 }
else if (
auto gm = conditional->asHybrid()) {
252 throw std::runtime_error(
"Unknown conditional type");
bool equals(const This &fg, double tol=1e-9) const
GTSAM-style equals.
HybridValues sample() const
Sample using ancestral sampling, use default rng.
DiscreteValues sample() const
do ancestral sampling
void print(const std::string &s="BayesNet", const KeyFormatter &formatter=DefaultKeyFormatter) const override
double error(const HybridValues &values) const
double evaluate(const HybridValues &values) const
Evaluate hybrid probability density for given HybridValues.
HybridBayesNet prune(size_t maxNrLeaves) const
Prune the Bayes Net such that we have at most maxNrLeaves leaves.
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
const KeyFormatter & formatter
A Bayes net of Gaussian Conditionals indexed by discrete keys.
const EIGEN_DEVICE_FUNC ExpReturnType exp() const
const VectorValues & continuous() const
Return the multi-dimensional vector values.
HybridValues optimize() const
Solve the HybridBayesNet by first computing the MPE of all the discrete variables and then optimizing...
static std::mt19937 kRandomNumberGenerator(42)
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
AlgebraicDecisionTree< Key > errorTree(const VectorValues &continuousValues) const
Compute the negative log posterior log P'(M|x) of all assignments up to a constant,...
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
void print(const std::string &s="", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style printing.
DiscreteBayesNet discreteMarginal() const
Get the discrete Bayes Net P(M). As the hybrid Bayes net defines P(X,M) = P(X|M) P(M),...
DecisionTree apply(const Unary &op) const
static const GaussianBayesNet gbn
AlgebraicDecisionTree< Key > discretePosterior(const VectorValues &continuousValues) const
Compute normalized posterior P(M|X=x) and return as a tree.
double logProbability(const HybridValues &x) const
std::vector< double > measurements
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
const DiscreteValues & discrete() const
Return the discrete values.
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
GaussianBayesNet choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes net P(X|M=m) corresponding to a specific assignment m for the discrete variabl...
static std::mt19937_64 kRandomNumberGenerator(42)
double negLogConstant(const std::optional< DiscreteValues > &discrete) const
Get the negative log of the normalization constant corresponding to the joint density represented by ...
gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:22