HybridBayesTree.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
28 
29 namespace gtsam {
30 
31 // Instantiate base class
32 template class BayesTreeCliqueBase<HybridBayesTreeClique,
33  HybridGaussianFactorGraph>;
34 template class BayesTree<HybridBayesTreeClique>;
35 
36 /* ************************************************************************* */
37 bool HybridBayesTree::equals(const This& other, double tol) const {
38  return Base::equals(other, tol);
39 }
40 
41 /* ************************************************************************* */
43  DiscreteBayesNet dbn;
44  DiscreteValues mpe;
45 
46  auto root = roots_.at(0);
47  // Access the clique and get the underlying hybrid conditional
48  HybridConditional::shared_ptr root_conditional = root->conditional();
49 
50  // The root should be discrete only, we compute the MPE
51  if (root_conditional->isDiscrete()) {
52  dbn.push_back(root_conditional->asDiscrete());
53  mpe = DiscreteFactorGraph(dbn).optimize();
54  } else {
55  throw std::runtime_error(
56  "HybridBayesTree root is not discrete-only. Please check elimination "
57  "ordering or use continuous factor graph.");
58  }
59 
61  return HybridValues(values, mpe);
62 }
63 
64 /* ************************************************************************* */
74  // The gaussian bayes tree that will be recursively created.
76  // Flag indicating if all the nodes are valid. Used in optimize().
77  bool valid_;
78 
87  const GaussianBayesTree::sharedNode& parentClique,
88  GaussianBayesTree* gbt, bool valid = true)
89  : assignment_(assignment),
90  parentClique_(parentClique),
91  gaussianbayesTree_(gbt),
92  valid_(valid) {}
93 
94  bool isValid() const { return valid_; }
95 
105  const HybridBayesTree::sharedNode& node,
107  // Extract the gaussian conditional from the Hybrid clique
108  HybridConditional::shared_ptr hybrid_conditional = node->conditional();
109 
111  if (hybrid_conditional->isHybrid()) {
112  conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
113  } else if (hybrid_conditional->isContinuous()) {
114  conditional = hybrid_conditional->asGaussian();
115  } else {
116  // Discrete only conditional, so we set to empty gaussian conditional
117  conditional = std::make_shared<GaussianConditional>();
118  }
119 
121  if (conditional) {
122  // Create the GaussianClique for the current node
123  clique = std::make_shared<GaussianBayesTree::Node>(conditional);
124  // Add the current clique to the GaussianBayesTree.
125  parentData.gaussianbayesTree_->addClique(clique,
126  parentData.parentClique_);
127  } else {
128  parentData.valid_ = false;
129  }
130 
131  // Create new HybridAssignmentData where the current node is the parent
132  // This will be passed down to the children nodes
133  HybridAssignmentData data(parentData.assignment_, clique,
134  parentData.gaussianbayesTree_, parentData.valid_);
135  return data;
136  }
137 };
138 
139 /* *************************************************************************
140  */
142  const DiscreteValues& assignment) const {
143  GaussianBayesTree gbt;
144  HybridAssignmentData rootData(assignment, 0, &gbt);
145  {
146  treeTraversal::no_op visitorPost;
147  // Limits OpenMP threads since we're mixing TBB and OpenMP
148  TbbOpenMPMixedScope threadLimiter;
151  visitorPost);
152  }
153 
154  if (!rootData.isValid()) {
155  return GaussianBayesTree();
156  }
157  return gbt;
158 }
159 
160 /* *************************************************************************
161  */
163  GaussianBayesTree gbt = this->choose(assignment);
164  // If empty GaussianBayesTree, means a clique is pruned hence invalid
165  if (gbt.size() == 0) {
166  return VectorValues();
167  }
168  VectorValues result = gbt.optimize();
169 
170  // Return the optimized bayes net result.
171  return result;
172 }
173 
174 /* ************************************************************************* */
175 void HybridBayesTree::prune(const size_t maxNrLeaves) {
176  auto decisionTree =
177  this->roots_.at(0)->conditional()->asDiscrete();
178 
179  DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
180  decisionTree->root_ = prunedDecisionTree.root_;
181 
183  struct HybridPrunerData {
185  DecisionTreeFactor prunedDecisionTree;
186  HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
187  const HybridBayesTree::sharedNode& parentClique)
188  : prunedDecisionTree(prunedDecisionTree) {}
189 
198  static HybridPrunerData AssignmentPreOrderVisitor(
200  HybridPrunerData& parentData) {
201  // Get the conditional
202  HybridConditional::shared_ptr conditional = clique->conditional();
203 
204  // If conditional is hybrid, we prune it.
205  if (conditional->isHybrid()) {
206  auto gaussianMixture = conditional->asMixture();
207 
208  gaussianMixture->prune(parentData.prunedDecisionTree);
209  }
210  return parentData;
211  }
212  };
213 
214  HybridPrunerData rootData(prunedDecisionTree, 0);
215  {
216  treeTraversal::no_op visitorPost;
217  // Limits OpenMP threads since we're mixing TBB and OpenMP
218  TbbOpenMPMixedScope threadLimiter;
220  *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
221  visitorPost);
222  }
223 }
224 
225 } // namespace gtsam
DecisionTreeFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
sharedClique sharedNode
Synonym for sharedClique (TODO: remove)
Definition: BayesTree.h:76
size_t size() const
void addClique(const sharedClique &clique, const sharedClique &parent_clique=sharedClique())
const DiscreteValues assignment_
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:190
leaf::MyValues values
Base class for cliques of a BayesTree.
bool equals(const This &other, double tol=1e-9) const
static HybridAssignmentData AssignmentPreOrderVisitor(const HybridBayesTree::sharedNode &node, HybridAssignmentData &parentData)
A function used during tree traversal that operates on each node before visiting the node&#39;s children...
DiscreteValues optimize(OptionalOrderingType orderingType={}) const
Find the maximum probable explanation (MPE) by doing max-product.
void prune(const size_t maxNumberLeaves)
Prune the underlying Bayes tree.
GaussianBayesTree choose(const DiscreteValues &assignment) const
Get the Gaussian Bayes Tree which corresponds to a specific discrete value assignment.
A Bayes net of Gaussian Conditionals indexed by discrete keys.
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:136
Values result
DATA & parentData
VectorValues optimize() const
void DepthFirstForestParallel(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost, int problemSizeThreshold=10)
Hybrid Bayes Tree, the result of eliminating a HybridJunctionTree.
HybridValues optimize() const
Optimize the hybrid Bayes tree by computing the MPE for the current set of discrete variables and usi...
int data[]
GaussianBayesTree::sharedNode parentClique_
Bayes Tree is a tree of cliques of a Bayes Chain.
std::shared_ptr< This > shared_ptr
shared_ptr to this class
traits
Definition: chartTesting.h:28
GaussianBayesTree * gaussianbayesTree_
HybridAssignmentData(const DiscreteValues &assignment, const GaussianBayesTree::sharedNode &parentClique, GaussianBayesTree *gbt, bool valid=true)
Construct a new Hybrid Assignment Data object.
const sharedClique & clique(Key j) const
Definition: BayesTree.h:155
const G double tol
Definition: Group.h:86
bool equals(const This &other, double tol=1e-9) const
Helper class for Depth First Forest traversal on the HybridBayesTree.


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:20