31 #include <unordered_set>
36 template<
class CLIQUE>
44 template <
class CLIQUE>
47 const auto conditional = clique->conditional();
48 stats->conditionalSizes.push_back(conditional->nrFrontals());
49 stats->separatorSizes.push_back(conditional->nrParents());
56 template<
class CLIQUE>
60 count += root->numCachedSeparatorMarginals();
65 template <
class CLIQUE>
69 throw std::invalid_argument(
70 "the root of Bayes tree has not been initialized!");
78 template <
class CLIQUE>
81 dot(
ss, keyFormatter);
86 template <
class CLIQUE>
90 dot(of, keyFormatter);
95 template <
class CLIQUE>
98 int parentnum)
const {
101 std::stringstream
out;
103 std::string parent =
out.str();
104 parent +=
"[label=\"";
106 for (
Key key : clique->conditional_->frontals()) {
107 if (!first) parent +=
", ";
109 parent += keyFormatter(
key);
112 if (clique->parent()) {
114 s << parentnum <<
"->" << num <<
"\n";
118 for (
Key parentKey : clique->conditional_->parents()) {
119 if (!first) parent +=
", ";
121 parent += keyFormatter(parentKey);
129 dot(
s,
c, keyFormatter, parentnum);
134 template<
class CLIQUE>
138 size += clique->treeSize();
143 template<
class CLIQUE>
145 for(
Key j: clique->conditional()->frontals())
147 if (parent_clique !=
nullptr) {
148 clique->parent_ = parent_clique;
149 parent_clique->children.push_back(clique);
151 roots_.push_back(clique);
157 template <
class FACTOR,
class CLIQUE>
158 struct _pushCliqueFunctor {
162 graph->push_back(clique->conditional_);
169 template <
class CLIQUE>
174 _pushCliqueFunctor<FactorType, CLIQUE> functor(
graph);
179 template<
class CLIQUE>
191 template<
class CLIQUE>
198 for (
auto&& root: roots_) {
199 std::queue<sharedClique> bfs_queue;
202 bfs_queue.push(std::move(root));
207 while (!bfs_queue.empty()) {
209 auto current = std::move(bfs_queue.front());
213 for (
auto child: current->children) {
214 bfs_queue.push(std::move(child));
224 template<
typename NODE>
225 std::shared_ptr<NODE>
226 BayesTreeCloneForestVisitorPre(
const std::shared_ptr<NODE>& node,
const std::shared_ptr<NODE>& parentPointer)
229 std::shared_ptr<NODE> clone = std::make_shared<NODE>(*node);
230 clone->children.clear();
231 clone->parent_ = parentPointer;
232 parentPointer->children.push_back(clone);
238 template<
class CLIQUE>
241 std::shared_ptr<Clique> rootContainer = std::make_shared<Clique>();
243 for(
const sharedClique& root: rootContainer->children) {
251 template<
class CLIQUE>
253 std::cout <<
s <<
": cliques: " <<
size() <<
", variables: " << nodes_.size() << std::endl;
259 template<
class CLIQUE>
264 return v1.first ==
v2.first &&
265 ((!
v1.second && !
v2.second) || (
v1.second &&
v2.second &&
v1.second->equals(*
v2.second)));
269 template<
class CLIQUE>
272 std::equal(nodes_.begin(), nodes_.end(),
other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
276 template<
class CLIQUE>
277 template<
class CONTAINER>
279 typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
280 assert(lowestOrderedParent != parents.end());
281 return *lowestOrderedParent;
285 template<
class CLIQUE>
288 for(
const Key&
j: subtree->conditional()->frontals()) {
289 bool inserted = nodes_.insert({
j, subtree}).second;
290 assert(inserted); (void)inserted;
295 fillNodesIndex(child); }
299 template<
class CLIQUE>
301 roots_.push_back(subtree);
302 fillNodesIndex(subtree);
308 template<
class CLIQUE>
312 gttic(BayesTree_marginalFactor);
322 *cliqueMarginal.marginalMultifrontalBayesNet(
Ordering{
j},
function);
325 return marginalBN.front();
331 template<
class CLIQUE>
335 gttic(BayesTree_joint);
336 return std::make_shared<FactorGraphType>(*jointBayesNet(
j1, j2,
function));
342 template <
class CLIQUE>
344 const std::shared_ptr<CLIQUE>&
C1,
const std::shared_ptr<CLIQUE>&
C2) {
346 std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
347 for (
auto p =
C1;
p;
p =
p->parent()) {
352 std::shared_ptr<CLIQUE>
B;
353 for (
auto p =
C2;
p;
p =
p->parent()) {
354 if (ancestors.count(
p)) {
365 template <
class CLIQUE>
367 const std::shared_ptr<CLIQUE>& p_F_S,
const std::shared_ptr<CLIQUE>&
B,
368 const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
369 gttic(Full_root_factoring);
372 auto p_S_B = p_F_S->shortcut(
B, eliminate);
375 KeyVector S_setminus_B = p_F_S->separator_setminus_B(
B);
379 typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
385 template <
class CLIQUE>
388 gttic(BayesTree_jointBayesNet);
406 p_BC1C2.push_back(p_B);
407 p_BC1C2.push_back(*p_C1_B);
408 p_BC1C2.push_back(*p_C2_B);
409 if (
C1 !=
B) p_BC1C2.push_back(
C1->conditional());
410 if (
C2 !=
B) p_BC1C2.push_back(
C2->conditional());
414 p_BC1C2.push_back(
C1->marginal2(eliminate));
415 p_BC1C2.push_back(
C2->marginal2(eliminate));
419 return p_BC1C2.marginalMultifrontalBayesNet(
Ordering{
j1, j2}, eliminate);
423 template<
class CLIQUE>
431 template<
class CLIQUE>
434 root->deleteCachedShortcuts();
439 template<
class CLIQUE>
442 if (clique->isRoot()) {
443 typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
444 if(root != roots_.end())
448 typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
449 assert(child != parent->children.end());
450 parent->children.erase(child);
457 for(
Key j: clique->conditional()->frontals()) {
458 nodes_.unsafe_erase(
j);
463 template <
class CLIQUE>
469 orphans->remove(clique);
472 this->removeClique(clique);
480 orphans->insert(orphans->begin(), clique->children.begin(),
481 clique->children.end());
482 clique->children.clear();
484 bn->push_back(clique->conditional_);
490 template <
class CLIQUE>
498 typename Nodes::const_iterator node = nodes_.find(
j);
499 if (node != nodes_.end()) {
501 this->removePath(node->second, bn, orphans);
507 for (
sharedClique& orphan : *orphans) orphan->deleteCachedShortcuts();
511 template<
class CLIQUE>
517 cliques.push_back(subtree);
520 if(!subtree->isRoot())
521 subtree->parent()->children.erase(std::find(
522 subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
524 roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
527 for(
typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
531 cliques.push_back(child); }
534 (*clique)->deleteCachedShortcutsNonRecursive();
537 for(
Key j: (*clique)->conditional()->frontals()) {
538 nodes_.unsafe_erase(
j); }
541 (*clique)->parent_.reset();
542 (*clique)->children.clear();