DiscreteBayesNet.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
23 
24 namespace gtsam {
25 
26 // Instantiate base class
27 template class FactorGraph<DiscreteConditional>;
28 
29 /* ************************************************************************* */
30 bool DiscreteBayesNet::equals(const This& bn, double tol) const {
31  return Base::equals(bn, tol);
32 }
33 
34 /* ************************************************************************* */
36  // evaluate all conditionals and add
37  double result = 0.0;
38  for (const DiscreteConditional::shared_ptr& conditional : *this)
39  result += conditional->logProbability(values);
40  return result;
41 }
42 
43 /* ************************************************************************* */
45  // evaluate all conditionals and multiply
46  double result = 1.0;
47  for (const DiscreteConditional::shared_ptr& conditional : *this)
48  result *= (*conditional)(values);
49  return result;
50 }
51 
52 /* ************************************************************************* */
53 DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
55  return sample(result, rng);
56 }
57 
59  std::mt19937_64* rng) const {
60  // sample each node in turn in topological sort order (parents first)
61  for (auto it = std::make_reverse_iterator(end());
62  it != std::make_reverse_iterator(begin()); ++it) {
63  const DiscreteConditional::shared_ptr& conditional = *it;
64  // Sample the conditional only if value for j not already in result
65  const Key j = conditional->firstFrontalKey();
66  if (result.count(j) == 0) {
67  conditional->sampleInPlace(&result, rng);
68  }
69  }
70  return result;
71 }
72 
73 /* ************************************************************************* */
74 // The implementation is: build the entire joint into one factor and then prune.
75 // NOTE(Frank): This can be quite expensive *unless* the factors have already
76 // been pruned before. Another, possibly faster approach is branch and bound
77 // search to find the K-best leaves and then create a single pruned conditional.
79  size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
80  DiscreteValues* fixedValues) const {
81  // Multiply into one big conditional. NOTE: possibly quite expensive.
83 
84  // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
85  DiscreteConditional pruned = joint;
86  pruned.prune(maxNrLeaves);
87 
88  DiscreteValues deadModesValues;
89  // If we have a dead mode threshold and discrete variables left after pruning,
90  // then we run dead mode removal.
91  if (marginalThreshold && pruned.keys().size() > 0) {
93  for (auto dkey : pruned.discreteKeys()) {
94  const Vector probabilities = marginals.marginalProbabilities(dkey);
95 
96  int index = -1;
97  auto threshold = (probabilities.array() > *marginalThreshold);
98  // If atleast 1 value is non-zero, then we can find the index
99  // Else if all are zero, index would be set to 0 which is incorrect
100  if (!threshold.isZero()) {
101  threshold.maxCoeff(&index);
102  }
103 
104  if (index >= 0) {
105  deadModesValues.emplace(dkey.first, index);
106  }
107  }
108 
109  // Remove the modes (imperative)
110  pruned.removeDiscreteModes(deadModesValues);
111 
112  // Set the fixed values if requested.
113  if (fixedValues) {
114  *fixedValues = deadModesValues;
115  }
116  }
117 
118  // Return the resulting DiscreteBayesNet.
120  if (pruned.keys().size() > 0) result.push_back(pruned);
121  return result;
122 }
123 
124 /* *********************************************************************** */
127  for (const DiscreteConditional::shared_ptr& conditional : *this)
128  joint = joint * (*conditional);
129 
130  return joint;
131 }
132 
133 /* *********************************************************************** */
135  const KeyFormatter& keyFormatter,
136  const DiscreteFactor::Names& names) const {
137  using std::endl;
138  std::stringstream ss;
139  ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
140  for (const DiscreteConditional::shared_ptr& conditional : *this)
141  ss << conditional->markdown(keyFormatter, names) << endl;
142  return ss.str();
143 }
144 
145 /* *********************************************************************** */
146 std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
147  const DiscreteFactor::Names& names) const {
148  using std::endl;
149  std::stringstream ss;
150  ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
151  for (const DiscreteConditional::shared_ptr& conditional : *this)
152  ss << conditional->html(keyFormatter, names) << endl;
153  return ss.str();
154 }
155 
156 /* ************************************************************************* */
157 } // namespace gtsam
DiscreteBayesNet.h
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
gtsam::DiscreteBayesNet::equals
bool equals(const This &bn, double tol=1e-9) const
Definition: DiscreteBayesNet.cpp:30
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:247
DiscreteConditional.h
gtsam::Vector
Eigen::VectorXd Vector
Definition: Vector.h:39
result
Values result
Definition: OdometryOptimize.cpp:8
FactorGraph-inst.h
Factor Graph Base Class.
ss
static std::stringstream ss
Definition: testBTree.cpp:31
gtsam::DiscreteBayesNet::joint
DiscreteConditional joint() const
Multiply all conditionals into one big joint conditional and return it.
Definition: DiscreteBayesNet.cpp:125
gtsam::DiscreteBayesNet::logProbability
double logProbability(const DiscreteValues &values) const
Definition: DiscreteBayesNet.cpp:35
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::DiscreteBayesNet::prune
DiscreteBayesNet prune(size_t maxNrLeaves, const std::optional< double > &marginalThreshold={}, DiscreteValues *fixedValues=nullptr) const
Prune the Bayes net.
Definition: DiscreteBayesNet.cpp:78
gtsam::FactorGraph< DiscreteConditional >::equals
bool equals(const This &fg, double tol=1e-9) const
Check equality up to tolerance.
Definition: FactorGraph-inst.h:50
gtsam::DiscreteBayesNet::markdown
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const DiscreteFactor::Names &names={}) const
Render as markdown tables.
Definition: DiscreteBayesNet.cpp:134
gtsam::DiscreteBayesNet::evaluate
double evaluate(const DiscreteValues &values) const
Definition: DiscreteBayesNet.cpp:44
gtsam::DiscreteMarginals
Definition: DiscreteMarginals.h:34
DiscreteMarginals.h
A class for computing marginals in a DiscreteFactorGraph.
gtsam::DiscreteConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: DiscreteConditional.h:44
gtsam::FactorGraph< DiscreteConditional >::size
size_t size() const
Definition: FactorGraph.h:297
process_shonan_timing_results.names
dictionary names
Definition: process_shonan_timing_results.py:175
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:38
gtsam
traits
Definition: ABC.h:17
make_reverse_iterator
std::reverse_iterator< Iterator > make_reverse_iterator(Iterator i)
Definition: stl_iterators.cpp:16
gtsam::DiscreteFactor::Names
DiscreteValues::Names Names
Translation table from values to strings.
Definition: DiscreteFactor.h:190
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:143
gtsam::FactorGraph< DiscreteConditional >::end
const_iterator end() const
Definition: FactorGraph.h:342
gtsam::DiscreteBayesNet::sample
DiscreteValues sample(std::mt19937_64 *rng=nullptr) const
do ancestral sampling
Definition: DiscreteBayesNet.cpp:53
gtsam::DiscreteBayesNet::html
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const DiscreteFactor::Names &names={}) const
Render as html tables.
Definition: DiscreteBayesNet.cpp:146
gtsam::FactorGraph< DiscreteConditional >::begin
const_iterator begin() const
Definition: FactorGraph.h:339
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::DiscreteFactor::discreteKeys
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Definition: DiscreteFactor.cpp:37
marginals
Marginals marginals(graph, result)
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
gtsam::DiscreteConditional::removeDiscreteModes
void removeDiscreteModes(const DiscreteValues &given)
Remove the discrete modes whose assignments are given to us. Only applies to discrete conditionals.
Definition: DiscreteConditional.cpp:509
gtsam::DiscreteConditional::prune
virtual void prune(size_t maxNrAssignments)
Prune the conditional.
Definition: DiscreteConditional.cpp:502


gtsam
Author(s):
autogenerated on Wed May 28 2025 03:01:13