HybridBayesTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
30 
31 #include <memory>
32 
33 namespace gtsam {
34 
35 // Instantiate base class
36 template class BayesTreeCliqueBase<HybridBayesTreeClique,
37  HybridGaussianFactorGraph>;
38 template class BayesTree<HybridBayesTreeClique>;
39 
40 /* ************************************************************************* */
41 bool HybridBayesTree::equals(const This& other, double tol) const {
42  return Base::equals(other, tol);
43 }
44 
45 /* ************************************************************************* */
47  const DiscreteFactorGraph& dfg) const {
49 
50  // Check type of product, and get as TableFactor for efficiency.
51  TableFactor p;
52  if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
53  p = *tf;
54  } else {
55  p = TableFactor(product->toDecisionTreeFactor());
56  }
57  DiscreteValues assignment = TableDistribution(p).argmax();
58  return assignment;
59 }
60 
61 /* ************************************************************************* */
63  DiscreteFactorGraph discrete_fg;
65 
66  auto root = roots_.at(0);
67  // Access the clique and get the underlying hybrid conditional
68  HybridConditional::shared_ptr root_conditional = root->conditional();
69 
70  // The root should be discrete only, we compute the MPE
71  if (root_conditional->isDiscrete()) {
72  auto discrete = root_conditional->asDiscrete<TableDistribution>();
73  discrete_fg.push_back(discrete);
74  mpe = discreteMaxProduct(discrete_fg);
75  } else {
76  mpe = DiscreteValues();
77  }
78 
79  return mpe;
80 }
81 
82 /* ************************************************************************* */
84  DiscreteValues mpe = this->mpe();
85 
87  return HybridValues(values, mpe);
88 }
89 
90 /* ************************************************************************* */
100  // The gaussian bayes tree that will be recursively created.
102  // Flag indicating if all the nodes are valid. Used in optimize().
103  bool valid_;
104 
113  const GaussianBayesTree::sharedNode& parentClique,
114  GaussianBayesTree* gbt, bool valid = true)
115  : assignment_(assignment),
116  parentClique_(parentClique),
117  gaussianbayesTree_(gbt),
118  valid_(valid) {}
119 
120  bool isValid() const { return valid_; }
121 
131  const HybridBayesTree::sharedNode& node,
133  // Extract the gaussian conditional from the Hybrid clique
135 
137  if (hybrid_conditional->isHybrid()) {
138  conditional = (*hybrid_conditional->asHybrid())(parentData.assignment_);
139  } else if (hybrid_conditional->isContinuous()) {
140  conditional = hybrid_conditional->asGaussian();
141  } else {
142  // Discrete only conditional, so we set to empty gaussian conditional
143  conditional = std::make_shared<GaussianConditional>();
144  }
145 
147  if (conditional) {
148  // Create the GaussianClique for the current node
149  clique = std::make_shared<GaussianBayesTree::Node>(conditional);
150  // Add the current clique to the GaussianBayesTree.
151  parentData.gaussianbayesTree_->addClique(clique,
152  parentData.parentClique_);
153  } else {
154  parentData.valid_ = false;
155  }
156 
157  // Create new HybridAssignmentData where the current node is the parent
158  // This will be passed down to the children nodes
159  HybridAssignmentData data(parentData.assignment_, clique,
160  parentData.gaussianbayesTree_, parentData.valid_);
161  return data;
162  }
163 };
164 
165 /* ************************************************************************* */
167  const DiscreteValues& assignment) const {
168  GaussianBayesTree gbt;
169  HybridAssignmentData rootData(assignment, 0, &gbt);
170  {
171  treeTraversal::no_op visitorPost;
172  // Limits OpenMP threads since we're mixing TBB and OpenMP
173  TbbOpenMPMixedScope threadLimiter;
176  visitorPost);
177  }
178 
179  if (!rootData.isValid()) {
180  return GaussianBayesTree();
181  }
182  return gbt;
183 }
184 
185 /* ************************************************************************* */
187  return HybridGaussianFactorGraph(*this).error(values);
188 }
189 
190 /* ************************************************************************* */
192  GaussianBayesTree gbt = this->choose(assignment);
193  // If empty GaussianBayesTree, means a clique is pruned hence invalid
194  if (gbt.size() == 0) {
195  return VectorValues();
196  }
197  VectorValues result = gbt.optimize();
198 
199  // Return the optimized bayes net result.
200  return result;
201 }
202 
203 /* ************************************************************************* */
204 void HybridBayesTree::prune(const size_t maxNrLeaves) {
205  if (!this->roots_.at(0)->conditional()->asDiscrete()) {
206  // Root of the BayesTree is not a discrete clique, so we do nothing.
207  return;
208  }
209 
210  auto prunedDiscreteProbs =
211  this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
212 
213  // Imperative pruning
214  prunedDiscreteProbs->prune(maxNrLeaves);
215 
217  struct HybridPrunerData {
219  DiscreteConditional::shared_ptr prunedDiscreteProbs;
220  HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs,
221  const HybridBayesTree::sharedNode& parentClique)
222  : prunedDiscreteProbs(prunedDiscreteProbs) {}
223 
232  static HybridPrunerData AssignmentPreOrderVisitor(
234  HybridPrunerData& parentData) {
235  // Get the conditional
236  HybridConditional::shared_ptr conditional = clique->conditional();
237 
238  // If conditional is hybrid, we prune it.
239  if (conditional->isHybrid()) {
240  auto hybridGaussianCond = conditional->asHybrid();
241 
242  if (!hybridGaussianCond->pruned()) {
243  // Imperative
244  clique->conditional() = std::make_shared<HybridConditional>(
245  hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
246  }
247  }
248  return parentData;
249  }
250  };
251 
252  HybridPrunerData rootData(prunedDiscreteProbs, 0);
253  {
254  treeTraversal::no_op visitorPost;
255  // Limits OpenMP threads since we're mixing TBB and OpenMP
256  TbbOpenMPMixedScope threadLimiter;
258  *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
259  visitorPost);
260  }
261 }
262 
263 } // namespace gtsam
gtsam::HybridConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: HybridConditional.h:66
TableDistribution.h
gtsam::TableFactor
Definition: TableFactor.h:54
DiscreteBayesNet.h
gtsam::HybridAssignmentData::isValid
bool isValid() const
Definition: HybridBayesTree.cpp:120
gtsam::HybridValues
Definition: HybridValues.h:37
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
treeTraversal-inst.h
gtsam::TableDistribution::argmax
DiscreteValues argmax() const
Return assignment that maximizes value.
Definition: TableDistribution.cpp:125
gtsam::FactorGraph::error
double error(const HybridValues &values) const
Definition: FactorGraph-inst.h:66
DiscreteFactorGraph.h
gtsam::BayesTree< HybridBayesTreeClique >::roots_
Roots roots_
Definition: BayesTree.h:103
gtsam::HybridAssignmentData::HybridAssignmentData
HybridAssignmentData(const DiscreteValues &assignment, const GaussianBayesTree::sharedNode &parentClique, GaussianBayesTree *gbt, bool valid=true)
Construct a new Hybrid Assignment Data object.
Definition: HybridBayesTree.cpp:112
gtsam::HybridBayesTree
Definition: HybridBayesTree.h:62
gtsam::HybridBayesTree::discreteMaxProduct
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph &dfg) const
Definition: HybridBayesTree.cpp:46
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
gtsam::HybridBayesTree::error
double error(const HybridValues &values) const
Definition: HybridBayesTree.cpp:186
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:247
gtsam::HybridAssignmentData::assignment_
const DiscreteValues assignment_
Definition: HybridBayesTree.cpp:98
GaussianJunctionTree.h
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::DiscreteFactorGraph::scaledProduct
DiscreteFactor::shared_ptr scaledProduct() const
Return product of all factors as a single factor, which is scaled by the max value to prevent underfl...
Definition: DiscreteFactorGraph.cpp:117
gtsam::HybridAssignmentData::AssignmentPreOrderVisitor
static HybridAssignmentData AssignmentPreOrderVisitor(const HybridBayesTree::sharedNode &node, HybridAssignmentData &parentData)
A function used during tree traversal that operates on each node before visiting the node's children.
Definition: HybridBayesTree.cpp:130
data
int data[]
Definition: Map_placement_new.cpp:1
parentData
DATA & parentData
Definition: treeTraversal-inst.h:45
gtsam::VectorValues
Definition: VectorValues.h:74
gtsam::HybridBayesTree::mpe
DiscreteValues mpe() const
Compute the Most Probable Explanation (MPE) of the discrete variables.
Definition: HybridBayesTree.cpp:62
gtsam::TableDistribution
Definition: TableDistribution.h:39
gtsam::BayesTree< HybridBayesTreeClique >::clique
const sharedClique & clique(Key j) const
Definition: BayesTree.h:156
gtsam::HybridBayesTree::optimize
HybridValues optimize() const
Optimize the hybrid Bayes tree by computing the MPE for the current set of discrete variables and usi...
Definition: HybridBayesTree.cpp:83
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:106
HybridBayesTree.h
Hybrid Bayes Tree, the result of eliminating a HybridJunctionTree.
gtsam::GaussianBayesTree::optimize
VectorValues optimize() const
Definition: GaussianBayesTree.cpp:67
gtsam::HybridAssignmentData::gaussianbayesTree_
GaussianBayesTree * gaussianbayesTree_
Definition: HybridBayesTree.cpp:101
gtsam::HybridBayesTree::choose
GaussianBayesTree choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes Tree which corresponds to a specific discrete value assignment.
Definition: HybridBayesTree.cpp:166
gtsam::DiscreteConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: DiscreteConditional.h:43
gtsam::treeTraversal::DepthFirstForestParallel
void DepthFirstForestParallel(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost, int problemSizeThreshold=10)
Definition: treeTraversal-inst.h:156
gtsam
traits
Definition: SFMdata.h:40
gtsam::GaussianBayesTree
Definition: GaussianBayesTree.h:49
gtsam::BayesTree::size
size_t size() const
Definition: BayesTree-inst.h:135
gtsam::DiscreteFactor::shared_ptr
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Definition: DiscreteFactor.h:45
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
gtsam::TableDistribution::prune
virtual void prune(size_t maxNrAssignments) override
Prune the conditional.
Definition: TableDistribution.cpp:142
p
float * p
Definition: Tutorial_Map_using.cpp:9
equal_constants::hybrid_conditional
const HybridGaussianConditional hybrid_conditional(mode, conditionals)
gtsam::GaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianConditional.h:46
product
void product(const MatrixType &m)
Definition: product.h:20
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::HybridAssignmentData
Helper class for Depth First Forest traversal on the HybridBayesTree.
Definition: HybridBayesTree.cpp:97
gtsam::BayesTree< GaussianBayesTreeClique >::sharedNode
sharedClique sharedNode
Synonym for sharedClique (TODO: remove)
Definition: BayesTree.h:76
gtsam::HybridBayesTree::prune
void prune(const size_t maxNumberLeaves)
Prune the underlying Bayes tree.
Definition: HybridBayesTree.cpp:204
gtsam::BayesTree< HybridBayesTreeClique >::equals
bool equals(const This &other, double tol=1e-9) const
Definition: BayesTree-inst.h:270
BayesTreeCliqueBase-inst.h
Base class for cliques of a BayesTree.
BayesTree-inst.h
Bayes Tree is a tree of cliques of a Bayes Chain.
gtsam::HybridAssignmentData::valid_
bool valid_
Definition: HybridBayesTree.cpp:103
gtsam::TbbOpenMPMixedScope
Definition: types.h:162
gtsam::HybridAssignmentData::parentClique_
GaussianBayesTree::sharedNode parentClique_
Definition: HybridBayesTree.cpp:99
gtsam::HybridBayesTree::equals
bool equals(const This &other, double tol=1e-9) const
Definition: HybridBayesTree.cpp:41
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
HybridConditional.h


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:01:48