BayesTree-inst.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #pragma once
22 
26 #include <gtsam/base/timing.h>
27 
28 #include <fstream>
29 #include <queue>
30 
31 namespace gtsam {
32 
33  /* ************************************************************************* */
34  template<class CLIQUE>
37  for (const sharedClique& root : roots_) getCliqueData(root, &stats);
38  return stats;
39  }
40 
41  /* ************************************************************************* */
42  template <class CLIQUE>
44  BayesTreeCliqueData* stats) const {
45  const auto conditional = clique->conditional();
46  stats->conditionalSizes.push_back(conditional->nrFrontals());
47  stats->separatorSizes.push_back(conditional->nrParents());
48  for (sharedClique c : clique->children) {
49  getCliqueData(c, stats);
50  }
51  }
52 
53  /* ************************************************************************* */
54  template<class CLIQUE>
56  size_t count = 0;
57  for(const sharedClique& root: roots_)
58  count += root->numCachedSeparatorMarginals();
59  return count;
60  }
61 
62  /* ************************************************************************* */
63  template <class CLIQUE>
64  void BayesTree<CLIQUE>::dot(std::ostream& os,
65  const KeyFormatter& keyFormatter) const {
66  if (roots_.empty())
67  throw std::invalid_argument(
68  "the root of Bayes tree has not been initialized!");
69  os << "digraph G{\n";
70  for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
71  os << "}";
72  std::flush(os);
73  }
74 
75  /* ************************************************************************* */
76  template <class CLIQUE>
77  std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
78  std::stringstream ss;
79  dot(ss, keyFormatter);
80  return ss.str();
81  }
82 
83  /* ************************************************************************* */
84  template <class CLIQUE>
85  void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
86  const KeyFormatter& keyFormatter) const {
87  std::ofstream of(filename.c_str());
88  dot(of, keyFormatter);
89  of.close();
90  }
91 
92  /* ************************************************************************* */
93  template <class CLIQUE>
94  void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
95  const KeyFormatter& keyFormatter,
96  int parentnum) const {
97  static int num = 0;
98  bool first = true;
99  std::stringstream out;
100  out << num;
101  std::string parent = out.str();
102  parent += "[label=\"";
103 
104  for (Key key : clique->conditional_->frontals()) {
105  if (!first) parent += ", ";
106  first = false;
107  parent += keyFormatter(key);
108  }
109 
110  if (clique->parent()) {
111  parent += " : ";
112  s << parentnum << "->" << num << "\n";
113  }
114 
115  first = true;
116  for (Key parentKey : clique->conditional_->parents()) {
117  if (!first) parent += ", ";
118  first = false;
119  parent += keyFormatter(parentKey);
120  }
121  parent += "\"];\n";
122  s << parent;
123  parentnum = num;
124 
125  for (sharedClique c : clique->children) {
126  num++;
127  dot(s, c, keyFormatter, parentnum);
128  }
129  }
130 
131  /* ************************************************************************* */
132  template<class CLIQUE>
133  size_t BayesTree<CLIQUE>::size() const {
134  size_t size = 0;
135  for(const sharedClique& clique: roots_)
136  size += clique->treeSize();
137  return size;
138  }
139 
140  /* ************************************************************************* */
141  template<class CLIQUE>
142  void BayesTree<CLIQUE>::addClique(const sharedClique& clique, const sharedClique& parent_clique) {
143  for(Key j: clique->conditional()->frontals())
144  nodes_[j] = clique;
145  if (parent_clique != nullptr) {
146  clique->parent_ = parent_clique;
147  parent_clique->children.push_back(clique);
148  } else {
149  roots_.push_back(clique);
150  }
151  }
152 
153  /* ************************************************************************* */
154  namespace {
155  template <class FACTOR, class CLIQUE>
156  struct _pushCliqueFunctor {
157  _pushCliqueFunctor(FactorGraph<FACTOR>* graph_) : graph(graph_) {}
158  FactorGraph<FACTOR>* graph;
159  int operator()(const std::shared_ptr<CLIQUE>& clique, int dummy) {
160  graph->push_back(clique->conditional_);
161  return 0;
162  }
163  };
164  } // namespace
165 
166  /* ************************************************************************* */
167  template <class CLIQUE>
169  FactorGraph<FactorType>* graph) const {
170  // Traverse the BayesTree and add all conditionals to this graph
171  int data = 0; // Unused
172  _pushCliqueFunctor<FactorType, CLIQUE> functor(graph);
173  treeTraversal::DepthFirstForest(*this, data, functor);
174  }
175 
176  /* ************************************************************************* */
177  template<class CLIQUE>
179  *this = other;
180  }
181 
182  /* ************************************************************************* */
183 
189  template<class CLIQUE>
191  /* Because tree nodes are hold by both root_ and nodes_, we need to clear nodes_ manually first and
192  * reduce the reference count of each node by 1. Otherwise, the nodes will not be properly deleted
193  * during the BFS process.
194  */
195  nodes_.clear();
196  for (auto&& root: roots_) {
197  std::queue<sharedClique> bfs_queue;
198 
199  // first, move the root to the queue
200  bfs_queue.push(root);
201  root = nullptr; // now the root node is owned by the queue
202 
203  // do a BFS on the tree, for each node, add its children to the queue, and then delete it from the queue
204  // So if the reference count of the node is 1, it will be deleted, and because its children are in the queue,
205  // the deletion of the node will not trigger a recursive deletion of the children.
206  while (!bfs_queue.empty()) {
207  // move the ownership of the front node from the queue to the current variable
208  auto current = bfs_queue.front();
209  bfs_queue.pop();
210 
211  // add the children of the current node to the queue, so that the queue will also own the children nodes.
212  for (auto child: current->children) {
213  bfs_queue.push(child);
214  } // leaving the scope of current will decrease the reference count of the current node by 1, and if the reference count is 0,
215  // the node will be deleted. Because the children are in the queue, the deletion of the node will not trigger a recursive
216  // deletion of the children.
217  }
218  }
219  }
220 
221  /* ************************************************************************* */
222  namespace {
223  template<typename NODE>
224  std::shared_ptr<NODE>
225  BayesTreeCloneForestVisitorPre(const std::shared_ptr<NODE>& node, const std::shared_ptr<NODE>& parentPointer)
226  {
227  // Clone the current node and add it to its cloned parent
228  std::shared_ptr<NODE> clone = std::make_shared<NODE>(*node);
229  clone->children.clear();
230  clone->parent_ = parentPointer;
231  parentPointer->children.push_back(clone);
232  return clone;
233  }
234  }
235 
236  /* ************************************************************************* */
237  template<class CLIQUE>
239  this->clear();
240  std::shared_ptr<Clique> rootContainer = std::make_shared<Clique>();
241  treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre<Clique>);
242  for(const sharedClique& root: rootContainer->children) {
243  root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique
244  insertRoot(root);
245  }
246  return *this;
247  }
248 
249  /* ************************************************************************* */
250  template<class CLIQUE>
251  void BayesTree<CLIQUE>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
252  std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl;
253  treeTraversal::PrintForest(*this, s, keyFormatter);
254  }
255 
256  /* ************************************************************************* */
257  // binary predicate to test equality of a pair for use in equals
258  template<class CLIQUE>
260  const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v1,
261  const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v2
262  ) {
263  return v1.first == v2.first &&
264  ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second)));
265  }
266 
267  /* ************************************************************************* */
268  template<class CLIQUE>
270  return size()==other.size() &&
271  std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
272  }
273 
274  /* ************************************************************************* */
275  template<class CLIQUE>
276  template<class CONTAINER>
277  Key BayesTree<CLIQUE>::findParentClique(const CONTAINER& parents) const {
278  typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
279  assert(lowestOrderedParent != parents.end());
280  return *lowestOrderedParent;
281  }
282 
283  /* ************************************************************************* */
284  template<class CLIQUE>
286  // Add each frontal variable of this root node
287  for(const Key& j: subtree->conditional()->frontals()) {
288  bool inserted = nodes_.insert({j, subtree}).second;
289  assert(inserted); (void)inserted;
290  }
291  // Fill index for each child
293  for(const sharedClique& child: subtree->children) {
294  fillNodesIndex(child); }
295  }
296 
297  /* ************************************************************************* */
298  template<class CLIQUE>
300  roots_.push_back(subtree); // Add to roots
301  fillNodesIndex(subtree); // Populate nodes index
302  }
303 
304  /* ************************************************************************* */
305  // First finds clique marginal then marginalizes that
306  /* ************************************************************************* */
307  template<class CLIQUE>
310  {
311  gttic(BayesTree_marginalFactor);
312 
313  // get clique containing Key j
314  sharedClique clique = this->clique(j);
315 
316  // calculate or retrieve its marginal P(C) = P(F,S)
317  FactorGraphType cliqueMarginal = clique->marginal2(function);
318 
319  // Now, marginalize out everything that is not variable j
320  BayesNetType marginalBN =
321  *cliqueMarginal.marginalMultifrontalBayesNet(Ordering{j}, function);
322 
323  // The Bayes net should contain only one conditional for variable j, so return it
324  return marginalBN.front();
325  }
326 
327  /* ************************************************************************* */
328  // Find two cliques, their joint, then marginalizes
329  /* ************************************************************************* */
330  template<class CLIQUE>
332  BayesTree<CLIQUE>::joint(Key j1, Key j2, const Eliminate& function) const
333  {
334  gttic(BayesTree_joint);
335  return std::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
336  }
337 
338  /* ************************************************************************* */
339  template<class CLIQUE>
341  BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
342  {
343  gttic(BayesTree_jointBayesNet);
344  // get clique C1 and C2
345  sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
346 
347  gttic(Lowest_common_ancestor);
348  // Find lowest common ancestor clique
349  sharedClique B; {
350  // Build two paths to the root
351  FastList<sharedClique> path1, path2; {
352  sharedClique p = C1;
353  while(p) {
354  path1.push_front(p);
355  p = p->parent();
356  }
357  } {
358  sharedClique p = C2;
359  while(p) {
360  path2.push_front(p);
361  p = p->parent();
362  }
363  }
364  // Find the path intersection
365  typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
366  if(*p1 == *p2)
367  B = *p1;
368  while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
369  B = *p1;
370  ++p1;
371  ++p2;
372  }
373  }
374  gttoc(Lowest_common_ancestor);
375 
376  // Build joint on all involved variables
377  FactorGraphType p_BC1C2;
378 
379  if(B)
380  {
381  // Compute marginal on lowest common ancestor clique
382  gttic(LCA_marginal);
383  FactorGraphType p_B = B->marginal2(function);
384  gttoc(LCA_marginal);
385 
386  // Compute shortcuts of the requested cliques given the lowest common ancestor
387  gttic(Clique_shortcuts);
388  BayesNetType p_C1_Bred = C1->shortcut(B, function);
389  BayesNetType p_C2_Bred = C2->shortcut(B, function);
390  gttoc(Clique_shortcuts);
391 
392  // Factor the shortcuts to be conditioned on the full root
393  // Get the set of variables to eliminate, which is C1\B.
394  gttic(Full_root_factoring);
395  std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
396  KeyVector C1_minus_B; {
397  KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
398  for(const Key j: *B->conditional()) {
399  C1_minus_B_set.erase(j); }
400  C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
401  }
402  // Factor into C1\B | B.
403  p_C1_B =
404  FactorGraphType(p_C1_Bred)
405  .eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
406  .first;
407  }
408  std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
409  KeyVector C2_minus_B; {
410  KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
411  for(const Key j: *B->conditional()) {
412  C2_minus_B_set.erase(j); }
413  C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
414  }
415  // Factor into C2\B | B.
416  p_C2_B =
417  FactorGraphType(p_C2_Bred)
418  .eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
419  .first;
420  }
421  gttoc(Full_root_factoring);
422 
423  gttic(Variable_joint);
424  p_BC1C2.push_back(p_B);
425  p_BC1C2.push_back(*p_C1_B);
426  p_BC1C2.push_back(*p_C2_B);
427  if(C1 != B)
428  p_BC1C2.push_back(C1->conditional());
429  if(C2 != B)
430  p_BC1C2.push_back(C2->conditional());
431  gttoc(Variable_joint);
432  }
433  else
434  {
435  // The nodes have no common ancestor, they're in different trees, so they're joint is just the
436  // product of their marginals.
437  gttic(Disjoint_marginals);
438  p_BC1C2.push_back(C1->marginal2(function));
439  p_BC1C2.push_back(C2->marginal2(function));
440  gttoc(Disjoint_marginals);
441  }
442 
443  // now, marginalize out everything that is not variable j1 or j2
444  return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
445  }
446 
447  /* ************************************************************************* */
448  template<class CLIQUE>
450  // Remove all nodes and clear the root pointer
451  nodes_.clear();
452  roots_.clear();
453  }
454 
455  /* ************************************************************************* */
456  template<class CLIQUE>
458  for(const sharedClique& root: roots_) {
459  root->deleteCachedShortcuts();
460  }
461  }
462 
463  /* ************************************************************************* */
464  template<class CLIQUE>
466  {
467  if (clique->isRoot()) {
468  typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
469  if(root != roots_.end())
470  roots_.erase(root);
471  } else { // detach clique from parent
472  sharedClique parent = clique->parent_.lock();
473  typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
474  assert(child != parent->children.end());
475  parent->children.erase(child);
476  }
477 
478  // orphan my children
479  for(sharedClique child: clique->children)
480  child->parent_ = typename Clique::weak_ptr();
481 
482  for(Key j: clique->conditional()->frontals()) {
483  nodes_.unsafe_erase(j);
484  }
485  }
486 
487  /* ************************************************************************* */
488  template <class CLIQUE>
490  Cliques* orphans) {
491  // base case is nullptr, if so we do nothing and return empties above
492  if (clique) {
493  // remove the clique from orphans in case it has been added earlier
494  orphans->remove(clique);
495 
496  // remove me
497  this->removeClique(clique);
498 
499  // remove path above me
500  this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn,
501  orphans);
502 
503  // add children to list of orphans (splice also removed them from
504  // clique->children_)
505  orphans->insert(orphans->begin(), clique->children.begin(),
506  clique->children.end());
507  clique->children.clear();
508 
509  bn->push_back(clique->conditional_);
510  }
511  }
512 
513  /* *************************************************************************
514  */
515  template <class CLIQUE>
517  Cliques* orphans) {
518  gttic(removetop);
519  // process each key of the new factor
520  for (const Key& j : keys) {
521  // get the clique
522  // TODO(frank): Nodes will be searched again in removeClique
523  typename Nodes::const_iterator node = nodes_.find(j);
524  if (node != nodes_.end()) {
525  // remove path from clique to root
526  this->removePath(node->second, bn, orphans);
527  }
528  }
529 
530  // Delete cachedShortcuts for each orphan subtree
531  // TODO(frank): Consider Improving
532  for (sharedClique& orphan : *orphans) orphan->deleteCachedShortcuts();
533  }
534 
535  /* ************************************************************************* */
536  template<class CLIQUE>
538  const sharedClique& subtree)
539  {
540  // Result clique list
541  Cliques cliques;
542  cliques.push_back(subtree);
543 
544  // Remove the first clique from its parents
545  if(!subtree->isRoot())
546  subtree->parent()->children.erase(std::find(
547  subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
548  else
549  roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
550 
551  // Add all subtree cliques and erase the children and parent of each
552  for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
553  {
554  // Add children
555  for(const sharedClique& child: (*clique)->children) {
556  cliques.push_back(child); }
557 
558  // Delete cached shortcuts
559  (*clique)->deleteCachedShortcutsNonRecursive();
560 
561  // Remove this node from the nodes index
562  for(Key j: (*clique)->conditional()->frontals()) {
563  nodes_.unsafe_erase(j); }
564 
565  // Erase the parent and children pointers
566  (*clique)->parent_.reset();
567  (*clique)->children.clear();
568  }
569 
570  return cliques;
571  }
572 
573 }
void removePath(sharedClique clique, BayesNetType *bn, Cliques *orphans)
const gtsam::Symbol key('X', 0)
std::shared_ptr< FactorGraphType > sharedFactorGraph
Definition: BayesTree.h:84
void removeTop(const KeyVector &keys, BayesNetType *bn, Cliques *orphans)
FactorGraphType::Eliminate Eliminate
Definition: BayesTree.h:85
size_t size() const
Vector v2
This & operator=(const This &other)
double dot(const V1 &a, const V2 &b)
Definition: Vector.h:195
Vector v1
Vector3f p1
void addClique(const sharedClique &clique, const sharedClique &parent_clique=sharedClique())
std::ofstream out("Result.txt")
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
std::shared_ptr< BayesNetType > sharedBayesNet
Definition: BayesTree.h:80
bool equals(const This &other, double tol=1e-9) const
std::shared_ptr< ConditionalType > sharedConditional
Definition: BayesTree.h:78
NonlinearFactorGraph graph
bool check_sharedCliques(const std::pair< Key, typename BayesTree< CLIQUE >::sharedClique > &v1, const std::pair< Key, typename BayesTree< CLIQUE >::sharedClique > &v2)
std::weak_ptr< This > weak_ptr
void fillNodesIndex(const sharedClique &subtree)
Key findParentClique(const CONTAINER &parents) const
bool stats
std::shared_ptr< Clique > sharedClique
Shared pointer to a clique.
Definition: BayesTree.h:74
void insertRoot(const sharedClique &subtree)
void DepthFirstForest(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost)
void deleteCachedShortcuts()
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
#define gttic(label)
Definition: timing.h:295
Cliques removeSubtree(const sharedClique &subtree)
BayesTreeCliqueData getCliqueData() const
FastVector< std::size_t > conditionalSizes
Definition: BayesTree.h:49
void saveGraph(const std::string &filename, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
output to file with graphviz format.
FastVector< std::size_t > separatorSizes
Definition: BayesTree.h:50
int data[]
RealScalar s
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
void addFactorsToGraph(FactorGraph< FactorType > *graph) const
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
HybridBayesTreeClique ::FactorGraphType FactorGraphType
Definition: BayesTree.h:83
static std::stringstream ss
Definition: testBTree.cpp:31
Bayes Tree is a tree of cliques of a Bayes Chain.
traits
Definition: chartTesting.h:28
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Output to graphviz format, stream version.
ofstream os("timeSchurFactors.csv")
#define gttoc(label)
Definition: timing.h:296
std::shared_ptr< This > shared_ptr
void removeClique(sharedClique clique)
float * p
static Point3 p2
HybridBayesTreeClique ::BayesNetType BayesNetType
Definition: BayesTree.h:79
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
const G double tol
Definition: Group.h:86
const KeyVector keys
sharedFactorGraph joint(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
size_t numCachedSeparatorMarginals() const
bool equal(const T &obj1, const T &obj2, double tol)
Definition: Testable.h:85
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
sharedConditional marginalFactor(Key j, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
std::ptrdiff_t j
Timing utilities.


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:33:57