HybridBayesTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
29 
30 #include <memory>
31 
32 namespace gtsam {
33 
34 // Instantiate base class
35 template class BayesTreeCliqueBase<HybridBayesTreeClique,
36  HybridGaussianFactorGraph>;
37 template class BayesTree<HybridBayesTreeClique>;
38 
39 /* ************************************************************************* */
40 bool HybridBayesTree::equals(const This& other, double tol) const {
41  return Base::equals(other, tol);
42 }
43 
44 /* ************************************************************************* */
46  DiscreteFactorGraph discrete_fg;
47  DiscreteValues mpe;
48 
49  auto root = roots_.at(0);
50  // Access the clique and get the underlying hybrid conditional
51  HybridConditional::shared_ptr root_conditional = root->conditional();
52 
53  // The root should be discrete only, we compute the MPE
54  if (root_conditional->isDiscrete()) {
55  discrete_fg.push_back(root_conditional->asDiscrete());
56  mpe = discrete_fg.optimize();
57  } else {
58  throw std::runtime_error(
59  "HybridBayesTree root is not discrete-only. Please check elimination "
60  "ordering or use continuous factor graph.");
61  }
62 
64  return HybridValues(values, mpe);
65 }
66 
67 /* ************************************************************************* */
77  // The gaussian bayes tree that will be recursively created.
79  // Flag indicating if all the nodes are valid. Used in optimize().
80  bool valid_;
81 
90  const GaussianBayesTree::sharedNode& parentClique,
91  GaussianBayesTree* gbt, bool valid = true)
92  : assignment_(assignment),
93  parentClique_(parentClique),
94  gaussianbayesTree_(gbt),
95  valid_(valid) {}
96 
97  bool isValid() const { return valid_; }
98 
108  const HybridBayesTree::sharedNode& node,
110  // Extract the gaussian conditional from the Hybrid clique
112 
114  if (hybrid_conditional->isHybrid()) {
115  conditional = (*hybrid_conditional->asHybrid())(parentData.assignment_);
116  } else if (hybrid_conditional->isContinuous()) {
117  conditional = hybrid_conditional->asGaussian();
118  } else {
119  // Discrete only conditional, so we set to empty gaussian conditional
120  conditional = std::make_shared<GaussianConditional>();
121  }
122 
124  if (conditional) {
125  // Create the GaussianClique for the current node
126  clique = std::make_shared<GaussianBayesTree::Node>(conditional);
127  // Add the current clique to the GaussianBayesTree.
128  parentData.gaussianbayesTree_->addClique(clique,
129  parentData.parentClique_);
130  } else {
131  parentData.valid_ = false;
132  }
133 
134  // Create new HybridAssignmentData where the current node is the parent
135  // This will be passed down to the children nodes
136  HybridAssignmentData data(parentData.assignment_, clique,
137  parentData.gaussianbayesTree_, parentData.valid_);
138  return data;
139  }
140 };
141 
142 /* ************************************************************************* */
144  const DiscreteValues& assignment) const {
145  GaussianBayesTree gbt;
146  HybridAssignmentData rootData(assignment, 0, &gbt);
147  {
148  treeTraversal::no_op visitorPost;
149  // Limits OpenMP threads since we're mixing TBB and OpenMP
150  TbbOpenMPMixedScope threadLimiter;
153  visitorPost);
154  }
155 
156  if (!rootData.isValid()) {
157  return GaussianBayesTree();
158  }
159  return gbt;
160 }
161 
162 /* ************************************************************************* */
164  return HybridGaussianFactorGraph(*this).error(values);
165 }
166 
167 /* ************************************************************************* */
169  GaussianBayesTree gbt = this->choose(assignment);
170  // If empty GaussianBayesTree, means a clique is pruned hence invalid
171  if (gbt.size() == 0) {
172  return VectorValues();
173  }
174  VectorValues result = gbt.optimize();
175 
176  // Return the optimized bayes net result.
177  return result;
178 }
179 
180 /* ************************************************************************* */
181 void HybridBayesTree::prune(const size_t maxNrLeaves) {
182  auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
183 
184  DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
185  discreteProbs->root_ = prunedDiscreteProbs.root_;
186 
188  struct HybridPrunerData {
190  DecisionTreeFactor prunedDiscreteProbs;
191  HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
192  const HybridBayesTree::sharedNode& parentClique)
193  : prunedDiscreteProbs(prunedDiscreteProbs) {}
194 
203  static HybridPrunerData AssignmentPreOrderVisitor(
205  HybridPrunerData& parentData) {
206  // Get the conditional
207  HybridConditional::shared_ptr conditional = clique->conditional();
208 
209  // If conditional is hybrid, we prune it.
210  if (conditional->isHybrid()) {
211  auto hybridGaussianCond = conditional->asHybrid();
212 
213  if (!hybridGaussianCond->pruned()) {
214  // Imperative
215  clique->conditional() = std::make_shared<HybridConditional>(
216  hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
217  }
218  }
219  return parentData;
220  }
221  };
222 
223  HybridPrunerData rootData(prunedDiscreteProbs, 0);
224  {
225  treeTraversal::no_op visitorPost;
226  // Limits OpenMP threads since we're mixing TBB and OpenMP
227  TbbOpenMPMixedScope threadLimiter;
229  *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
230  visitorPost);
231  }
232 }
233 
234 } // namespace gtsam
gtsam::HybridConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: HybridConditional.h:66
DiscreteBayesNet.h
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
gtsam::HybridAssignmentData::isValid
bool isValid() const
Definition: HybridBayesTree.cpp:97
gtsam::HybridValues
Definition: HybridValues.h:37
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
treeTraversal-inst.h
gtsam::FactorGraph::error
double error(const HybridValues &values) const
Definition: FactorGraph-inst.h:66
DiscreteFactorGraph.h
gtsam::BayesTree< HybridBayesTreeClique >::roots_
Roots roots_
Definition: BayesTree.h:103
gtsam::HybridAssignmentData::HybridAssignmentData
HybridAssignmentData(const DiscreteValues &assignment, const GaussianBayesTree::sharedNode &parentClique, GaussianBayesTree *gbt, bool valid=true)
Construct a new Hybrid Assignment Data object.
Definition: HybridBayesTree.cpp:89
gtsam::HybridBayesTree
Definition: HybridBayesTree.h:62
gtsam::DecisionTreeFactor::prune
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: DecisionTreeFactor.cpp:508
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
gtsam::HybridBayesTree::error
double error(const HybridValues &values) const
Definition: HybridBayesTree.cpp:163
different_sigmas::values
HybridValues values
Definition: testHybridBayesNet.cpp:245
gtsam::HybridAssignmentData::assignment_
const DiscreteValues assignment_
Definition: HybridBayesTree.cpp:75
GaussianJunctionTree.h
gtsam::HybridFactor::isHybrid
bool isHybrid() const
True is this is a Discrete-Continuous factor.
Definition: HybridFactor.h:125
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::HybridAssignmentData::AssignmentPreOrderVisitor
static HybridAssignmentData AssignmentPreOrderVisitor(const HybridBayesTree::sharedNode &node, HybridAssignmentData &parentData)
A function used during tree traversal that operates on each node before visiting the node's children.
Definition: HybridBayesTree.cpp:107
gtsam::DecisionTree::root_
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:149
data
int data[]
Definition: Map_placement_new.cpp:1
parentData
DATA & parentData
Definition: treeTraversal-inst.h:45
gtsam::VectorValues
Definition: VectorValues.h:74
gtsam::BayesTree< HybridBayesTreeClique >::clique
const sharedClique & clique(Key j) const
Definition: BayesTree.h:155
gtsam::HybridBayesTree::optimize
HybridValues optimize() const
Optimize the hybrid Bayes tree by computing the MPE for the current set of discrete variables and usi...
Definition: HybridBayesTree.cpp:45
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:106
HybridBayesTree.h
Hybrid Bayes Tree, the result of eliminating a HybridJunctionTree.
gtsam::HybridFactor::isContinuous
bool isContinuous() const
True if this is a factor of continuous variables only.
Definition: HybridFactor.h:122
gtsam::GaussianBayesTree::optimize
VectorValues optimize() const
Definition: GaussianBayesTree.cpp:67
gtsam::HybridAssignmentData::gaussianbayesTree_
GaussianBayesTree * gaussianbayesTree_
Definition: HybridBayesTree.cpp:78
gtsam::HybridBayesTree::choose
GaussianBayesTree choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes Tree which corresponds to a specific discrete value assignment.
Definition: HybridBayesTree.cpp:143
gtsam::treeTraversal::DepthFirstForestParallel
void DepthFirstForestParallel(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost, int problemSizeThreshold=10)
Definition: treeTraversal-inst.h:156
gtsam
traits
Definition: SFMdata.h:40
gtsam::GaussianBayesTree
Definition: GaussianBayesTree.h:49
gtsam::BayesTree::size
size_t size() const
Definition: BayesTree-inst.h:133
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
equal_constants::hybrid_conditional
const HybridGaussianConditional hybrid_conditional(mode, conditionals)
gtsam::GaussianConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: GaussianConditional.h:46
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::HybridAssignmentData
Helper class for Depth First Forest traversal on the HybridBayesTree.
Definition: HybridBayesTree.cpp:74
gtsam::BayesTree< GaussianBayesTreeClique >::sharedNode
sharedClique sharedNode
Synonym for sharedClique (TODO: remove)
Definition: BayesTree.h:76
gtsam::HybridBayesTree::prune
void prune(const size_t maxNumberLeaves)
Prune the underlying Bayes tree.
Definition: HybridBayesTree.cpp:181
gtsam::BayesTree< HybridBayesTreeClique >::equals
bool equals(const This &other, double tol=1e-9) const
Definition: BayesTree-inst.h:269
BayesTreeCliqueBase-inst.h
Base class for cliques of a BayesTree.
gtsam::DiscreteFactorGraph::optimize
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
Definition: DiscreteFactorGraph.cpp:209
BayesTree-inst.h
Bayes Tree is a tree of cliques of a Bayes Chain.
gtsam::HybridAssignmentData::valid_
bool valid_
Definition: HybridBayesTree.cpp:80
gtsam::TbbOpenMPMixedScope
Definition: types.h:162
gtsam::HybridAssignmentData::parentClique_
GaussianBayesTree::sharedNode parentClique_
Definition: HybridBayesTree.cpp:76
gtsam::HybridBayesTree::equals
bool equals(const This &other, double tol=1e-9) const
Definition: HybridBayesTree.cpp:40
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
HybridConditional.h


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:22