test_DiscreteSearch.py
Go to the documentation of this file.
1 """
2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
4 All Rights Reserved
5 
6 See LICENSE for the license information
7 
8 Unit tests for Discrete Search.
9 Author: Frank Dellaert
10 """
11 
12 # pylint: disable=no-name-in-module, invalid-name
13 
14 import unittest
15 
16 from dfg_utils import generate_observation_cpt, generate_transition_cpt, make_key
17 from gtsam.utils.test_case import GtsamTestCase
18 
19 from gtsam import (
20  DiscreteConditional,
21  DiscreteFactorGraph,
22  DiscreteSearch,
23  Ordering,
24  DefaultKeyFormatter,
25 )
26 
27 OrderingType = Ordering.OrderingType
28 
29 
31  """Tests for Discrete Factor Graphs."""
32 
33  def test_MPE_chain(self):
34  """
35  Test for numerical underflow in EliminateMPE on long chains.
36  Adapted from the toy problem of @pcl15423
37  Ref: https://github.com/borglab/gtsam/issues/1448
38  """
39  num_states = 3
40  num_obs = 200
41  desired_state = 1
42  states = list(range(num_states))
43 
44  X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
45  Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
46  graph = DiscreteFactorGraph()
47 
48  transition_cpt = generate_transition_cpt(num_states)
49  for i in reversed(range(1, num_obs)):
50  transition_conditional = DiscreteConditional(
51  X[i], [X[i - 1]], transition_cpt
52  )
53  graph.push_back(transition_conditional)
54 
55  # Contrived example such that the desired state gives measurements [0, num_obs) with equal
56  # probability but all other states always give measurement num_obs
57  obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
58  # Contrived example where each measurement is its own index
59  for i in range(num_obs):
60  obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
61  factor = obs_conditional.likelihood(i)
62  graph.push_back(factor)
63 
64  # Check MPE
65  mpe = graph.optimize()
66  vals = [mpe[X[i][0]] for i in range(num_obs)]
67  self.assertEqual(vals, [desired_state] * num_obs)
68 
69  # Create an ordering:
70  ordering = Ordering()
71  for i in reversed(range(num_obs)):
72  ordering.push_back(X[i][0])
73 
74  # Now do Search
75  search = DiscreteSearch.FromFactorGraph(graph, ordering)
76  solutions = search.run(K=1)
77  mpe2 = solutions[0].assignment
78  # print({DefaultKeyFormatter(key): value for key, value in mpe2.items()})
79  vals = [mpe2[X[i][0]] for i in range(num_obs)]
80  self.assertEqual(vals, [desired_state] * num_obs)
81 
82 
83 if __name__ == "__main__":
84  unittest.main()
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
list
Definition: pytypes.h:2168
test_DiscreteSearch.TestDiscreteSearch
Definition: test_DiscreteSearch.py:30
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
dfg_utils.generate_observation_cpt
def generate_observation_cpt(num_states, num_obs, desired_state)
Definition: dfg_utils.py:27
dfg_utils.generate_transition_cpt
def generate_transition_cpt(num_states, transitions=None)
Definition: dfg_utils.py:14
gtsam::utils.test_case
Definition: test_case.py:1
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam::utils.test_case.GtsamTestCase
Definition: test_case.py:16
test_DiscreteSearch.TestDiscreteSearch.test_MPE_chain
def test_MPE_chain(self)
Definition: test_DiscreteSearch.py:33
len
size_t len(handle h)
Get the length of a Python object.
Definition: pytypes.h:2448
dfg_utils.make_key
def make_key(character, index, cardinality)
Definition: dfg_utils.py:5
gtsam::Ordering
Definition: inference/Ordering.h:33


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:06:17