testLoopyBelief.cpp
Go to the documentation of this file.
1 
13 
14 #include <fstream>
15 #include <iostream>
16 
17 using namespace std;
18 using namespace boost;
19 using namespace gtsam;
20 
24 class LoopyBelief {
32  typedef std::map<Key, size_t> CorrectedBeliefIndices;
33  struct StarGraph {
39  const CorrectedBeliefIndices& _beliefIndices,
40  const DecisionTreeFactor::shared_ptr& _unary)
41  : star(_star),
42  correctedBeliefIndices(_beliefIndices),
43  unary(_unary),
44  varIndex_(*_star) {}
45 
46  void print(const std::string& s = "") const {
47  cout << s << ":" << endl;
48  star->print("Star graph: ");
49  for (const auto& [key, _] : correctedBeliefIndices) {
50  cout << "Belief factor index for " << key << ": "
51  << correctedBeliefIndices.at(key) << endl;
52  }
53  if (unary) unary->print("Unary: ");
54  }
55  };
56 
57  typedef std::map<Key, StarGraph> StarGraphs;
59 
60  public:
67  const std::map<Key, DiscreteKey>& allDiscreteKeys)
68  : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {}
69 
71  void print(const std::string& s = "") const {
72  cout << s << ":" << endl;
73  for (const auto& [key, _] : starGraphs_) {
74  starGraphs_.at(key).print("Node " + std::to_string(key) + ":");
75  }
76  }
77 
80  const std::map<Key, DiscreteKey>& allDiscreteKeys) {
81  static const bool debug = false;
83  std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
84  // Eliminate each star graph
85  for (const auto& [key, _] : starGraphs_) {
86  // cout << "***** Node " << key << "*****" << endl;
87  // initialize belief to the unary factor from the original graph
89 
90  // keep intermediate messages to divide later
91  std::map<Key, DiscreteFactor::shared_ptr> messages;
92 
93  // eliminate each neighbor in this star graph one by one
94  for (const auto& [neighbor, _] : starGraphs_.at(key).correctedBeliefIndices) {
95  DiscreteFactorGraph subGraph;
96  for (size_t factor : starGraphs_.at(key).varIndex_[neighbor]) {
97  subGraph.push_back(starGraphs_.at(key).star->at(factor));
98  }
99  if (debug) subGraph.print("------- Subgraph:");
100  const auto [dummyCond, message] =
101  EliminateDiscrete(subGraph, Ordering{neighbor});
102  // store the new factor into messages
103  messages.insert(make_pair(neighbor, message));
104  if (debug) message->print("------- Message: ");
105 
106  // Belief is the product of all messages and the unary factor
107  // Incorporate new the factor to belief
108  if (!beliefAtKey)
109  beliefAtKey =
110  std::dynamic_pointer_cast<DecisionTreeFactor>(message);
111  else
112  beliefAtKey = std::make_shared<DecisionTreeFactor>(
113  (*beliefAtKey) *
114  (*std::dynamic_pointer_cast<DecisionTreeFactor>(message)));
115  }
116  if (starGraphs_.at(key).unary)
117  beliefAtKey = std::make_shared<DecisionTreeFactor>(
118  (*beliefAtKey) * (*starGraphs_.at(key).unary));
119  if (debug) beliefAtKey->print("New belief at key: ");
120  // normalize belief
121  double sum = 0.0;
122  for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
123  DiscreteValues val;
124  val[key] = v;
125  sum += (*beliefAtKey)(val);
126  }
127  // TODO(kartikarcot): Check if this makes sense
128  string sumFactorTable;
129  for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
130  sumFactorTable = sumFactorTable + " " + std::to_string(sum);
131  }
132  DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
133  if (debug) sumFactor.print("denomFactor: ");
134  beliefAtKey =
135  std::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
136  if (debug) beliefAtKey->print("New belief at key normalized: ");
137  beliefs->push_back(beliefAtKey);
138  allMessages[key] = messages;
139  }
140 
141  // Update corrected beliefs
142  VariableIndex beliefFactors(*beliefs);
143  for (const auto& [key, _] : starGraphs_) {
144  std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
145  for (const auto& [neighbor, _] : starGraphs_.at(key).correctedBeliefIndices) {
146  DecisionTreeFactor correctedBelief =
147  (*std::dynamic_pointer_cast<DecisionTreeFactor>(
148  beliefs->at(beliefFactors[key].front()))) /
149  (*std::dynamic_pointer_cast<DecisionTreeFactor>(
150  messages.at(neighbor)));
151  if (debug) correctedBelief.print("correctedBelief: ");
152  size_t beliefIndex =
153  starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
154  starGraphs_.at(neighbor).star->replace(
155  beliefIndex,
156  std::make_shared<DecisionTreeFactor>(correctedBelief));
157  }
158  }
159 
160  if (debug) print("After update: ");
161 
162  return beliefs;
163  }
164 
165  private:
170  const DiscreteFactorGraph& graph,
171  const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
172  StarGraphs starGraphs;
173  VariableIndex varIndex(graph);
174  for (const auto& [key, _] : varIndex) {
175  // initialize to multiply with other unary factors later
176  DecisionTreeFactor::shared_ptr prodOfUnaries;
177 
178  // collect all factors involving this key in the original graph
180  for (size_t factorIndex : varIndex[key]) {
181  star->push_back(graph.at(factorIndex));
182 
183  // accumulate unary factors
184  if (graph.at(factorIndex)->size() == 1) {
185  if (!prodOfUnaries)
186  prodOfUnaries = graph.at<DecisionTreeFactor>(factorIndex);
187  else
188  prodOfUnaries = std::make_shared<DecisionTreeFactor>(
189  *prodOfUnaries * (*graph.at<DecisionTreeFactor>(factorIndex)));
190  }
191  }
192 
193  // add the belief factor for each neighbor variable to this star graph
194  // also record the factor index for later modification
195  KeySet neighbors = star->keys();
196  neighbors.erase(key);
197  CorrectedBeliefIndices correctedBeliefIndices;
198  for (Key neighbor : neighbors) {
199  // TODO: default table for keys with more than 2 values?
200  string initialBelief;
201  for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
202  initialBelief = initialBelief + "0.0 ";
203  }
204  initialBelief = initialBelief + "1.0";
205  star->push_back(
206  DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
207  correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
208  }
209  starGraphs.insert(make_pair(
210  key, StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
211  }
212  return starGraphs;
213  }
214 };
215 
216 /* ************************************************************************* */
217 
218 TEST_UNSAFE(LoopyBelief, construction) {
219  // Variables: Cloudy, Sprinkler, Rain, Wet
220  DiscreteKey C(0, 2), S(1, 2), R(2, 2), W(3, 2);
221 
222  // Map from key to DiscreteKey for building belief factor.
223  // TODO: this is bad!
224  std::map<Key, DiscreteKey> allKeys{{0, C}, {1, S}, {2, R}, {3, W}};
225 
226  // Build graph
227  DecisionTreeFactor pC(C, "0.5 0.5");
228  DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
229  DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
230  DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
231 
233  graph.push_back(pC);
234  graph.push_back(pSC);
235  graph.push_back(pRC);
236  graph.push_back(pSR);
237 
238  graph.print("graph: ");
239 
240  LoopyBelief solver(graph, allKeys);
241  solver.print("Loopy belief: ");
242 
243  // Main loop
244  for (size_t iter = 0; iter < 20; ++iter) {
245  cout << "==================================" << endl;
246  cout << "iteration: " << iter << endl;
247  DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys);
248  beliefs->print();
249  }
250 }
251 
252 /* ************************************************************************* */
253 int main() {
254  TestResult tr;
255  return TestRegistry::runAllTests(tr);
256 }
257 /* ************************************************************************* */
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
Eigen::internal::print
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
Definition: NEON/PacketMath.h:3115
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:44
LoopyBelief::CorrectedBeliefIndices
std::map< Key, size_t > CorrectedBeliefIndices
Definition: testLoopyBelief.cpp:32
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:98
main
int main()
Definition: testLoopyBelief.cpp:253
s
RealScalar s
Definition: level1_cplx_impl.h:126
TestHarness.h
DiscreteFactorGraph.h
LoopyBelief::starGraphs_
StarGraphs starGraphs_
star graph at each variable
Definition: testLoopyBelief.cpp:58
LoopyBelief::StarGraph
Definition: testLoopyBelief.cpp:33
gtsam::DiscreteFactorGraph::print
void print(const std::string &s="DiscreteFactorGraph", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Definition: DiscreteFactorGraph.cpp:84
gtsam::FastSet< Key >
LoopyBelief::StarGraphs
std::map< Key, StarGraph > StarGraphs
Definition: testLoopyBelief.cpp:57
DiscreteConditional.h
boost
Definition: boostmultiprec.cpp:109
gtsam::DecisionTreeFactor::shared_ptr
std::shared_ptr< DecisionTreeFactor > shared_ptr
Definition: DecisionTreeFactor.h:50
solver
BiCGSTAB< SparseMatrix< double > > solver
Definition: BiCGSTAB_simple.cpp:5
LoopyBelief::StarGraph::print
void print(const std::string &s="") const
Definition: testLoopyBelief.cpp:46
gtsam::FactorGraph::at
const sharedFactor at(size_t i) const
Definition: FactorGraph.h:306
LoopyBelief::print
void print(const std::string &s="") const
print
Definition: testLoopyBelief.cpp:71
LoopyBelief::iterate
DiscreteFactorGraph::shared_ptr iterate(const std::map< Key, DiscreteKey > &allDiscreteKeys)
One step of belief propagation.
Definition: testLoopyBelief.cpp:79
gtsam::EliminateDiscrete
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
Definition: DiscreteFactorGraph.cpp:205
unary
Definition: testExpression.cpp:78
debug
static constexpr bool debug
Definition: testDiscreteBayesTree.cpp:31
LoopyBelief::StarGraph::varIndex_
VariableIndex varIndex_
Definition: testLoopyBelief.cpp:37
gtsam::DecisionTreeFactor::print
void print(const std::string &s="DecisionTreeFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Definition: DecisionTreeFactor.cpp:90
TEST_UNSAFE
TEST_UNSAFE(LoopyBelief, construction)
Definition: testLoopyBelief.cpp:218
gtsam::VariableIndex
Definition: VariableIndex.h:41
LoopyBelief::StarGraph::star
DiscreteFactorGraph::shared_ptr star
Definition: testLoopyBelief.cpp:34
TestResult
Definition: TestResult.h:26
key
const gtsam::Symbol key('X', 0)
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
LoopyBelief::StarGraph::StarGraph
StarGraph(const DiscreteFactorGraph::shared_ptr &_star, const CorrectedBeliefIndices &_beliefIndices, const DecisionTreeFactor::shared_ptr &_unary)
Definition: testLoopyBelief.cpp:38
LoopyBelief
Definition: testLoopyBelief.cpp:24
C
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
gtsam
traits
Definition: chartTesting.h:28
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
gtsam::symbol_shorthand::W
Key W(std::uint64_t j)
Definition: inference/Symbol.h:170
iter
iterator iter(handle obj)
Definition: pytypes.h:2428
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
std
Definition: BFloat16.h:88
VariableIndex.h
LoopyBelief::StarGraph::correctedBeliefIndices
CorrectedBeliefIndices correctedBeliefIndices
Definition: testLoopyBelief.cpp:35
v
Array< int, Dynamic, 1 > v
Definition: Array_initializer_list_vector_cxx11.cpp:1
LoopyBelief::LoopyBelief
LoopyBelief(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys)
Definition: testLoopyBelief.cpp:66
graph
NonlinearFactorGraph graph
Definition: doc/Code/OdometryExample.cpp:2
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
gtsam::Ordering
Definition: inference/Ordering.h:33
LoopyBelief::buildStarGraphs
StarGraphs buildStarGraphs(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys) const
Definition: testLoopyBelief.cpp:169
DecisionTreeFactor.h
R
Rot2 R(Rot2::fromAngle(0.1))
S
DiscreteKey S(1, 2)
gtsam::DiscreteFactorGraph::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to This
Definition: DiscreteFactorGraph.h:106
LoopyBelief::StarGraph::unary
DecisionTreeFactor::shared_ptr unary
Definition: testLoopyBelief.cpp:36
gtsam::NonlinearFactorGraph::print
void print(const std::string &str="NonlinearFactorGraph: ", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const override
Definition: NonlinearFactorGraph.cpp:55


gtsam
Author(s):
autogenerated on Mon Jul 1 2024 03:06:49