TableDistribution.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
18 #include <gtsam/base/Testable.h>
19 #include <gtsam/base/debug.h>
20 #include <gtsam/discrete/Ring.h>
24 
25 #include <algorithm>
26 #include <cassert>
27 #include <random>
28 #include <set>
29 #include <stdexcept>
30 #include <string>
31 #include <utility>
32 #include <vector>
33 
34 using namespace std;
35 using std::pair;
36 using std::stringstream;
37 using std::vector;
38 namespace gtsam {
39 
42  const Eigen::SparseVector<double>& sparse_table) {
43  return sparse_table / sparse_table.sum();
44 }
45 
46 /* ************************************************************************** */
47 TableDistribution::TableDistribution(const TableFactor& f)
48  : BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
49  table_(f / (*std::dynamic_pointer_cast<TableFactor>(
50  f.sum(f.keys().size())))) {}
51 
52 /* ************************************************************************** */
54  const std::vector<double>& potentials)
55  : BaseConditional(keys.size(), keys, ADT(nullptr)),
56  table_(TableFactor(
57  keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
58 }
59 
60 /* ************************************************************************** */
62  const std::string& potentials)
63  : BaseConditional(keys.size(), keys, ADT(nullptr)),
64  table_(TableFactor(
65  keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
66 }
67 
68 /* ************************************************************************** */
69 void TableDistribution::print(const string& s,
70  const KeyFormatter& formatter) const {
71  cout << s << " P( ";
72  for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
73  cout << formatter(*it) << " ";
74  }
75  cout << "):\n";
76  table_.print("", formatter);
77  cout << endl;
78 }
79 
80 /* ************************************************************************** */
81 bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
82  auto dtc = dynamic_cast<const TableDistribution*>(&other);
83  if (!dtc) {
84  return false;
85  } else {
86  const DiscreteConditional& f(
87  static_cast<const DiscreteConditional&>(other));
88  return table_.equals(dtc->table_, tol) &&
90  }
91 }
92 
93 /* ****************************************************************************/
95  return table_.sum(nrFrontals);
96 }
97 
98 /* ****************************************************************************/
100  return table_.sum(keys);
101 }
102 
103 /* ****************************************************************************/
105  return table_.max(nrFrontals);
106 }
107 
108 /* ****************************************************************************/
110  return table_.max(keys);
111 }
112 
113 /* ****************************************************************************/
115  return table_ * s;
116 }
117 
118 /* ****************************************************************************/
120  const DiscreteFactor::shared_ptr& f) const {
121  return table_ / f;
122 }
123 
124 /* ************************************************************************ */
126  uint64_t maxIdx = 0;
127  double maxValue = 0.0;
128 
130 
131  for (SparseIt it(sparseTable); it; ++it) {
132  if (it.value() > maxValue) {
133  maxIdx = it.index();
134  maxValue = it.value();
135  }
136  }
137 
138  return table_.findAssignments(maxIdx);
139 }
140 
141 /* ****************************************************************************/
142 void TableDistribution::prune(size_t maxNrAssignments) {
143  table_ = table_.prune(maxNrAssignments);
144 }
145 
146 /* ****************************************************************************/
147 size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
148  static mt19937 rng(2); // random number generator
149 
150  DiscreteKeys parentsKeys;
151  for (auto&& [key, _] : parentsValues) {
152  parentsKeys.push_back({key, table_.cardinality(key)});
153  }
154 
155  // Get the correct conditional distribution: P(F|S=parentsValues)
156  TableFactor pFS = table_.choose(parentsValues, parentsKeys);
157 
158  // TODO(Duy): only works for one key now, seems horribly slow this way
159  if (nrFrontals() != 1) {
160  throw std::invalid_argument(
161  "TableDistribution::sample can only be called on single variable "
162  "conditionals");
163  }
164  Key key = firstFrontalKey();
165  size_t nj = cardinality(key);
166  vector<double> p(nj);
168  for (size_t value = 0; value < nj; value++) {
169  frontals[key] = value;
170  p[value] = pFS(frontals); // P(F=value|S=parentsValues)
171  if (p[value] == 1.0) {
172  return value; // shortcut exit
173  }
174  }
175  std::discrete_distribution<size_t> distribution(p.begin(), p.end());
176  return distribution(rng);
177 }
178 
179 } // namespace gtsam
gtsam::DiscreteConditional::sample
size_t sample() const
Zero parent version.
Definition: DiscreteConditional.cpp:328
TableDistribution.h
gtsam::TableFactor
Definition: TableFactor.h:54
rng
static std::mt19937 rng
Definition: timeFactorOverhead.cpp:31
gtsam::TableFactor::sum
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override
Create new factor by summing all values with the same separator values.
Definition: TableFactor.cpp:393
Eigen::SparseVector::sum
Scalar sum() const
Definition: SparseRedux.h:41
s
RealScalar s
Definition: level1_cplx_impl.h:126
gtsam::TableDistribution::argmax
DiscreteValues argmax() const
Return assignment that maximizes value.
Definition: TableDistribution.cpp:125
gtsam::Conditional::equals
bool equals(const This &c, double tol=1e-9) const
Definition: Conditional-inst.h:41
Testable.h
Concept check for values that can be used in unit tests.
keys
const KeyVector keys
Definition: testRegularImplicitSchurFactor.cpp:40
gtsam::TableDistribution::table_
TableFactor table_
Definition: TableDistribution.h:41
formatter
const KeyFormatter & formatter
Definition: treeTraversal-inst.h:204
gtsam::DiscreteKeys
DiscreteKeys is a set of keys that can be assembled using the & operator.
Definition: DiscreteKey.h:41
Eigen::SparseCompressedBase::InnerIterator
Definition: SparseCompressedBase.h:158
gtsam::AlgebraicDecisionTree< Key >
size
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
gtsam::TableDistribution
Definition: TableDistribution.h:39
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::endFrontals
DecisionTreeFactor ::const_iterator endFrontals() const
Definition: Conditional.h:183
gtsam::TableDistribution::TableDistribution
TableDistribution()
Default constructor needed for serialization.
Definition: TableDistribution.h:58
gtsam::KeyFormatter
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
gtsam::TableFactor::max
double max() const override
Find the maximum value in the factor.
Definition: TableFactor.cpp:403
gtsam::TableFactor::findAssignments
DiscreteValues findAssignments(const uint64_t idx) const
Find DiscreteValues for corresponding index.
Definition: TableFactor.cpp:586
Signature.h
signatures for conditional densities
gtsam::TableFactor::prune
TableFactor prune(size_t maxNrAssignments) const
Prune the decision tree of discrete variables.
Definition: TableFactor.cpp:746
gtsam::TableFactor::sparseTable
Eigen::SparseVector< double > sparseTable() const
Getter for the underlying sparse vector.
Definition: TableFactor.h:166
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::beginFrontals
DecisionTreeFactor ::const_iterator beginFrontals() const
Definition: Conditional.h:180
gtsam::Conditional
Definition: Conditional.h:63
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::nrFrontals
size_t nrFrontals() const
Definition: Conditional.h:131
gtsam::AlgebraicDecisionTree< Key >::sum
double sum() const
Compute sum of all values.
Definition: AlgebraicDecisionTree.h:211
key
const gtsam::Symbol key('X', 0)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
gtsam::DiscreteFactor::cardinality
size_t cardinality(Key j) const
Definition: DiscreteFactor.h:99
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::firstFrontalKey
Key firstFrontalKey() const
Definition: Conditional.h:137
gtsam::DiscreteConditional
Definition: DiscreteConditional.h:37
gtsam::Factor::const_iterator
KeyVector::const_iterator const_iterator
Const iterator over keys.
Definition: Factor.h:83
gtsam
traits
Definition: SFMdata.h:40
gtsam::DiscreteFactor::shared_ptr
std::shared_ptr< DiscreteFactor > shared_ptr
shared_ptr to this class
Definition: DiscreteFactor.h:45
gtsam::TableDistribution::operator*
DiscreteFactor::shared_ptr operator*(double s) const override
Multiply by scalar s.
Definition: TableDistribution.cpp:114
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
gtsam::Factor::keys
const KeyVector & keys() const
Access the factor's involved variable keys.
Definition: Factor.h:143
gtsam::TableDistribution::prune
virtual void prune(size_t maxNrAssignments) override
Prune the conditional.
Definition: TableDistribution.cpp:142
Eigen::SparseVector< double >
std
Definition: BFloat16.h:88
gtsam::TableFactor::choose
TableFactor choose(const DiscreteValues parentAssignments, DiscreteKeys parent_keys) const
Create a TableFactor that is a subset of this TableFactor.
Definition: TableFactor.cpp:322
p
float * p
Definition: Tutorial_Map_using.cpp:9
Ring.h
Real Ring definition.
gtsam::TableDistribution::operator/
DiscreteFactor::shared_ptr operator/(const DiscreteFactor::shared_ptr &f) const override
divide by DiscreteFactor::shared_ptr f (safely)
Definition: TableDistribution.cpp:119
gtsam::TableDistribution::max
double max() const override
Find the maximum value in the factor.
Definition: TableDistribution.h:120
gtsam::tol
const G double tol
Definition: Group.h:79
gtsam::TableFactor::equals
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
equality
Definition: TableFactor.cpp:206
gtsam::TableFactor::print
void print(const std::string &s="TableFactor:\n", const KeyFormatter &formatter=DefaultKeyFormatter) const override
print
Definition: TableFactor.cpp:375
uint64_t
unsigned __int64 uint64_t
Definition: ms_stdint.h:95
gtsam::DiscreteFactor
Definition: DiscreteFactor.h:40
gtsam::Key
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:97
_
constexpr descr< N - 1 > _(char const (&text)[N])
Definition: descr.h:109
HybridValues.h
gtsam::Ordering
Definition: inference/Ordering.h:33
gtsam::Conditional< DecisionTreeFactor, DiscreteConditional >::frontals
Frontals frontals() const
Definition: Conditional.h:145
gtsam::TableDistribution::print
void print(const std::string &s="Table Distribution: ", const KeyFormatter &formatter=DefaultKeyFormatter) const override
GTSAM-style print.
Definition: TableDistribution.cpp:69
test_callbacks.value
value
Definition: test_callbacks.py:162
gtsam::normalizeSparseTable
static Eigen::SparseVector< double > normalizeSparseTable(const Eigen::SparseVector< double > &sparse_table)
Normalize sparse_table.
Definition: TableDistribution.cpp:41
pybind_wrapper_test_script.other
other
Definition: pybind_wrapper_test_script.py:42
gtsam::TableDistribution::equals
bool equals(const DiscreteFactor &other, double tol=1e-9) const override
GTSAM-style equals.
Definition: TableDistribution.cpp:81
debug.h
Global debugging flags.


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:04:38