54 template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
58 using std::dynamic_pointer_cast;
63 const std::shared_ptr<Factor> &
f) {
65 throw std::runtime_error(
s +
" not implemented for factor type " +
74 index,
KeyVector(discrete_keys.begin(), discrete_keys.end()),
true);
81 const std::function<
bool(
const Factor * ,
83 &printCondition)
const {
84 std::cout <<
str <<
"size: " <<
size() << std::endl << std::endl;
90 std::cout <<
"Factor " <<
i <<
": ";
93 ss.str(std::string());
95 if (
auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
96 if (factor ==
nullptr) {
97 std::cout <<
"nullptr"
100 factor->print(
ss.str(), keyFormatter);
101 std::cout <<
"error = ";
102 gmf->errorTree(
values.continuous()).print(
"", keyFormatter);
103 std::cout << std::endl;
105 }
else if (
auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
106 if (factor ==
nullptr) {
107 std::cout <<
"nullptr"
110 factor->print(
ss.str(), keyFormatter);
112 if (
hc->isContinuous()) {
113 std::cout <<
"error = " <<
hc->asGaussian()->error(
values) <<
"\n";
114 }
else if (
hc->isDiscrete()) {
115 std::cout <<
"error = ";
116 hc->asDiscrete()->errorTree().print(
"", keyFormatter);
120 std::cout <<
"error = ";
121 hc->asMixture()->errorTree(
values.continuous()).print();
125 }
else if (
auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
126 const double errorValue = (factor !=
nullptr ? gf->error(
values) : .0);
127 if (!printCondition(factor.get(), errorValue,
i))
130 if (factor ==
nullptr) {
131 std::cout <<
"nullptr"
134 factor->print(
ss.str(), keyFormatter);
135 std::cout <<
"error = " << errorValue <<
"\n";
137 }
else if (
auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
138 if (factor ==
nullptr) {
139 std::cout <<
"nullptr"
142 factor->print(
ss.str(), keyFormatter);
143 std::cout <<
"error = ";
144 df->errorTree().print(
"", keyFormatter);
161 if (gfgTree.
empty()) {
182 if (
auto gf = dynamic_pointer_cast<GaussianFactor>(
f)) {
184 }
else if (
auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(
f)) {
186 }
else if (
auto gm = dynamic_pointer_cast<GaussianMixture>(
f)) {
188 }
else if (
auto hc = dynamic_pointer_cast<HybridConditional>(
f)) {
189 if (
auto gm =
hc->asMixture()) {
191 }
else if (
auto g =
hc->asGaussian()) {
198 }
else if (dynamic_pointer_cast<DiscreteFactor>(
f)) {
213 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
218 if (
auto gf = dynamic_pointer_cast<GaussianFactor>(
f)) {
220 }
else if (
auto orphan = dynamic_pointer_cast<OrphanWrapper>(
f)) {
223 }
else if (
auto hc = dynamic_pointer_cast<HybridConditional>(
f)) {
224 auto gc =
hc->asGaussian();
233 return {std::make_shared<HybridConditional>(
result.first),
result.second};
237 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
243 if (
auto df = dynamic_pointer_cast<DiscreteFactor>(
f)) {
245 }
else if (
auto orphan = dynamic_pointer_cast<OrphanWrapper>(
f)) {
248 }
else if (
auto hc = dynamic_pointer_cast<HybridConditional>(
f)) {
249 auto dc =
hc->asDiscrete();
260 return {std::make_shared<HybridConditional>(
result.first),
result.second};
279 using Result = std::pair<std::shared_ptr<GaussianConditional>,
288 auto probability = [&](
const Result &pair) ->
double {
289 const auto &[conditional, factor] = pair;
292 if (!factor)
return 1.0;
293 return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
298 return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
309 const auto &[conditional, factor] = pair;
311 auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
312 if (!hf)
throw std::runtime_error(
"Expected HessianFactor!");
313 hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
320 return std::make_shared<GaussianMixtureFactor>(continuousSeparator,
321 discreteSeparator, newFactors);
324 static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
328 const std::set<DiscreteKey> &discreteSeparatorSet) {
331 DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
332 discreteSeparatorSet.end());
346 return {
nullptr,
nullptr};
361 continuousSeparator.empty()
368 eliminationResults, [](
const Result &pair) {
return pair.first; });
369 auto gaussianMixture = std::make_shared<GaussianMixture>(
370 frontalKeys, continuousSeparator, discreteSeparator,
conditionals);
372 return {std::make_shared<HybridConditional>(gaussianMixture), newFactor};
389 std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
441 bool only_discrete =
true, only_continuous =
true;
442 for (
auto &&factor :
factors) {
443 if (
auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
444 if (hybrid_factor->isDiscrete()) {
445 only_continuous =
false;
446 }
else if (hybrid_factor->isContinuous()) {
447 only_discrete =
false;
448 }
else if (hybrid_factor->isHybrid()) {
449 only_continuous =
false;
450 only_discrete =
false;
452 }
else if (
auto cont_factor =
453 std::dynamic_pointer_cast<GaussianFactor>(factor)) {
454 only_discrete =
false;
455 }
else if (
auto discrete_factor =
456 std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
457 only_continuous =
false;
466 }
else if (only_continuous) {
471 KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
476 auto continuousKeySet =
factors.continuousKeySet();
478 continuousKeySet.begin(), continuousKeySet.end(),
479 frontalKeysSet.begin(), frontalKeysSet.end(),
480 std::inserter(continuousSeparator, continuousSeparator.begin()));
483 KeySet discreteSeparatorSet;
484 std::set<DiscreteKey> discreteSeparator;
485 auto discreteKeySet =
factors.discreteKeySet();
487 discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
488 frontalKeysSet.end(),
489 std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
491 auto discreteKeyMap =
factors.discreteKeyMap();
492 for (
auto key : discreteSeparatorSet) {
493 discreteSeparator.insert(discreteKeyMap.at(
key));
511 if (
auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(
f)) {
513 error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
514 }
else if (
auto gaussian = dynamic_pointer_cast<GaussianFactor>(
f)) {
517 double error = gaussian->error(continuousValues);
519 error_tree = error_tree.
apply(
520 [
error](
double leaf_value) {
return leaf_value +
error; });
521 }
else if (dynamic_pointer_cast<DiscreteFactor>(
f)) {