testLoopyBelief.cpp
Go to the documentation of this file.
1 
13 
14 #include <fstream>
15 #include <iostream>
16 
17 using namespace std;
18 using namespace boost;
19 using namespace gtsam;
20 
24 class LoopyBelief {
32  typedef std::map<Key, size_t> CorrectedBeliefIndices;
33  struct StarGraph {
35  CorrectedBeliefIndices correctedBeliefIndices;
39  const CorrectedBeliefIndices& _beliefIndices,
40  const DecisionTreeFactor::shared_ptr& _unary)
41  : star(_star),
42  correctedBeliefIndices(_beliefIndices),
43  unary(_unary),
44  varIndex_(*_star) {}
45 
46  void print(const std::string& s = "") const {
47  cout << s << ":" << endl;
48  star->print("Star graph: ");
49  for (const auto& [key, _] : correctedBeliefIndices) {
50  cout << "Belief factor index for " << key << ": "
51  << correctedBeliefIndices.at(key) << endl;
52  }
53  if (unary) unary->print("Unary: ");
54  }
55  };
56 
57  typedef std::map<Key, StarGraph> StarGraphs;
58  StarGraphs starGraphs_;
59 
60  public:
67  const std::map<Key, DiscreteKey>& allDiscreteKeys)
68  : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {}
69 
71  void print(const std::string& s = "") const {
72  cout << s << ":" << endl;
73  for (const auto& [key, _] : starGraphs_) {
74  starGraphs_.at(key).print("Node " + std::to_string(key) + ":");
75  }
76  }
77 
80  const std::map<Key, DiscreteKey>& allDiscreteKeys) {
81  static const bool debug = false;
83  std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
84  // Eliminate each star graph
85  for (const auto& [key, _] : starGraphs_) {
86  // cout << "***** Node " << key << "*****" << endl;
87  // initialize belief to the unary factor from the original graph
89 
90  // keep intermediate messages to divide later
91  std::map<Key, DiscreteFactor::shared_ptr> messages;
92 
93  // eliminate each neighbor in this star graph one by one
94  for (const auto& [neighbor, _] : starGraphs_.at(key).correctedBeliefIndices) {
95  DiscreteFactorGraph subGraph;
96  for (size_t factor : starGraphs_.at(key).varIndex_[neighbor]) {
97  subGraph.push_back(starGraphs_.at(key).star->at(factor));
98  }
99  if (debug) subGraph.print("------- Subgraph:");
100  const auto [dummyCond, message] =
101  EliminateDiscrete(subGraph, Ordering{neighbor});
102  // store the new factor into messages
103  messages.insert(make_pair(neighbor, message));
104  if (debug) message->print("------- Message: ");
105 
106  // Belief is the product of all messages and the unary factor
107  // Incorporate new the factor to belief
108  if (!beliefAtKey)
109  beliefAtKey =
110  std::dynamic_pointer_cast<DecisionTreeFactor>(message);
111  else
112  beliefAtKey = std::make_shared<DecisionTreeFactor>(
113  (*beliefAtKey) *
114  (*std::dynamic_pointer_cast<DecisionTreeFactor>(message)));
115  }
116  if (starGraphs_.at(key).unary)
117  beliefAtKey = std::make_shared<DecisionTreeFactor>(
118  (*beliefAtKey) * (*starGraphs_.at(key).unary));
119  if (debug) beliefAtKey->print("New belief at key: ");
120  // normalize belief
121  double sum = 0.0;
122  for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
123  DiscreteValues val;
124  val[key] = v;
125  sum += (*beliefAtKey)(val);
126  }
127  // TODO(kartikarcot): Check if this makes sense
128  string sumFactorTable;
129  for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
130  sumFactorTable = sumFactorTable + " " + std::to_string(sum);
131  }
132  DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
133  if (debug) sumFactor.print("denomFactor: ");
134  beliefAtKey =
135  std::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
136  if (debug) beliefAtKey->print("New belief at key normalized: ");
137  beliefs->push_back(beliefAtKey);
138  allMessages[key] = messages;
139  }
140 
141  // Update corrected beliefs
142  VariableIndex beliefFactors(*beliefs);
143  for (const auto& [key, _] : starGraphs_) {
144  std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
145  for (const auto& [neighbor, _] : starGraphs_.at(key).correctedBeliefIndices) {
146  DecisionTreeFactor correctedBelief =
147  (*std::dynamic_pointer_cast<DecisionTreeFactor>(
148  beliefs->at(beliefFactors[key].front()))) /
149  (*std::dynamic_pointer_cast<DecisionTreeFactor>(
150  messages.at(neighbor)));
151  if (debug) correctedBelief.print("correctedBelief: ");
152  size_t beliefIndex =
153  starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
154  starGraphs_.at(neighbor).star->replace(
155  beliefIndex,
156  std::make_shared<DecisionTreeFactor>(correctedBelief));
157  }
158  }
159 
160  if (debug) print("After update: ");
161 
162  return beliefs;
163  }
164 
165  private:
169  StarGraphs buildStarGraphs(
170  const DiscreteFactorGraph& graph,
171  const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
172  StarGraphs starGraphs;
173  VariableIndex varIndex(graph);
174  for (const auto& [key, _] : varIndex) {
175  // initialize to multiply with other unary factors later
176  DecisionTreeFactor::shared_ptr prodOfUnaries;
177 
178  // collect all factors involving this key in the original graph
180  for (size_t factorIndex : varIndex[key]) {
181  star->push_back(graph.at(factorIndex));
182 
183  // accumulate unary factors
184  if (graph.at(factorIndex)->size() == 1) {
185  if (!prodOfUnaries)
186  prodOfUnaries = std::dynamic_pointer_cast<DecisionTreeFactor>(
187  graph.at(factorIndex));
188  else
189  prodOfUnaries = std::make_shared<DecisionTreeFactor>(
190  *prodOfUnaries *
191  (*std::dynamic_pointer_cast<DecisionTreeFactor>(
192  graph.at(factorIndex))));
193  }
194  }
195 
196  // add the belief factor for each neighbor variable to this star graph
197  // also record the factor index for later modification
198  KeySet neighbors = star->keys();
199  neighbors.erase(key);
200  CorrectedBeliefIndices correctedBeliefIndices;
201  for (Key neighbor : neighbors) {
202  // TODO: default table for keys with more than 2 values?
203  string initialBelief;
204  for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
205  initialBelief = initialBelief + "0.0 ";
206  }
207  initialBelief = initialBelief + "1.0";
208  star->push_back(
209  DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
210  correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
211  }
212  starGraphs.insert(make_pair(
213  key, StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
214  }
215  return starGraphs;
216  }
217 };
218 
219 /* ************************************************************************* */
220 
221 TEST_UNSAFE(LoopyBelief, construction) {
222  // Variables: Cloudy, Sprinkler, Rain, Wet
223  DiscreteKey C(0, 2), S(1, 2), R(2, 2), W(3, 2);
224 
225  // Map from key to DiscreteKey for building belief factor.
226  // TODO: this is bad!
227  std::map<Key, DiscreteKey> allKeys{{0, C}, {1, S}, {2, R}, {3, W}};
228 
229  // Build graph
230  DecisionTreeFactor pC(C, "0.5 0.5");
231  DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
232  DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
233  DecisionTreeFactor pSR(S & R, "0.0 0.9 0.9 0.99");
234 
236  graph.push_back(pC);
237  graph.push_back(pSC);
238  graph.push_back(pRC);
239  graph.push_back(pSR);
240 
241  graph.print("graph: ");
242 
243  LoopyBelief solver(graph, allKeys);
244  solver.print("Loopy belief: ");
245 
246  // Main loop
247  for (size_t iter = 0; iter < 20; ++iter) {
248  cout << "==================================" << endl;
249  cout << "iteration: " << iter << endl;
250  DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys);
251  beliefs->print();
252  }
253 }
254 
255 /* ************************************************************************* */
256 int main() {
257  TestResult tr;
258  return TestRegistry::runAllTests(tr);
259 }
260 /* ************************************************************************* */
std::map< Key, StarGraph > StarGraphs
const gtsam::Symbol key('X', 0)
std::pair< DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr > EliminateDiscrete(const DiscreteFactorGraph &factors, const Ordering &frontalKeys)
Main elimination function for DiscreteFactorGraph.
static int runAllTests(TestResult &result)
DiscreteFactorGraph::shared_ptr iterate(const std::map< Key, DiscreteKey > &allDiscreteKeys)
One step of belief propagation.
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:190
std::map< Key, size_t > CorrectedBeliefIndices
Rot2 R(Rot2::fromAngle(0.1))
Definition: BFloat16.h:88
Key W(std::uint64_t j)
iterator iter(handle obj)
Definition: pytypes.h:2273
BiCGSTAB< SparseMatrix< double > > solver
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
int main()
DiscreteKey S(1, 2)
NonlinearFactorGraph graph
static constexpr bool debug
DiscreteFactorGraph::shared_ptr star
void print(const std::string &s="") const
Print.
Definition: Symbol.cpp:50
std::shared_ptr< This > shared_ptr
shared_ptr to This
void print(const std::string &s="") const
Array< int, Dynamic, 1 > v
void print(const std::string &s="") const
print
StarGraph(const DiscreteFactorGraph::shared_ptr &_star, const CorrectedBeliefIndices &_beliefIndices, const DecisionTreeFactor::shared_ptr &_unary)
RealScalar s
LoopyBelief(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys)
void print(const std::string &s="DecisionTreeFactor:\, const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Matrix< Scalar, Dynamic, Dynamic > C
Definition: bench_gemm.cpp:50
traits
Definition: chartTesting.h:28
StarGraphs buildStarGraphs(const DiscreteFactorGraph &graph, const std::map< Key, DiscreteKey > &allDiscreteKeys) const
std::shared_ptr< DecisionTreeFactor > shared_ptr
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
DecisionTreeFactor::shared_ptr unary
const sharedFactor at(size_t i) const
Definition: FactorGraph.h:343
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
void print(const std::string &s="DiscreteFactorGraph", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
CorrectedBeliefIndices correctedBeliefIndices
TEST_UNSAFE(LoopyBelief, construction)
StarGraphs starGraphs_
star graph at each variable
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:38:40