BayesTree-inst.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #pragma once
22 
26 #include <gtsam/base/timing.h>
27 
28 #include <boost/optional.hpp>
29 #include <boost/assign/list_of.hpp>
30 #include <fstream>
31 
32 using boost::assign::cref_list_of;
33 
34 namespace gtsam {
35 
36  /* ************************************************************************* */
37  template<class CLIQUE>
40  for (const sharedClique& root : roots_) getCliqueData(root, &stats);
41  return stats;
42  }
43 
44  /* ************************************************************************* */
45  template <class CLIQUE>
47  BayesTreeCliqueData* stats) const {
48  const auto conditional = clique->conditional();
49  stats->conditionalSizes.push_back(conditional->nrFrontals());
50  stats->separatorSizes.push_back(conditional->nrParents());
51  for (sharedClique c : clique->children) {
52  getCliqueData(c, stats);
53  }
54  }
55 
56  /* ************************************************************************* */
57  template<class CLIQUE>
59  size_t count = 0;
60  for(const sharedClique& root: roots_)
61  count += root->numCachedSeparatorMarginals();
62  return count;
63  }
64 
65  /* ************************************************************************* */
66  template<class CLIQUE>
67  void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
68  if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
69  std::ofstream of(s.c_str());
70  of<< "digraph G{\n";
71  for(const sharedClique& root: roots_)
72  saveGraph(of, root, keyFormatter);
73  of<<"}";
74  of.close();
75  }
76 
77  /* ************************************************************************* */
78  template<class CLIQUE>
79  void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
80  static int num = 0;
81  bool first = true;
82  std::stringstream out;
83  out << num;
84  std::string parent = out.str();
85  parent += "[label=\"";
86 
87  for (Key index : clique->conditional_->frontals()) {
88  if (!first) parent += ",";
89  first = false;
90  parent += indexFormatter(index);
91  }
92 
93  if (clique->parent()) {
94  parent += " : ";
95  s << parentnum << "->" << num << "\n";
96  }
97 
98  first = true;
99  for (Key sep : clique->conditional_->parents()) {
100  if (!first) parent += ",";
101  first = false;
102  parent += indexFormatter(sep);
103  }
104  parent += "\"];\n";
105  s << parent;
106  parentnum = num;
107 
108  for (sharedClique c : clique->children) {
109  num++;
110  saveGraph(s, c, indexFormatter, parentnum);
111  }
112  }
113 
114  /* ************************************************************************* */
115  template<class CLIQUE>
116  size_t BayesTree<CLIQUE>::size() const {
117  size_t size = 0;
118  for(const sharedClique& clique: roots_)
119  size += clique->treeSize();
120  return size;
121  }
122 
123  /* ************************************************************************* */
124  template<class CLIQUE>
125  void BayesTree<CLIQUE>::addClique(const sharedClique& clique, const sharedClique& parent_clique) {
126  for(Key j: clique->conditional()->frontals())
127  nodes_[j] = clique;
128  if (parent_clique != nullptr) {
129  clique->parent_ = parent_clique;
130  parent_clique->children.push_back(clique);
131  } else {
132  roots_.push_back(clique);
133  }
134  }
135 
136  /* ************************************************************************* */
137  namespace {
138  template <class FACTOR, class CLIQUE>
139  struct _pushCliqueFunctor {
140  _pushCliqueFunctor(FactorGraph<FACTOR>* graph_) : graph(graph_) {}
141  FactorGraph<FACTOR>* graph;
142  int operator()(const boost::shared_ptr<CLIQUE>& clique, int dummy) {
143  graph->push_back(clique->conditional_);
144  return 0;
145  }
146  };
147  } // namespace
148 
149  /* ************************************************************************* */
150  template <class CLIQUE>
153  // Traverse the BayesTree and add all conditionals to this graph
154  int data = 0; // Unused
155  _pushCliqueFunctor<FactorType, CLIQUE> functor(graph);
156  treeTraversal::DepthFirstForest(*this, data, functor);
157  }
158 
159  /* ************************************************************************* */
160  template<class CLIQUE>
162  *this = other;
163  }
164 
165  /* ************************************************************************* */
166  namespace {
167  template<typename NODE>
168  boost::shared_ptr<NODE>
169  BayesTreeCloneForestVisitorPre(const boost::shared_ptr<NODE>& node, const boost::shared_ptr<NODE>& parentPointer)
170  {
171  // Clone the current node and add it to its cloned parent
172  boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node);
173  clone->children.clear();
174  clone->parent_ = parentPointer;
175  parentPointer->children.push_back(clone);
176  return clone;
177  }
178  }
179 
180  /* ************************************************************************* */
181  template<class CLIQUE>
183  this->clear();
184  boost::shared_ptr<Clique> rootContainer = boost::make_shared<Clique>();
185  treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre<Clique>);
186  for(const sharedClique& root: rootContainer->children) {
187  root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique
188  insertRoot(root);
189  }
190  return *this;
191  }
192 
193  /* ************************************************************************* */
194  template<class CLIQUE>
195  void BayesTree<CLIQUE>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
196  std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl;
197  treeTraversal::PrintForest(*this, s, keyFormatter);
198  }
199 
200  /* ************************************************************************* */
201  // binary predicate to test equality of a pair for use in equals
202  template<class CLIQUE>
204  const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v1,
205  const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v2
206  ) {
207  return v1.first == v2.first &&
208  ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second)));
209  }
210 
211  /* ************************************************************************* */
212  template<class CLIQUE>
213  bool BayesTree<CLIQUE>::equals(const BayesTree<CLIQUE>& other, double tol) const {
214  return size()==other.size() &&
215  std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
216  }
217 
218  /* ************************************************************************* */
219  template<class CLIQUE>
220  template<class CONTAINER>
221  Key BayesTree<CLIQUE>::findParentClique(const CONTAINER& parents) const {
222  typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
223  assert(lowestOrderedParent != parents.end());
224  return *lowestOrderedParent;
225  }
226 
227  /* ************************************************************************* */
228  template<class CLIQUE>
230  // Add each frontal variable of this root node
231  for(const Key& j: subtree->conditional()->frontals()) {
232  bool inserted = nodes_.insert(std::make_pair(j, subtree)).second;
233  assert(inserted); (void)inserted;
234  }
235  // Fill index for each child
237  for(const sharedClique& child: subtree->children) {
238  fillNodesIndex(child); }
239  }
240 
241  /* ************************************************************************* */
242  template<class CLIQUE>
244  roots_.push_back(subtree); // Add to roots
245  fillNodesIndex(subtree); // Populate nodes index
246  }
247 
248  /* ************************************************************************* */
249  // First finds clique marginal then marginalizes that
250  /* ************************************************************************* */
251  template<class CLIQUE>
254  {
255  gttic(BayesTree_marginalFactor);
256 
257  // get clique containing Key j
258  sharedClique clique = this->clique(j);
259 
260  // calculate or retrieve its marginal P(C) = P(F,S)
261  FactorGraphType cliqueMarginal = clique->marginal2(function);
262 
263  // Now, marginalize out everything that is not variable j
264  BayesNetType marginalBN = *cliqueMarginal.marginalMultifrontalBayesNet(
265  Ordering(cref_list_of<1,Key>(j)), function);
266 
267  // The Bayes net should contain only one conditional for variable j, so return it
268  return marginalBN.front();
269  }
270 
271  /* ************************************************************************* */
272  // Find two cliques, their joint, then marginalizes
273  /* ************************************************************************* */
274  template<class CLIQUE>
276  BayesTree<CLIQUE>::joint(Key j1, Key j2, const Eliminate& function) const
277  {
278  gttic(BayesTree_joint);
279  return boost::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
280  }
281 
282  /* ************************************************************************* */
283  template<class CLIQUE>
285  BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
286  {
287  gttic(BayesTree_jointBayesNet);
288  // get clique C1 and C2
289  sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
290 
291  gttic(Lowest_common_ancestor);
292  // Find lowest common ancestor clique
293  sharedClique B; {
294  // Build two paths to the root
295  FastList<sharedClique> path1, path2; {
296  sharedClique p = C1;
297  while(p) {
298  path1.push_front(p);
299  p = p->parent();
300  }
301  } {
302  sharedClique p = C2;
303  while(p) {
304  path2.push_front(p);
305  p = p->parent();
306  }
307  }
308  // Find the path intersection
309  typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
310  if(*p1 == *p2)
311  B = *p1;
312  while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
313  B = *p1;
314  ++p1;
315  ++p2;
316  }
317  }
318  gttoc(Lowest_common_ancestor);
319 
320  // Build joint on all involved variables
321  FactorGraphType p_BC1C2;
322 
323  if(B)
324  {
325  // Compute marginal on lowest common ancestor clique
326  gttic(LCA_marginal);
327  FactorGraphType p_B = B->marginal2(function);
328  gttoc(LCA_marginal);
329 
330  // Compute shortcuts of the requested cliques given the lowest common ancestor
331  gttic(Clique_shortcuts);
332  BayesNetType p_C1_Bred = C1->shortcut(B, function);
333  BayesNetType p_C2_Bred = C2->shortcut(B, function);
334  gttoc(Clique_shortcuts);
335 
336  // Factor the shortcuts to be conditioned on the full root
337  // Get the set of variables to eliminate, which is C1\B.
338  gttic(Full_root_factoring);
339  boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
340  KeyVector C1_minus_B; {
341  KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
342  for(const Key j: *B->conditional()) {
343  C1_minus_B_set.erase(j); }
344  C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
345  }
346  // Factor into C1\B | B.
347  sharedFactorGraph temp_remaining;
348  boost::tie(p_C1_B, temp_remaining) =
349  FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function);
350  }
351  boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
352  KeyVector C2_minus_B; {
353  KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
354  for(const Key j: *B->conditional()) {
355  C2_minus_B_set.erase(j); }
356  C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
357  }
358  // Factor into C2\B | B.
359  sharedFactorGraph temp_remaining;
360  boost::tie(p_C2_B, temp_remaining) =
361  FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function);
362  }
363  gttoc(Full_root_factoring);
364 
365  gttic(Variable_joint);
366  p_BC1C2 += p_B;
367  p_BC1C2 += *p_C1_B;
368  p_BC1C2 += *p_C2_B;
369  if(C1 != B)
370  p_BC1C2 += C1->conditional();
371  if(C2 != B)
372  p_BC1C2 += C2->conditional();
373  gttoc(Variable_joint);
374  }
375  else
376  {
377  // The nodes have no common ancestor, they're in different trees, so they're joint is just the
378  // product of their marginals.
379  gttic(Disjoint_marginals);
380  p_BC1C2 += C1->marginal2(function);
381  p_BC1C2 += C2->marginal2(function);
382  gttoc(Disjoint_marginals);
383  }
384 
385  // now, marginalize out everything that is not variable j1 or j2
386  return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), function);
387  }
388 
389  /* ************************************************************************* */
390  template<class CLIQUE>
392  // Remove all nodes and clear the root pointer
393  nodes_.clear();
394  roots_.clear();
395  }
396 
397  /* ************************************************************************* */
398  template<class CLIQUE>
400  for(const sharedClique& root: roots_) {
401  root->deleteCachedShortcuts();
402  }
403  }
404 
405  /* ************************************************************************* */
406  template<class CLIQUE>
408  {
409  if (clique->isRoot()) {
410  typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
411  if(root != roots_.end())
412  roots_.erase(root);
413  } else { // detach clique from parent
414  sharedClique parent = clique->parent_.lock();
415  typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
416  assert(child != parent->children.end());
417  parent->children.erase(child);
418  }
419 
420  // orphan my children
421  for(sharedClique child: clique->children)
422  child->parent_ = typename Clique::weak_ptr();
423 
424  for(Key j: clique->conditional()->frontals()) {
425  nodes_.unsafe_erase(j);
426  }
427  }
428 
429  /* ************************************************************************* */
430  template <class CLIQUE>
432  Cliques* orphans) {
433  // base case is nullptr, if so we do nothing and return empties above
434  if (clique) {
435  // remove the clique from orphans in case it has been added earlier
436  orphans->remove(clique);
437 
438  // remove me
439  this->removeClique(clique);
440 
441  // remove path above me
442  this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn,
443  orphans);
444 
445  // add children to list of orphans (splice also removed them from
446  // clique->children_)
447  orphans->insert(orphans->begin(), clique->children.begin(),
448  clique->children.end());
449  clique->children.clear();
450 
451  bn->push_back(clique->conditional_);
452  }
453  }
454 
455  /* *************************************************************************
456  */
457  template <class CLIQUE>
459  Cliques* orphans) {
460  gttic(removetop);
461  // process each key of the new factor
462  for (const Key& j : keys) {
463  // get the clique
464  // TODO(frank): Nodes will be searched again in removeClique
465  typename Nodes::const_iterator node = nodes_.find(j);
466  if (node != nodes_.end()) {
467  // remove path from clique to root
468  this->removePath(node->second, bn, orphans);
469  }
470  }
471 
472  // Delete cachedShortcuts for each orphan subtree
473  // TODO(frank): Consider Improving
474  for (sharedClique& orphan : *orphans) orphan->deleteCachedShortcuts();
475  }
476 
477  /* ************************************************************************* */
478  template<class CLIQUE>
480  const sharedClique& subtree)
481  {
482  // Result clique list
483  Cliques cliques;
484  cliques.push_back(subtree);
485 
486  // Remove the first clique from its parents
487  if(!subtree->isRoot())
488  subtree->parent()->children.erase(std::find(
489  subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
490  else
491  roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
492 
493  // Add all subtree cliques and erase the children and parent of each
494  for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
495  {
496  // Add children
497  for(const sharedClique& child: (*clique)->children) {
498  cliques.push_back(child); }
499 
500  // Delete cached shortcuts
501  (*clique)->deleteCachedShortcutsNonRecursive();
502 
503  // Remove this node from the nodes index
504  for(Key j: (*clique)->conditional()->frontals()) {
505  nodes_.unsafe_erase(j); }
506 
507  // Erase the parent and children pointers
508  (*clique)->parent_.reset();
509  (*clique)->children.clear();
510  }
511 
512  return cliques;
513  }
514 
515 }
void removePath(sharedClique clique, BayesNetType *bn, Cliques *orphans)
boost::shared_ptr< Clique > sharedClique
Shared pointer to a clique.
Definition: BayesTree.h:72
void removeTop(const KeyVector &keys, BayesNetType *bn, Cliques *orphans)
FactorGraphType::Eliminate Eliminate
Definition: BayesTree.h:83
Vector v2
This & operator=(const This &other)
Vector v1
bool equals(const This &other, double tol=1e-9) const
Vector3f p1
void addClique(const sharedClique &clique, const sharedClique &parent_clique=sharedClique())
Scalar Scalar * c
Definition: benchVecAdd.cpp:17
void addFactorsToGraph(FactorGraph< FactorType > *graph) const
Key findParentClique(const CONTAINER &parents) const
const mpreal root(const mpreal &x, unsigned long int k, mp_rnd_t r=mpreal::get_default_rnd())
Definition: mpreal.h:2194
bool check_sharedCliques(const std::pair< Key, typename BayesTree< CLIQUE >::sharedClique > &v1, const std::pair< Key, typename BayesTree< CLIQUE >::sharedClique > &v2)
void fillNodesIndex(const sharedClique &subtree)
bool stats
void insertRoot(const sharedClique &subtree)
sharedConditional marginalFactor(Key j, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
void DepthFirstForest(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost)
void deleteCachedShortcuts()
FactorGraph< FACTOR > * graph
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
#define gttic(label)
Definition: timing.h:280
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
Cliques removeSubtree(const sharedClique &subtree)
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
constexpr int first(int i)
Implementation details for constexpr functions.
FastVector< std::size_t > conditionalSizes
Definition: BayesTree.h:47
FastVector< std::size_t > separatorSizes
Definition: BayesTree.h:48
void saveGraph(const std::string &s, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
int data[]
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
RealScalar s
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
BayesTreeCliqueData getCliqueData() const
CLIQUE::FactorGraphType FactorGraphType
Definition: BayesTree.h:81
Bayes Tree is a tree of cliques of a Bayes Chain.
traits
Definition: chartTesting.h:28
size_t numCachedSeparatorMarginals() const
#define gttoc(label)
Definition: timing.h:281
boost::shared_ptr< FactorGraphType > sharedFactorGraph
Definition: BayesTree.h:82
boost::shared_ptr< BayesNetType > sharedBayesNet
Definition: BayesTree.h:78
void removeClique(sharedClique clique)
size_t size() const
float * p
static Point3 p2
CLIQUE::BayesNetType BayesNetType
Definition: BayesTree.h:77
boost::shared_ptr< ConditionalType > sharedConditional
Definition: BayesTree.h:76
const G double tol
Definition: Group.h:83
const KeyVector keys
bool equal(const T &obj1, const T &obj2, double tol)
Definition: Testable.h:83
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:61
std::ptrdiff_t j
void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Timing utilities.
sharedFactorGraph joint(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
std::string sep
Definition: IOFormat.cpp:1


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:41:41