12 #include <boost/range/adaptor/map.hpp> 13 #include <boost/assign/list_of.hpp> 18 using namespace boost;
20 using namespace gtsam;
39 const CorrectedBeliefIndices& _beliefIndices,
41 star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
45 void print(
const std::string&
s =
"")
const {
46 cout <<
s <<
":" << endl;
47 star->print(
"Star graph: ");
48 for(
Key key: correctedBeliefIndices | boost::adaptors::map_keys) {
49 cout <<
"Belief factor index for " <<
key <<
": " 50 << correctedBeliefIndices.at(
key) << endl;
53 unary->print(
"Unary: ");
66 const std::map<Key, DiscreteKey>& allDiscreteKeys) :
67 starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
71 void print(
const std::string&
s =
"")
const {
72 cout <<
s <<
":" << endl;
73 for(
Key key: starGraphs_ | boost::adaptors::map_keys) {
74 starGraphs_.at(
key).print((boost::format(
"Node %d:") %
key).
str());
80 const std::map<Key, DiscreteKey>& allDiscreteKeys) {
81 static const bool debug =
false;
84 std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
86 for(
Key key: starGraphs_ | boost::adaptors::map_keys) {
92 std::map<Key, DiscreteFactor::shared_ptr> messages;
95 for(
Key neighbor: starGraphs_.at(
key).correctedBeliefIndices | boost::adaptors::map_keys) {
97 for(
size_t factor: starGraphs_.at(
key).varIndex_[neighbor]) {
98 subGraph.
push_back(starGraphs_.at(
key).star->at(factor));
100 if (debug) subGraph.
print(
"------- Subgraph:");
105 messages.insert(make_pair(neighbor, message));
106 if (debug) message->print(
"------- Message: ");
115 boost::make_shared<DecisionTreeFactor>(
120 if (starGraphs_.at(
key).unary)
121 beliefAtKey = boost::make_shared<DecisionTreeFactor>(
122 (*beliefAtKey) * (*starGraphs_.at(
key).unary));
123 if (debug) beliefAtKey->
print(
"New belief at key: ");
126 for (
size_t v = 0;
v < allDiscreteKeys.at(
key).second; ++
v) {
129 sum += (*beliefAtKey)(val);
131 string sumFactorTable;
132 for (
size_t v = 0;
v < allDiscreteKeys.at(
key).second; ++
v)
133 sumFactorTable = (boost::format(
"%s %f") % sumFactorTable %
sum).
str();
135 if (debug) sumFactor.
print(
"denomFactor: ");
136 beliefAtKey = boost::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
137 if (debug) beliefAtKey->print(
"New belief at key normalized: ");
138 beliefs->push_back(beliefAtKey);
139 allMessages[
key] = messages;
144 for(
Key key: starGraphs_ | boost::adaptors::map_keys) {
145 std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[
key];
146 for(
Key neighbor: starGraphs_.at(
key).correctedBeliefIndices | boost::adaptors::map_keys) {
150 messages.at(neighbor)));
151 if (debug) correctedBelief.
print(
"correctedBelief: ");
152 size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
154 starGraphs_.at(neighbor).star->replace(beliefIndex,
155 boost::make_shared<DecisionTreeFactor>(correctedBelief));
159 if (debug)
print(
"After update: ");
169 const std::map<Key, DiscreteKey>& allDiscreteKeys)
const {
170 StarGraphs starGraphs;
172 for(
Key key: varIndex | boost::adaptors::map_keys) {
178 for(
size_t factorIndex: varIndex[
key]) {
179 star->push_back(graph.
at(factorIndex));
182 if (graph.
at(factorIndex)->size() == 1) {
185 graph.
at(factorIndex));
187 prodOfUnaries = boost::make_shared<DecisionTreeFactor>(
190 graph.
at(factorIndex))));
196 KeySet neighbors = star->keys();
197 neighbors.erase(key);
198 CorrectedBeliefIndices correctedBeliefIndices;
199 for(
Key neighbor: neighbors) {
201 string initialBelief;
202 for (
size_t v = 0;
v < allDiscreteKeys.at(neighbor).second - 1; ++
v) {
203 initialBelief = initialBelief +
"0.0 ";
205 initialBelief = initialBelief +
"1.0";
208 correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
212 StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
226 std::map<Key, DiscreteKey> allKeys = map_list_of(0,
C)(1,
S)(2,
R)(3,
W);
240 graph.
print(
"graph: ");
243 solver.
print(
"Loopy belief: ");
247 cout <<
"==================================" << endl;
248 cout <<
"iteration: " <<
iter << endl;
void print(const Matrix &A, const string &s, ostream &stream)
std::map< Key, StarGraph > StarGraphs
static int runAllTests(TestResult &result)
DiscreteFactorGraph::shared_ptr iterate(const std::map< Key, DiscreteKey > &allDiscreteKeys)
One step of belief propagation.
std::map< Key, size_t > CorrectedBeliefIndices
Rot2 R(Rot2::fromAngle(0.1))
iterator iter(handle obj)
BiCGSTAB< SparseMatrix< double > > solver
IsDerived< DERIVEDFACTOR > push_back(boost::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
NonlinearFactorGraph graph
boost::shared_ptr< This > shared_ptr
shared_ptr to this class
DiscreteFactorGraph::shared_ptr star
std::pair< Key, size_t > DiscreteKey
void print(const std::string &s="") const
StarGraph(const DiscreteFactorGraph::shared_ptr &_star, const CorrectedBeliefIndices &_beliefIndices, const DecisionTreeFactor::shared_ptr &_unary)
const mpreal sum(const mpreal tab[], const unsigned long int n, int &status, mp_rnd_t mode=mpreal::get_default_rnd())
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
LoopyBelief(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys)
boost::shared_ptr< This > shared_ptr
shared_ptr to this class
boost::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Matrix< Scalar, Dynamic, Dynamic > C
const sharedFactor at(size_t i) const
DecisionTreeFactor::shared_ptr unary
boost::shared_ptr< DecisionTreeFactor > shared_ptr
void print(const std::string &s="DiscreteFactorGraph", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
CorrectedBeliefIndices correctedBeliefIndices
TEST_UNSAFE(LoopyBelief, construction)
StarGraphs buildStarGraphs(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys) const
void print(const std::string &s="DecisionTreeFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
StarGraphs starGraphs_
star graph at each variable
void print(const std::string &s="") const
print
std::uint64_t Key
Integer nonlinear key type.