DiscreteSearch.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
22 
23 namespace gtsam {
24 
27 
28 /*
29  * A SearchNode represents a node in the search tree for the search algorithm.
30  * Each SearchNode contains a partial assignment of discrete variables, the
31  * current error, a bound on the final error, and the index of the next
32  * slot to be assigned.
33  */
34 struct SearchNode {
35  DiscreteValues assignment; // Partial assignment of discrete variables.
36  double error; // Current error for the partial assignment.
37  double bound; // Lower bound on the final error
38  std::optional<size_t> next; // Index of the next slot to be assigned.
39 
40  // Construct the root node for the search.
41  static SearchNode Root(size_t numSlots, double bound) {
42  return {DiscreteValues(), 0.0, bound, 0};
43  }
44 
45  struct Compare {
46  bool operator()(const SearchNode& a, const SearchNode& b) const {
47  return a.bound > b.bound; // smallest bound -> highest priority
48  }
49  };
50 
51  // Checks if the node represents a complete assignment.
52  inline bool isComplete() const { return !next; }
53 
54  // Expands the node by assigning the next variable(s).
55  SearchNode expand(const DiscreteValues& fa, const Slot& slot,
56  std::optional<size_t> nextSlot) const {
57  // Combine the new frontal assignment with the current partial assignment
58  DiscreteValues newAssignment = assignment;
59  for (auto& [key, value] : fa) {
60  newAssignment[key] = value;
61  }
62  double errorSoFar = error + slot.factor->error(newAssignment);
63  return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
64  }
65 
66  // Prints the SearchNode to an output stream.
67  friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
68  os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
69  return os;
70  }
71 };
72 
74  bool operator()(const Solution& a, const Solution& b) const {
75  return a.error < b.error;
76  }
77 };
78 
79 /*
80  * A Solutions object maintains a priority queue of the best solutions found
81  * during the search. The priority queue is limited to a maximum size, and
82  * solutions are only added if they are better than the worst solution.
83  */
84 class Solutions {
85  size_t maxSize_; // Maximum number of solutions to keep
86  std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
87 
88  public:
89  Solutions(size_t maxSize) : maxSize_(maxSize) {}
90 
91  // Add a solution to the priority queue, possibly evicting the worst one.
92  // Return true if we added the solution.
93  bool maybeAdd(double error, const DiscreteValues& assignment) {
94  const bool full = pq_.size() == maxSize_;
95  if (full && error >= pq_.top().error) return false;
96  if (full) pq_.pop();
97  pq_.emplace(error, assignment);
98  return true;
99  }
100 
101  // Check if we have any solutions
102  bool empty() const { return pq_.empty(); }
103 
104  // Method to print all solutions
105  friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
106  os << "Solutions (top " << sn.pq_.size() << "):\n";
107  auto pq = sn.pq_;
108  while (!pq.empty()) {
109  os << pq.top() << "\n";
110  pq.pop();
111  }
112  return os;
113  }
114 
115  // Check if (partial) solution with given bound can be pruned. If we have
116  // room, we never prune. Otherwise, prune if lower bound on error is worse
117  // than our current worst error.
118  bool prune(double bound) const {
119  if (pq_.size() < maxSize_) return false;
120  return bound >= pq_.top().error;
121  }
122 
123  // Method to extract solutions in ascending order of error
124  std::vector<Solution> extractSolutions() {
125  std::vector<Solution> result;
126  while (!pq_.empty()) {
127  result.push_back(pq_.top());
128  pq_.pop();
129  }
130  std::sort(
131  result.begin(), result.end(),
132  [](const Solution& a, const Solution& b) { return a.error < b.error; });
133  return result;
134  }
135 };
136 
137 // Get the factor associated with a node, possibly product of factors.
138 template <typename NodeType>
139 static DiscreteFactor::shared_ptr getFactor(const NodeType& node) {
140  const auto& factors = node->factors;
141  return factors.size() == 1 ? factors.back()
143 }
144 
146  using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
147  auto visitor = [this](const NodePtr& node, int data) {
149  const size_t cardinality = factor->cardinality(node->key);
150  std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
151  const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
152  slots_.emplace_back(std::move(slot));
153  return data + 1;
154  };
155 
156  int data = 0; // unused
157  treeTraversal::DepthFirstForest(etree, data, visitor);
159 }
160 
162  using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
163  auto visitor = [this](const NodePtr& cluster, int data) {
164  const auto factor = getFactor(cluster);
165  std::vector<std::pair<Key, size_t>> pairs;
166  for (Key key : cluster->orderedFrontalKeys) {
167  pairs.emplace_back(key, factor->cardinality(key));
168  }
169  const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
170  slots_.emplace_back(std::move(slot));
171  return data + 1;
172  };
173 
174  int data = 0; // unused
175  treeTraversal::DepthFirstForest(junctionTree, data, visitor);
177 }
178 
181  bool buildJunctionTree) {
183  if (buildJunctionTree) {
184  const DiscreteJunctionTree junctionTree(etree);
185  return DiscreteSearch(junctionTree);
186  } else {
187  return DiscreteSearch(etree);
188  }
189 }
190 
192  slots_.reserve(bayesNet.size());
193  for (auto& conditional : bayesNet) {
194  const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
195  slots_.emplace_back(std::move(slot));
196  }
197  std::reverse(slots_.begin(), slots_.end());
199 }
200 
202  using NodePtr = DiscreteBayesTree::sharedClique;
203  auto visitor = [this](const NodePtr& clique, int data) {
204  auto conditional = clique->conditional();
205  const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
206  slots_.emplace_back(std::move(slot));
207  return data + 1;
208  };
209 
210  int data = 0; // unused
213 }
214 
215 void DiscreteSearch::print(const std::string& name,
216  const KeyFormatter& formatter) const {
217  std::cout << name << " with " << slots_.size() << " slots:\n";
218  for (size_t i = 0; i < slots_.size(); ++i) {
219  std::cout << i << ": " << slots_[i] << std::endl;
220  }
221 }
222 
223 using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
225 
226 std::vector<Solution> DiscreteSearch::run(size_t K) const {
227  if (slots_.empty()) {
228  return {Solution(0.0, DiscreteValues())};
229  }
230 
231  Solutions solutions(K);
232  SearchNodeQueue expansions;
233  expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
234 
235  // Perform the search
236  while (!expansions.empty()) {
237  // Pop the partial assignment with the smallest bound
238  SearchNode current = expansions.top();
239  expansions.pop();
240 
241  // If we already have K solutions, prune if we cannot beat the worst one.
242  if (solutions.prune(current.bound)) {
243  continue;
244  }
245 
246  // Check if we have a complete assignment
247  if (current.isComplete()) {
248  solutions.maybeAdd(current.error, current.assignment);
249  continue;
250  }
251 
252  // Get the next slot to expand
253  const auto& slot = slots_[*current.next];
254  std::optional<size_t> nextSlot = *current.next + 1;
255  if (nextSlot == slots_.size()) nextSlot.reset();
256  for (auto& fa : slot.assignments) {
257  auto childNode = current.expand(fa, slot, nextSlot);
258 
259  // Again, prune if we cannot beat the worst solution
260  if (!solutions.prune(childNode.bound)) {
261  expansions.emplace(childNode);
262  }
263  }
264  }
265 
266  // Extract solutions from bestSolutions in ascending order of error
267  return solutions.extractSolutions();
268 }
269 /*
270  * We have a number of factors, each with a max value, and we want to compute
271  * a lower-bound on the cost-to-go for each slot, *not* including this factor.
272  * For the last slot[n-1], this is 0.0, as the cost after that is zero.
273  * For the second-to-last slot, it is h = -log(max(factor[n-1])), because after
274  * we assign slot[n-2] we still need to assign slot[n-1], which will cost *at
275  * least* h. We return the estimated lower bound of the cost for *all* slots.
276  */
278  double error = 0.0;
279  for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
280  it->heuristic = error;
281  Ordering ordering(it->factor->begin(), it->factor->end());
282  auto maxx = it->factor->max(ordering);
283  error -= std::log(maxx->evaluate({}));
284  }
285  return error;
286 }
287 
288 } // namespace gtsam
gtsam::SearchNode::Compare::operator()
bool operator()(const SearchNode &a, const SearchNode &b) const
Definition: DiscreteSearch.cpp:46
gtsam::SearchNode::expand
SearchNode expand(const DiscreteValues &fa, const Slot &slot, std::optional< size_t > nextSlot) const
Definition: DiscreteSearch.cpp:55
gtsam::DiscreteSearch::Slot::heuristic
double heuristic
Definition: DiscreteSearch.h:65
gtsam::SearchNode::operator<<
friend std::ostream & operator<<(std::ostream &os, const SearchNode &node)
Definition: DiscreteSearch.cpp:67
name
Annotation for function names.
Definition: attr.h:51
asia::bayesNet
static const DiscreteBayesNet bayesNet
Definition: testDiscreteSearch.cpp:30
gtsam::DiscreteFactorGraph
Definition: DiscreteFactorGraph.h:99
gtsam::CompareSolution
Definition: DiscreteSearch.cpp:73
gtsam::Solutions::Solutions
Solutions(size_t maxSize)
Definition: DiscreteSearch.cpp:89
gtsam::SearchNode::Compare
Definition: DiscreteSearch.cpp:45
gtsam::Solutions::operator<<
friend std::ostream & operator<<(std::ostream &os, const Solutions &sn)
Definition: DiscreteSearch.cpp:105
gtsam::SearchNode::bound
double bound
Definition: DiscreteSearch.cpp:37
gtsam::bound
double bound(double a, double min, double max)
Definition: PoseRTV.cpp:19
gtsam::Slot
DiscreteSearch::Slot Slot
Definition: DiscreteSearch.cpp:25
simple_graph::factors
const GaussianFactorGraph factors
Definition: testJacobianFactor.cpp:213
gtsam::Solutions::maxSize_
size_t maxSize_
Definition: DiscreteSearch.cpp:85
gtsam::DiscreteSearch::run
std::vector< Solution > run(size_t K=1) const
Search for the K best solutions.
Definition: DiscreteSearch.cpp:226
gtsam::Solutions::empty
bool empty() const
Definition: DiscreteSearch.cpp:102
gtsam::DiscreteSearch
DiscreteSearch: Search for the K best solutions.
Definition: DiscreteSearch.h:44
asia::bayesTree
static const DiscreteBayesTree bayesTree
Definition: testDiscreteSearch.cpp:40
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
log
const EIGEN_DEVICE_FUNC LogReturnType log() const
Definition: ArrayCwiseUnaryOps.h:128
gtsam::SearchNode::error
double error
Definition: DiscreteSearch.cpp:36
gtsam::SearchNodeQueue
std::priority_queue< SearchNode, std::vector< SearchNode >, SearchNode::Compare > SearchNodeQueue
Definition: DiscreteSearch.cpp:224
gtsam::DiscreteSearch::DiscreteSearch
DiscreteSearch(const DiscreteEliminationTree &etree)
Construct from a DiscreteEliminationTree.
Definition: DiscreteSearch.cpp:145
os
ofstream os("timeSchurFactors.csv")
gtsam::DiscreteSearch::Solution
Definition: DiscreteSearch.h:79
result
Values result
Definition: OdometryOptimize.cpp:8
gtsam::SearchNode
Definition: DiscreteSearch.cpp:34
gtsam::DiscreteSearch::FromFactorGraph
static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph &factorGraph, const Ordering &ordering, bool buildJunctionTree=false)
Definition: DiscreteSearch.cpp:179
gtsam::SearchNode::assignment
DiscreteValues assignment
Definition: DiscreteSearch.cpp:35
gtsam::Solutions::extractSolutions
std::vector< Solution > extractSolutions()
Definition: DiscreteSearch.cpp:124
asia::factorGraph
static const DiscreteFactorGraph factorGraph(bayesNet)
pruning_fixture::factor
DecisionTreeFactor factor(D &C &B &A, "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0")
gtsam::DiscreteSearch::lowerBound_
double lowerBound_
Lower bound on the cost-to-go for the entire search.
Definition: DiscreteSearch.h:161
gtsam::Solution
DiscreteSearch::Solution Solution
Definition: DiscreteSearch.cpp:26
gtsam::DiscreteJunctionTree
Definition: DiscreteJunctionTree.h:53
data
int data[]
Definition: Map_placement_new.cpp:1
gtsam::DiscreteBayesNet
Definition: DiscreteBayesNet.h:38
gtsam::DiscreteValues::CartesianProduct
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
Definition: DiscreteValues.h:148
gtsam::DiscreteEliminationTree
Elimination tree for discrete factors.
Definition: DiscreteEliminationTree.h:31
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
DiscreteJunctionTree.h
gtsam::treeTraversal::DepthFirstForest
void DepthFirstForest(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost)
Definition: treeTraversal-inst.h:77
gtsam::DiscreteSearch::computeHeuristic
double computeHeuristic()
Definition: DiscreteSearch.cpp:277
gtsam::CompareSolution::operator()
bool operator()(const Solution &a, const Solution &b) const
Definition: DiscreteSearch.cpp:74
gtsam::DiscreteFactorGraph::product
DiscreteFactor::shared_ptr product() const
Definition: DiscreteFactorGraph.cpp:67
DiscreteEliminationTree.h
gtsam::SearchNode::isComplete
bool isComplete() const
Definition: DiscreteSearch.cpp:52
ordering
static enum @1096 ordering
key
const gtsam::Symbol key('X', 0)
gtsam::FactorGraph::back
sharedFactor back() const
Definition: FactorGraph.h:348
gtsam::FactorGraph::size
size_t size() const
Definition: FactorGraph.h:297
gtsam::b
const G & b
Definition: Group.h:79
sn
static double sn[6]
Definition: fresnl.c:63
a
ArrayXXi a
Definition: Array_initializer_list_23_cxx11.cpp:1
gtsam::DiscreteSearch::Slot
Definition: DiscreteSearch.h:62
gtsam
traits
Definition: SFMdata.h:40
gtsam::DiscreteFactor::shared_ptr
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Definition: DiscreteFactor.h:45
error
static double error
Definition: testRot3.cpp:37
K
#define K
Definition: igam.h:8
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::BayesTree< DiscreteBayesTreeClique >::sharedClique
std::shared_ptr< Clique > sharedClique
Shared pointer to a clique.
Definition: BayesTree.h:74
gtsam::getFactor
static DiscreteFactor::shared_ptr getFactor(const NodeType &node)
Definition: DiscreteSearch.cpp:139
DiscreteSearch.h
Defines the DiscreteSearch class for discrete search algorithms.
gtsam::DiscreteBayesTree
A Bayes tree representing a Discrete distribution.
Definition: DiscreteBayesTree.h:73
gtsam::SearchNode::next
std::optional< size_t > next
Definition: DiscreteSearch.cpp:38
gtsam::DiscreteSearch::Slot::factor
DiscreteFactor::shared_ptr factor
Definition: DiscreteSearch.h:63
gtsam::Solutions
Definition: DiscreteSearch.cpp:84
reverse
void reverse(const MatrixType &m)
Definition: array_reverse.cpp:16
gtsam::Solutions::maybeAdd
bool maybeAdd(double error, const DiscreteValues &assignment)
Definition: DiscreteSearch.cpp:93
gtsam::SearchNode::Root
static SearchNode Root(size_t numSlots, double bound)
Definition: DiscreteSearch.cpp:41
gtsam::DiscreteSearch::slots_
std::vector< Slot > slots_
The slots to fill in the search.
Definition: DiscreteSearch.h:162
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::DiscreteSearch::print
void print(const std::string &name="DiscreteSearch: ", const KeyFormatter &formatter=DefaultKeyFormatter) const
Definition: DiscreteSearch.cpp:215
test_callbacks.value
value
Definition: test_callbacks.py:162
gtsam::Solutions::pq_
std::priority_queue< Solution, std::vector< Solution >, CompareSolution > pq_
Definition: DiscreteSearch.cpp:86
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
gtsam::Solutions::prune
bool prune(double bound) const
Definition: DiscreteSearch.cpp:118


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:01:36