TableDistribution.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
18 #include <gtsam/base/Testable.h>
19 #include <gtsam/base/debug.h>
20 #include <gtsam/base/utilities.h>
21 #include <gtsam/discrete/Ring.h>
25 
26 #include <algorithm>
27 #include <cassert>
28 #include <random>
29 #include <set>
30 #include <stdexcept>
31 #include <string>
32 #include <utility>
33 #include <vector>
34 
35 using namespace std;
36 using std::pair;
37 using std::stringstream;
38 using std::vector;
39 namespace gtsam {
40 
43  const Eigen::SparseVector<double>& sparse_table) {
44  return sparse_table / sparse_table.sum();
45 }
46 
47 /* ************************************************************************** */
48 TableDistribution::TableDistribution(const TableFactor& f)
49  : BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
50  table_(f / (*std::dynamic_pointer_cast<TableFactor>(
51  f.sum(f.keys().size())))) {}
52 
53 /* ************************************************************************** */
55  const std::vector<double>& potentials)
56  : BaseConditional(keys.size(), keys, ADT(nullptr)),
57  table_(TableFactor(
58  keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
59 }
60 
61 /* ************************************************************************** */
63  const std::string& potentials)
64  : BaseConditional(keys.size(), keys, ADT(nullptr)),
65  table_(TableFactor(
66  keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
67 }
68 
69 /* ************************************************************************** */
70 void TableDistribution::print(const string& s,
71  const KeyFormatter& formatter) const {
72  cout << s << " P( ";
73  for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
74  cout << formatter(*it) << " ";
75  }
76  cout << "):\n";
77  table_.print("", formatter);
78  cout << endl;
79 }
80 
81 /* ************************************************************************** */
82 bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
83  auto dtc = dynamic_cast<const TableDistribution*>(&other);
84  if (!dtc) {
85  return false;
86  } else {
87  const DiscreteConditional& f(
88  static_cast<const DiscreteConditional&>(other));
89  return table_.equals(dtc->table_, tol) &&
91  }
92 }
93 
94 /* ****************************************************************************/
96  return table_.sum(nrFrontals);
97 }
98 
99 /* ****************************************************************************/
101  return table_.sum(keys);
102 }
103 
104 /* ****************************************************************************/
106  return table_.max(nrFrontals);
107 }
108 
109 /* ****************************************************************************/
111  return table_.max(keys);
112 }
113 
114 /* ****************************************************************************/
116  return table_ * s;
117 }
118 
119 /* ****************************************************************************/
121  const DiscreteFactor::shared_ptr& f) const {
122  return table_ / f;
123 }
124 
125 /* ************************************************************************ */
127  uint64_t maxIdx = 0;
128  double maxValue = 0.0;
129 
131 
132  for (SparseIt it(sparseTable); it; ++it) {
133  if (it.value() > maxValue) {
134  maxIdx = it.index();
135  maxValue = it.value();
136  }
137  }
138 
139  return table_.findAssignments(maxIdx);
140 }
141 
142 /* ****************************************************************************/
143 void TableDistribution::prune(size_t maxNrAssignments) {
144  table_ = table_.prune(maxNrAssignments);
145 }
146 
147 /* ****************************************************************************/
148 size_t TableDistribution::sample(const DiscreteValues& parentsValues,
149  std::mt19937_64* rng) const {
150  DiscreteKeys parentsKeys;
151  for (auto&& [key, _] : parentsValues) {
152  parentsKeys.push_back({key, table_.cardinality(key)});
153  }
154 
155  // Get the correct conditional distribution: P(F|S=parentsValues)
156  TableFactor pFS = table_.choose(parentsValues, parentsKeys);
157 
158  // TODO(Duy): only works for one key now, seems horribly slow this way
159  if (nrFrontals() != 1) {
160  throw std::invalid_argument(
161  "TableDistribution::sample can only be called on single variable "
162  "conditionals");
163  }
164  Key key = firstFrontalKey();
165  size_t nj = cardinality(key);
166  vector<double> p(nj);
168  for (size_t value = 0; value < nj; value++) {
169  frontals[key] = value;
170  p[value] = pFS(frontals); // P(F=value|S=parentsValues)
171  if (p[value] == 1.0) {
172  return value; // shortcut exit
173  }
174  }
175 
176  // Check if rng is nullptr, then assign default
177  rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
178 
179  std::discrete_distribution<size_t> distribution(p.begin(), p.end());
180  return distribution(*rng);
181 }
182 
183 } // namespace gtsam
TableDistribution.h
gtsam::TableFactor
Definition: TableFactor.h:54
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::TableFactor::sum
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override
Create new factor by summing all values with the same separator values.
Definition: TableFactor.cpp:395
Eigen::SparseVector::sum
Scalar sum() const
Definition: SparseRedux.h:41
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::TableDistribution::argmax
DiscreteValues argmax() const
Return assignment that maximizes value.
Definition: TableDistribution.cpp:126
gtsam::Conditional::equals
bool equals(const This &c, double tol=1e-9) const
Definition: Conditional-inst.h:41
Testable.h
Concept check for values that can be used in unit tests.
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
gtsam::TableDistribution::table_
TableFactor table_
Definition: TableDistribution.h:41
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
utilities.h
Eigen::SparseCompressedBase::InnerIterator
Definition: SparseCompressedBase.h:158
gtsam::AlgebraicDecisionTree< Key >
size
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
gtsam::TableDistribution
Definition: TableDistribution.h:39
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::endFrontals
DecisionTreeFactor ::const_iterator endFrontals() const
Definition: Conditional.h:183
gtsam::TableDistribution::TableDistribution
TableDistribution()
Default constructor needed for serialization.
Definition: TableDistribution.h:58
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::TableFactor::max
double max() const override
Find the maximum value in the factor.
Definition: TableFactor.cpp:405
gtsam::TableFactor::findAssignments
DiscreteValues findAssignments(const uint64_t idx) const
Find DiscreteValues for corresponding index.
Definition: TableFactor.cpp:588
Signature.h
signatures for conditional densities
gtsam::TableFactor::prune
TableFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: TableFactor.cpp:748
gtsam::TableFactor::sparseTable
Eigen::SparseVector< double > sparseTable() const
Getter for the underlying sparse vector.
Definition: TableFactor.h:166
gtsam::TableDistribution::sample
virtual size_t sample(const DiscreteValues &parentsValues, std::mt19937_64 *rng=nullptr) const override
Definition: TableDistribution.cpp:148
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::beginFrontals
DecisionTreeFactor ::const_iterator beginFrontals() const
Definition: Conditional.h:180
gtsam::Conditional
Definition: Conditional.h:63
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::nrFrontals
size_t nrFrontals() const
Definition: Conditional.h:131
gtsam::AlgebraicDecisionTree< Key >::sum
double sum() const
Compute sum of all values.
Definition: AlgebraicDecisionTree.h:211
key
const gtsam::Symbol key('X', 0)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
gtsam::DiscreteFactor::cardinality
size_t cardinality(Key j) const
Definition: DiscreteFactor.h:99
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::firstFrontalKey
Key firstFrontalKey() const
Definition: Conditional.h:137
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:38
gtsam::Factor::const_iterator
KeyVector::const_iterator const_iterator
Const iterator over keys.
Definition: Factor.h:83
gtsam
traits
Definition: ABC.h:17
gtsam::DiscreteFactor::shared_ptr
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Definition: DiscreteFactor.h:45
kRandomNumberGenerator
static std::mt19937_64 kRandomNumberGenerator(42)
Global default pseudo-random number generator object. In wrappers we can access std::mt19937_64 via g...
gtsam::TableDistribution::operator*
DiscreteFactor::shared_ptr operator*(double s) const override
Multiply by scalar s.
Definition: TableDistribution.cpp:115
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:143
gtsam::TableDistribution::prune
virtual void prune(size_t maxNrAssignments) override
Prune the conditional.
Definition: TableDistribution.cpp:143
Eigen::SparseVector< double >
std
Definition: BFloat16.h:88
gtsam::TableFactor::choose
TableFactor choose(const DiscreteValues parentAssignments, DiscreteKeys parent_keys) const
Create a TableFactor that is a subset of this TableFactor.
Definition: TableFactor.cpp:324
p
float * p
Definition: Tutorial_Map_using.cpp:9
Ring.h
Real Ring definition.
gtsam::TableDistribution::operator/
DiscreteFactor::shared_ptr operator/(const DiscreteFactor::shared_ptr &f) const override
divide by DiscreteFactor::shared_ptr f (safely)
Definition: TableDistribution.cpp:120
gtsam::TableDistribution::max
double max() const override
Find the maximum value in the factor.
Definition: TableDistribution.h:120
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::TableFactor::equals
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
Definition: TableFactor.cpp:208
gtsam::TableFactor::print
void print(const std::string &s="TableFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Definition: TableFactor.cpp:377
uint64_t
unsigned __int64 uint64_t
Definition: ms_stdint.h:95
gtsam::DiscreteFactor
Definition: DiscreteFactor.h:40
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::frontals
Frontals frontals() const
Definition: Conditional.h:145
gtsam::TableDistribution::print
void print(const std::string &s="Table Distribution: ", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style print.
Definition: TableDistribution.cpp:70
test_callbacks.value
value
Definition: test_callbacks.py:162
gtsam::normalizeSparseTable
static Eigen::SparseVector< double > normalizeSparseTable(const Eigen::SparseVector< double > &sparse_table)
Normalize sparse_table.
Definition: TableDistribution.cpp:42
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::TableDistribution::equals
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
GTSAM-style equals.
Definition: TableDistribution.cpp:82
debug.h
Global debugging flags.


gtsam
Author(s):
autogenerated on Wed May 28 2025 03:04:19