18 using namespace boost;
19 using namespace gtsam;
39 const CorrectedBeliefIndices& _beliefIndices,
42 correctedBeliefIndices(_beliefIndices),
46 void print(
const std::string&
s =
"")
const {
47 cout <<
s <<
":" << endl;
48 star->print(
"Star graph: ");
49 for (
const auto& [
key,
_] : correctedBeliefIndices) {
50 cout <<
"Belief factor index for " <<
key <<
": " 51 << correctedBeliefIndices.at(
key) << endl;
53 if (unary) unary->
print(
"Unary: ");
67 const std::map<Key, DiscreteKey>& allDiscreteKeys)
68 : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {}
71 void print(
const std::string&
s =
"")
const {
72 cout <<
s <<
":" << endl;
73 for (
const auto& [
key,
_] : starGraphs_) {
74 starGraphs_.at(
key).print(
"Node " + std::to_string(
key) +
":");
80 const std::map<Key, DiscreteKey>& allDiscreteKeys) {
81 static const bool debug =
false;
83 std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
85 for (
const auto& [
key,
_] : starGraphs_) {
91 std::map<Key, DiscreteFactor::shared_ptr> messages;
94 for (
const auto& [neighbor,
_] : starGraphs_.at(
key).correctedBeliefIndices) {
96 for (
size_t factor : starGraphs_.at(
key).varIndex_[neighbor]) {
97 subGraph.
push_back(starGraphs_.at(
key).star->at(factor));
99 if (debug) subGraph.
print(
"------- Subgraph:");
100 const auto [dummyCond, message] =
103 messages.insert(make_pair(neighbor, message));
104 if (debug) message->print(
"------- Message: ");
112 beliefAtKey = std::make_shared<DecisionTreeFactor>(
116 if (starGraphs_.at(
key).unary)
117 beliefAtKey = std::make_shared<DecisionTreeFactor>(
118 (*beliefAtKey) * (*starGraphs_.at(
key).unary));
119 if (debug) beliefAtKey->
print(
"New belief at key: ");
122 for (
size_t v = 0;
v < allDiscreteKeys.at(
key).second; ++
v) {
125 sum += (*beliefAtKey)(val);
128 string sumFactorTable;
129 for (
size_t v = 0;
v < allDiscreteKeys.at(
key).second; ++
v) {
130 sumFactorTable = sumFactorTable +
" " + std::to_string(sum);
133 if (debug) sumFactor.
print(
"denomFactor: ");
135 std::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
136 if (debug) beliefAtKey->print(
"New belief at key normalized: ");
137 beliefs->push_back(beliefAtKey);
138 allMessages[
key] = messages;
143 for (
const auto& [
key,
_] : starGraphs_) {
144 std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[
key];
145 for (
const auto& [neighbor,
_] : starGraphs_.at(
key).correctedBeliefIndices) {
148 beliefs->at(beliefFactors[
key].front()))) /
150 messages.at(neighbor)));
151 if (debug) correctedBelief.
print(
"correctedBelief: ");
153 starGraphs_.at(neighbor).correctedBeliefIndices.at(
key);
154 starGraphs_.at(neighbor).star->replace(
156 std::make_shared<DecisionTreeFactor>(correctedBelief));
160 if (debug)
print(
"After update: ");
171 const std::map<Key, DiscreteKey>& allDiscreteKeys)
const {
172 StarGraphs starGraphs;
174 for (
const auto& [
key,
_] : varIndex) {
180 for (
size_t factorIndex : varIndex[
key]) {
181 star->push_back(graph.
at(factorIndex));
184 if (graph.
at(factorIndex)->size() == 1) {
187 graph.
at(factorIndex));
189 prodOfUnaries = std::make_shared<DecisionTreeFactor>(
192 graph.
at(factorIndex))));
198 KeySet neighbors = star->keys();
199 neighbors.erase(key);
200 CorrectedBeliefIndices correctedBeliefIndices;
201 for (
Key neighbor : neighbors) {
203 string initialBelief;
204 for (
size_t v = 0;
v < allDiscreteKeys.at(neighbor).second - 1; ++
v) {
205 initialBelief = initialBelief +
"0.0 ";
207 initialBelief = initialBelief +
"1.0";
210 correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
212 starGraphs.insert(make_pair(
213 key,
StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
227 std::map<Key, DiscreteKey> allKeys{{0,
C}, {1,
S}, {2,
R}, {3, W}};
241 graph.
print(
"graph: ");
244 solver.
print(
"Loopy belief: ");
248 cout <<
"==================================" << endl;
249 cout <<
"iteration: " <<
iter << endl;
std::map< Key, StarGraph > StarGraphs
const gtsam::Symbol key('X', 0)
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
static int runAllTests(TestResult &result)
DiscreteFactorGraph::shared_ptr iterate(const std::map< Key, DiscreteKey > &allDiscreteKeys)
One step of belief propagation.
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
std::map< Key, size_t > CorrectedBeliefIndices
Rot2 R(Rot2::fromAngle(0.1))
iterator iter(handle obj)
BiCGSTAB< SparseMatrix< double > > solver
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
NonlinearFactorGraph graph
static constexpr bool debug
DiscreteFactorGraph::shared_ptr star
void print(const std::string &s="") const
Print.
std::shared_ptr< This > shared_ptr
shared_ptr to This
void print(const std::string &s="") const
Array< int, Dynamic, 1 > v
void print(const std::string &s="") const
print
StarGraph(const DiscreteFactorGraph::shared_ptr &_star, const CorrectedBeliefIndices &_beliefIndices, const DecisionTreeFactor::shared_ptr &_unary)
LoopyBelief(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys)
void print(const std::string &s="DecisionTreeFactor:\, const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Matrix< Scalar, Dynamic, Dynamic > C
StarGraphs buildStarGraphs(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys) const
std::shared_ptr< DecisionTreeFactor > shared_ptr
constexpr descr< N - 1 > _(char const (&text)[N])
DecisionTreeFactor::shared_ptr unary
const sharedFactor at(size_t i) const
std::pair< Key, size_t > DiscreteKey
void print(const std::string &s="DiscreteFactorGraph", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
CorrectedBeliefIndices correctedBeliefIndices
TEST_UNSAFE(LoopyBelief, construction)
StarGraphs starGraphs_
star graph at each variable
std::uint64_t Key
Integer nonlinear key type.