18 using namespace boost;
19 using namespace gtsam;
42 correctedBeliefIndices(_beliefIndices),
46 void print(
const std::string&
s =
"")
const {
47 cout <<
s <<
":" << endl;
48 star->print(
"Star graph: ");
49 for (
const auto& [
key,
_] : correctedBeliefIndices) {
50 cout <<
"Belief factor index for " <<
key <<
": "
51 << correctedBeliefIndices.at(
key) << endl;
67 const std::map<Key, DiscreteKey>& allDiscreteKeys)
68 : starGraphs_(buildStarGraphs(
graph, allDiscreteKeys)) {}
71 void print(
const std::string&
s =
"")
const {
72 cout <<
s <<
":" << endl;
73 for (
const auto& [
key,
_] : starGraphs_) {
74 starGraphs_.at(
key).print(
"Node " + std::to_string(
key) +
":");
80 const std::map<Key, DiscreteKey>& allDiscreteKeys) {
81 static const bool debug =
false;
83 std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
85 for (
const auto& [
key,
_] : starGraphs_) {
91 std::map<Key, DiscreteFactor::shared_ptr> messages;
94 for (
const auto& [neighbor,
_] : starGraphs_.at(
key).correctedBeliefIndices) {
96 for (
size_t factor : starGraphs_.at(
key).varIndex_[neighbor]) {
97 subGraph.
push_back(starGraphs_.at(
key).star->at(factor));
99 if (
debug) subGraph.
print(
"------- Subgraph:");
100 const auto [dummyCond, message] =
103 messages.insert(make_pair(neighbor, message));
104 if (
debug) message->print(
"------- Message: ");
110 std::dynamic_pointer_cast<DecisionTreeFactor>(message);
112 beliefAtKey = std::make_shared<DecisionTreeFactor>(
114 (*std::dynamic_pointer_cast<DecisionTreeFactor>(message)));
116 if (starGraphs_.at(
key).unary)
117 beliefAtKey = std::make_shared<DecisionTreeFactor>(
118 (*beliefAtKey) * (*starGraphs_.at(
key).unary));
119 if (
debug) beliefAtKey->print(
"New belief at key: ");
122 for (
size_t v = 0;
v < allDiscreteKeys.at(
key).second; ++
v) {
125 sum += (*beliefAtKey)(val);
128 string sumFactorTable;
129 for (
size_t v = 0;
v < allDiscreteKeys.at(
key).second; ++
v) {
130 sumFactorTable = sumFactorTable +
" " + std::to_string(sum);
135 std::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
136 if (
debug) beliefAtKey->print(
"New belief at key normalized: ");
137 beliefs->push_back(beliefAtKey);
138 allMessages[
key] = messages;
143 for (
const auto& [
key,
_] : starGraphs_) {
144 std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[
key];
145 for (
const auto& [neighbor,
_] : starGraphs_.at(
key).correctedBeliefIndices) {
147 (*std::dynamic_pointer_cast<DecisionTreeFactor>(
148 beliefs->at(beliefFactors[
key].front()))) /
149 (*std::dynamic_pointer_cast<DecisionTreeFactor>(
150 messages.at(neighbor)));
151 if (
debug) correctedBelief.
print(
"correctedBelief: ");
153 starGraphs_.at(neighbor).correctedBeliefIndices.at(
key);
154 starGraphs_.at(neighbor).star->replace(
156 std::make_shared<DecisionTreeFactor>(correctedBelief));
171 const std::map<Key, DiscreteKey>& allDiscreteKeys)
const {
174 for (
const auto& [
key,
_] : varIndex) {
180 for (
size_t factorIndex : varIndex[
key]) {
181 star->push_back(
graph.
at(factorIndex));
184 if (
graph.
at(factorIndex)->size() == 1) {
188 prodOfUnaries = std::make_shared<DecisionTreeFactor>(
195 KeySet neighbors = star->keys();
196 neighbors.erase(
key);
198 for (
Key neighbor : neighbors) {
200 string initialBelief;
201 for (
size_t v = 0;
v < allDiscreteKeys.at(neighbor).second - 1; ++
v) {
202 initialBelief = initialBelief +
"0.0 ";
204 initialBelief = initialBelief +
"1.0";
207 correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
209 starGraphs.insert(make_pair(
210 key,
StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
224 std::map<Key, DiscreteKey> allKeys{{0,
C}, {1,
S}, {2,
R}, {3,
W}};
241 solver.print(
"Loopy belief: ");
245 cout <<
"==================================" << endl;
246 cout <<
"iteration: " <<
iter << endl;