DiscreteConditional.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
19 #include <gtsam/base/Testable.h>
20 #include <gtsam/base/debug.h>
24 
25 #include <algorithm>
26 #include <random>
27 #include <set>
28 #include <stdexcept>
29 #include <string>
30 #include <utility>
31 #include <vector>
32 
33 using namespace std;
34 using std::pair;
35 using std::stringstream;
36 using std::vector;
37 namespace gtsam {
38 
39 // Instantiate base class
40 template class GTSAM_EXPORT
42 
43 /* ************************************************************************** */
44 DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
45  const DecisionTreeFactor& f)
46  : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
47 
48 /* ************************************************************************** */
50  const DiscreteKeys& keys,
51  const ADT& potentials)
52  : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
53 
54 /* ************************************************************************** */
57  : BaseFactor(joint / marginal),
58  BaseConditional(joint.size() - marginal.size()) {}
59 
60 /* ************************************************************************** */
63  const Ordering& orderedKeys)
64  : DiscreteConditional(joint, marginal) {
65  keys_.clear();
66  keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
67 }
68 
69 /* ************************************************************************** */
71  : BaseFactor(signature.discreteKeys(), signature.cpt()),
72  BaseConditional(1) {}
73 
74 /* ************************************************************************** */
76  const DiscreteConditional& other) const {
77  // Take union of frontal keys
78  std::set<Key> newFrontals;
79  for (auto&& key : this->frontals()) newFrontals.insert(key);
80  for (auto&& key : other.frontals()) newFrontals.insert(key);
81 
82  // Check if frontals overlapped
83  if (nrFrontals() + other.nrFrontals() > newFrontals.size())
84  throw std::invalid_argument(
85  "DiscreteConditional::operator* called with overlapping frontal keys.");
86 
87  // Now, add cardinalities.
89  for (auto&& key : frontals())
90  discreteKeys.emplace_back(key, cardinality(key));
91  for (auto&& key : other.frontals())
92  discreteKeys.emplace_back(key, other.cardinality(key));
93 
94  // Sort
95  std::sort(discreteKeys.begin(), discreteKeys.end());
96 
97  // Add parents to set, to make them unique
98  std::set<DiscreteKey> parents;
99  for (auto&& key : this->parents())
100  if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
101  for (auto&& key : other.parents())
102  if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
103 
104  // Finally, add parents to keys, in order
105  for (auto&& dk : parents) discreteKeys.push_back(dk);
106 
108  return DiscreteConditional(newFrontals.size(), discreteKeys, product);
109 }
110 
111 /* ************************************************************************** */
113  if (nrParents() > 0)
114  throw std::invalid_argument(
115  "DiscreteConditional::marginal: single argument version only valid for "
116  "fully specified joint distributions (i.e., no parents).");
117 
118  // Calculate the keys as the frontal keys without the given key.
120 
121  // Calculate sum
122  ADT adt(*this);
123  for (auto&& k : frontals())
124  if (k != key) adt = adt.sum(k, cardinality(k));
125 
126  // Return new factor
127  return DiscreteConditional(1, discreteKeys, adt);
128 }
129 
130 /* ************************************************************************** */
131 void DiscreteConditional::print(const string& s,
132  const KeyFormatter& formatter) const {
133  cout << s << " P( ";
134  for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
135  cout << formatter(*it) << " ";
136  }
137  if (nrParents()) {
138  cout << "| ";
139  for (const_iterator it = beginParents(); it != endParents(); ++it) {
140  cout << formatter(*it) << " ";
141  }
142  }
143  cout << "):\n";
144  ADT::print("", formatter);
145  cout << endl;
146 }
147 
148 /* ************************************************************************** */
150  double tol) const {
151  if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
152  return false;
153  } else {
154  const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
155  return DecisionTreeFactor::equals(f, tol);
156  }
157 }
158 
159 /* ************************************************************************** */
161  const DiscreteValues& given, bool forceComplete) const {
162  // Get the big decision tree with all the levels, and then go down the
163  // branches based on the value of the parent variables.
164  DiscreteConditional::ADT adt(*this);
165  size_t value;
166  for (Key j : parents()) {
167  try {
168  value = given.at(j);
169  adt = adt.choose(j, value); // ADT keeps getting smaller.
170  } catch (std::out_of_range&) {
171  if (forceComplete) {
172  given.print("parentsValues: ");
173  throw runtime_error(
174  "DiscreteConditional::choose: parent value missing");
175  }
176  }
177  }
178  return adt;
179 }
180 
181 /* ************************************************************************** */
183  const DiscreteValues& given) const {
184  ADT adt = choose(given, false); // P(F|S=given)
185 
186  // Collect all keys not in given.
187  DiscreteKeys dKeys;
188  for (Key j : frontals()) {
189  dKeys.emplace_back(j, this->cardinality(j));
190  }
191  for (size_t i = nrFrontals(); i < size(); i++) {
192  Key j = keys_[i];
193  if (given.count(j) == 0) {
194  dKeys.emplace_back(j, this->cardinality(j));
195  }
196  }
197  return std::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
198 }
199 
200 /* ************************************************************************** */
202  const DiscreteValues& frontalValues) const {
203  // Get the big decision tree with all the levels, and then go down the
204  // branches based on the value of the frontal variables.
205  ADT adt(*this);
206  size_t value;
207  for (Key j : frontals()) {
208  try {
209  value = frontalValues.at(j);
210  adt = adt.choose(j, value); // ADT keeps getting smaller.
211  } catch (exception&) {
212  frontalValues.print("frontalValues: ");
213  throw runtime_error("DiscreteConditional::choose: frontal value missing");
214  }
215  }
216 
217  // Convert ADT to factor.
219  for (Key j : parents()) {
220  discreteKeys.emplace_back(j, this->cardinality(j));
221  }
222  return std::make_shared<DecisionTreeFactor>(discreteKeys, adt);
223 }
224 
225 /* ****************************************************************************/
227  size_t frontal) const {
228  if (nrFrontals() != 1)
229  throw std::invalid_argument(
230  "Single value likelihood can only be invoked on single-variable "
231  "conditional");
233  values.emplace(keys_[0], frontal);
234  return likelihood(values);
235 }
236 
237 /* ************************************************************************** */
239  size_t maxValue = 0;
240  double maxP = 0;
241  assert(nrFrontals() == 1);
242  assert(nrParents() == 0);
244  Key j = firstFrontalKey();
245  for (size_t value = 0; value < cardinality(j); value++) {
246  frontals[j] = value;
247  double pValueS = (*this)(frontals);
248  // Update MPE solution if better
249  if (pValueS > maxP) {
250  maxP = pValueS;
251  maxValue = value;
252  }
253  }
254  return maxValue;
255 }
256 
257 /* ************************************************************************** */
259  assert(nrFrontals() == 1);
260  Key j = (firstFrontalKey());
261  size_t sampled = sample(*values); // Sample variable given parents
262  (*values)[j] = sampled; // store result in partial solution
263 }
264 
265 /* ************************************************************************** */
266 size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
267  static mt19937 rng(2); // random number generator
268 
269  // Get the correct conditional distribution
270  ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
271 
272  // TODO(Duy): only works for one key now, seems horribly slow this way
273  if (nrFrontals() != 1) {
274  throw std::invalid_argument(
275  "DiscreteConditional::sample can only be called on single variable "
276  "conditionals");
277  }
278  Key key = firstFrontalKey();
279  size_t nj = cardinality(key);
280  vector<double> p(nj);
282  for (size_t value = 0; value < nj; value++) {
283  frontals[key] = value;
284  p[value] = pFS(frontals); // P(F=value|S=parentsValues)
285  if (p[value] == 1.0) {
286  return value; // shortcut exit
287  }
288  }
289  std::discrete_distribution<size_t> distribution(p.begin(), p.end());
290  return distribution(rng);
291 }
292 
293 /* ************************************************************************** */
294 size_t DiscreteConditional::sample(size_t parent_value) const {
295  if (nrParents() != 1)
296  throw std::invalid_argument(
297  "Single value sample() can only be invoked on single-parent "
298  "conditional");
300  values.emplace(keys_.back(), parent_value);
301  return sample(values);
302 }
303 
304 /* ************************************************************************** */
306  if (nrParents() != 0)
307  throw std::invalid_argument(
308  "sample() can only be invoked on no-parent prior");
310  return sample(values);
311 }
312 
313 /* ************************************************************************* */
314 vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
315  vector<pair<Key, size_t>> pairs;
316  for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
317  vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
318  return DiscreteValues::CartesianProduct(rpairs);
319 }
320 
321 /* ************************************************************************* */
322 vector<DiscreteValues> DiscreteConditional::allAssignments() const {
323  vector<pair<Key, size_t>> pairs;
324  for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
325  for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
326  vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
327  return DiscreteValues::CartesianProduct(rpairs);
328 }
329 
330 /* ************************************************************************* */
331 // Print out signature.
332 static void streamSignature(const DiscreteConditional& conditional,
333  const KeyFormatter& keyFormatter,
334  stringstream* ss) {
335  *ss << "P(";
336  bool first = true;
337  for (Key key : conditional.frontals()) {
338  if (!first) *ss << ",";
339  *ss << keyFormatter(key);
340  first = false;
341  }
342  if (conditional.nrParents() > 0) {
343  *ss << "|";
344  bool first = true;
345  for (Key parent : conditional.parents()) {
346  if (!first) *ss << ",";
347  *ss << keyFormatter(parent);
348  first = false;
349  }
350  }
351  *ss << "):";
352 }
353 
354 /* ************************************************************************* */
355 std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
356  const Names& names) const {
357  stringstream ss;
358  ss << " *";
359  streamSignature(*this, keyFormatter, &ss);
360  ss << "*\n" << std::endl;
361  if (nrParents() == 0) {
362  // We have no parents, call factor method.
363  ss << DecisionTreeFactor::markdown(keyFormatter, names);
364  return ss.str();
365  }
366 
367  // Print out header.
368  ss << "|";
369  for (Key parent : parents()) {
370  ss << "*" << keyFormatter(parent) << "*|";
371  }
372 
373  auto frontalAssignments = this->frontalAssignments();
374  for (const auto& a : frontalAssignments) {
375  for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
376  size_t index = a.at(*it);
377  ss << DiscreteValues::Translate(names, *it, index);
378  }
379  ss << "|";
380  }
381  ss << "\n";
382 
383  // Print out separator with alignment hints.
384  ss << "|";
385  size_t n = frontalAssignments.size();
386  for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
387  ss << "\n";
388 
389  // Print out all rows.
390  size_t count = 0;
391  for (const auto& a : allAssignments()) {
392  if (count == 0) {
393  ss << "|";
394  for (auto&& it = beginParents(); it != endParents(); ++it) {
395  size_t index = a.at(*it);
396  ss << DiscreteValues::Translate(names, *it, index) << "|";
397  }
398  }
399  ss << operator()(a) << "|";
400  count = (count + 1) % n;
401  if (count == 0) ss << "\n";
402  }
403  return ss.str();
404 }
405 
406 /* ************************************************************************ */
407 string DiscreteConditional::html(const KeyFormatter& keyFormatter,
408  const Names& names) const {
409  stringstream ss;
410  ss << "<div>\n<p> <i>";
411  streamSignature(*this, keyFormatter, &ss);
412  ss << "</i></p>\n";
413  if (nrParents() == 0) {
414  // We have no parents, call factor method.
415  ss << DecisionTreeFactor::html(keyFormatter, names);
416  return ss.str();
417  }
418 
419  // Print out preamble.
420  ss << "<table class='DiscreteConditional'>\n <thead>\n";
421 
422  // Print out header row.
423  ss << " <tr>";
424  for (Key parent : parents()) {
425  ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
426  }
427  auto frontalAssignments = this->frontalAssignments();
428  for (const auto& a : frontalAssignments) {
429  ss << "<th>";
430  for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
431  size_t index = a.at(*it);
432  ss << DiscreteValues::Translate(names, *it, index);
433  }
434  ss << "</th>";
435  }
436  ss << "</tr>\n";
437 
438  // Finish header and start body.
439  ss << " </thead>\n <tbody>\n";
440 
441  // Output all rows, one per assignment:
442  size_t count = 0, n = frontalAssignments.size();
443  for (const auto& a : allAssignments()) {
444  if (count == 0) {
445  ss << " <tr>";
446  for (auto&& it = beginParents(); it != endParents(); ++it) {
447  size_t index = a.at(*it);
448  ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
449  }
450  }
451  ss << "<td>" << operator()(a) << "</td>"; // value
452  count = (count + 1) % n;
453  if (count == 0) ss << "</tr>\n";
454  }
455 
456  // Finish up
457  ss << " </tbody>\n</table>\n</div>";
458  return ss.str();
459 }
460 
461 /* ************************************************************************* */
463  return this->evaluate(x.discrete());
464 }
465 /* ************************************************************************* */
466 
467 } // namespace gtsam
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
const gtsam::Symbol key('X', 0)
AlgebraicDecisionTree sum(const L &label, size_t cardinality) const
void sampleInPlace(DiscreteValues *parentsValues) const
sample in place, stores result in partial solution
DiscreteKeys discreteKeys() const
Return all the discrete keys associated with this factor.
Concept check for values that can be used in unit tests.
signatures for conditional densities
Global debugging flags.
DiscreteConditional()
Default constructor needed for serialization.
double evaluate(const DiscreteValues &values) const
Evaluate, just look up in AlgebraicDecisonTree.
void print(const std::string &s="", const typename Base::LabelFormatter &labelFormatter=&DefaultFormatter) const
print method customized to value type double.
double mul(const double &a, const double &b)
shared_ptr choose(const DiscreteValues &given) const
< DiscreteValues version
DecisionTreeFactor ::const_iterator endFrontals() const
Definition: Conditional.h:182
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
size_t size() const
Definition: Factor.h:159
static std::mt19937 rng
int n
leaf::MyValues values
DecisionTreeFactor ::const_iterator beginParents() const
Definition: Conditional.h:185
KeyVector keys_
The keys involved in this factor.
Definition: Factor.h:87
Definition: BFloat16.h:88
DecisionTreeFactor::shared_ptr likelihood(const DiscreteValues &frontalValues) const
static std::vector< DiscreteValues > CartesianProduct(const DiscreteKeys &keys)
Return a vector of DiscreteValues, one for each possible combination of values.
const KeyFormatter & formatter
DecisionTreeFactor ::const_iterator beginFrontals() const
Definition: Conditional.h:179
DiscreteConditional operator*(const DiscreteConditional &other) const
Combine two conditionals, yielding a new conditional with the union of the frontal keys...
size_t cardinality(Key j) const
double operator()(const DiscreteValues &values) const override
Evaluate probability distribution, sugar.
DecisionTreeFactor ::const_iterator endParents() const
Definition: Conditional.h:188
static void streamSignature(const DiscreteConditional &conditional, const KeyFormatter &keyFormatter, stringstream *ss)
std::map< Key, size_t > cardinalities_
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
RealScalar s
DecisionTree apply(const Unary &op) const
size_t sample() const
Zero parent version.
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
std::string html(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as html table.
static std::stringstream ss
Definition: testBTree.cpp:31
DiscreteValues::Names Names
Translation table from values to strings.
traits
Definition: chartTesting.h:28
std::shared_ptr< DecisionTreeFactor > shared_ptr
size_t argmax() const
Return assignment that maximizes distribution.
static std::string Translate(const Names &names, Key key, size_t index)
Translate an integer index value for given key to a string.
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
GTSAM-style equals.
float * p
const DiscreteValues & discrete() const
Return the discrete values.
Definition: HybridValues.h:92
const KeyVector & keys() const
Access the factor&#39;s involved variable keys.
Definition: Factor.h:142
std::vector< DiscreteValues > allAssignments() const
Return all assignments for frontal and parent variables.
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
const G double tol
Definition: Group.h:86
std::shared_ptr< This > shared_ptr
shared_ptr to this class
KeyVector::const_iterator const_iterator
Const iterator over keys.
Definition: Factor.h:82
DiscreteConditional marginal(Key key) const
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
std::string markdown(const KeyFormatter &keyFormatter=DefaultKeyFormatter, const Names &names={}) const override
Render as markdown table.
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:341
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:102
void print(const std::string &s="Discrete Conditional: ", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style print.
std::ptrdiff_t j
std::vector< DiscreteValues > frontalAssignments() const
Return all assignments for frontal variables.
void product(const MatrixType &m)
Definition: product.h:20
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:10