testGaussianMixture.cpp
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
27 #include <gtsam/inference/Key.h>
28 #include <gtsam/inference/Symbol.h>
31 
32 // Include for test suite
34 
35 using namespace gtsam;
38 
39 // Define mode key and an assignment m==1
40 const DiscreteKey m(M(0), 2);
41 const DiscreteValues m1Assignment{{M(0), 1}};
42 
43 // Define a 50/50 prior on the mode
45  std::make_shared<DiscreteConditional>(m, "60/40");
46 
48 double Gaussian(double mu, double sigma, double z) {
49  return exp(-0.5 * pow((z - mu) / sigma, 2)) / sqrt(2 * M_PI * sigma * sigma);
50 };
51 
57 double prob_m_z(double mu0, double mu1, double sigma0, double sigma1,
58  double z) {
59  const double p0 = 0.6 * Gaussian(mu0, sigma0, z);
60  const double p1 = 0.4 * Gaussian(mu1, sigma1, z);
61  return p1 / (p0 + p1);
62 };
63 
64 /*
65  * Test a Gaussian Mixture Model P(m)p(z|m) with same sigma.
66  * The posterior, as a function of z, should be a sigmoid function.
67  */
68 TEST(GaussianMixture, GaussianMixtureModel) {
69  double mu0 = 1.0, mu1 = 3.0;
70  double sigma = 2.0;
71 
72  // Create a Gaussian mixture model p(z|m) with same sigma.
73  HybridBayesNet gmm;
74  std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma},
75  {Vector1(mu1), sigma}};
77  gmm.push_back(mixing);
78 
79  // At the halfway point between the means, we should get P(m|z)=0.5
80  double midway = mu1 - mu0;
81  auto eliminationResult =
82  gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
83  auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
84  EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid));
85 
86  // Everywhere else, the result should be a sigmoid.
87  for (const double shift : {-4, -2, 0, 2, 4}) {
88  const double z = midway + shift;
89  const double expected = prob_m_z(mu0, mu1, sigma, sigma, z);
90 
91  // Workflow 1: convert HBN to HFG and solve
92  auto eliminationResult1 =
93  gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
94  auto posterior1 =
95  *eliminationResult1->at(0)->asDiscrete<TableDistribution>();
96  EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
97 
98  // Workflow 2: directly specify HFG and solve
101  m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
102  hfg1.push_back(mixing);
103  auto eliminationResult2 = hfg1.eliminateSequential();
104  auto posterior2 =
105  *eliminationResult2->at(0)->asDiscrete<TableDistribution>();
106  EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
107  }
108 }
109 
110 /*
111  * Test a Gaussian Mixture Model P(m)p(z|m) with different sigmas.
112  * The posterior, as a function of z, should be a unimodal function.
113  */
114 TEST(GaussianMixture, GaussianMixtureModel2) {
115  double mu0 = 1.0, mu1 = 3.0;
116  double sigma0 = 8.0, sigma1 = 4.0;
117 
118  // Create a Gaussian mixture model p(z|m) with same sigma.
119  HybridBayesNet gmm;
120  std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma0},
121  {Vector1(mu1), sigma1}};
123  gmm.push_back(mixing);
124 
125  // We get zMax=3.1333 by finding the maximum value of the function, at which
126  // point the mode m==1 is about twice as probable as m==0.
127  double zMax = 3.133;
128  const VectorValues vv{{Z(0), Vector1(zMax)}};
129  auto gfg = gmm.toFactorGraph(vv);
130 
131  // Equality of posteriors asserts that the elimination is correct (same ratios
132  // for all modes)
133  const auto& expectedDiscretePosterior = gmm.discretePosterior(vv);
134  EXPECT(assert_equal(expectedDiscretePosterior, gfg.discretePosterior(vv)));
135 
136  // Eliminate the graph!
137  auto eliminationResultMax = gfg.eliminateSequential();
138 
139  // Equality of posteriors asserts that the elimination is correct
140  // (same ratios for all modes)
141  EXPECT(assert_equal(expectedDiscretePosterior,
142  eliminationResultMax->discretePosterior(vv)));
143 
144  auto pMax = *eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
145  EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4));
146 
147  // Everywhere else, the result should be a bell curve like function.
148  for (const double shift : {-4, -2, 0, 2, 4}) {
149  const double z = zMax + shift;
150  const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z);
151 
152  // Workflow 1: convert HBN to HFG and solve
153  auto eliminationResult1 =
154  gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
155  auto posterior1 =
156  *eliminationResult1->at(0)->asDiscrete<TableDistribution>();
157  EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
158 
159  // Workflow 2: directly specify HFG and solve
162  m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
163  hfg.push_back(mixing);
164  auto eliminationResult2 = hfg.eliminateSequential();
165  auto posterior2 =
166  *eliminationResult2->at(0)->asDiscrete<TableDistribution>();
167  EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
168  }
169 }
170 
171 /* ************************************************************************* */
172 int main() {
173  TestResult tr;
174  return TestRegistry::runAllTests(tr);
175 }
176 /* ************************************************************************* */
TestRegistry::runAllTests
static int runAllTests(TestResult &result)
Definition: TestRegistry.cpp:27
Gaussian
double Gaussian(double mu, double sigma, double z)
Gaussian density function.
Definition: testGaussianMixture.cpp:48
gtsam::EliminateableFactorGraph::eliminateSequential
std::shared_ptr< BayesNetType > eliminateSequential(OptionalOrderingType orderingType={}, const Eliminate &function=EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex={}) const
Definition: EliminateableFactorGraph-inst.h:29
TableDistribution.h
gtsam::Vector1
Eigen::Matrix< double, 1, 1 > Vector1
Definition: Vector.h:42
simple_graph::sigma1
double sigma1
Definition: testJacobianFactor.cpp:193
gtsam::DecisionTreeFactor
Definition: DecisionTreeFactor.h:45
GaussianConditional.h
Conditional Gaussian Base class.
HybridGaussianConditional.h
A hybrid conditional in the Conditional Linear Gaussian scheme.
e
Array< double, 1, 3 > e(1./3., 0.5, 2.)
EXPECT
#define EXPECT(condition)
Definition: Test.h:150
TestHarness.h
gtsam::HybridBayesNet
Definition: HybridBayesNet.h:37
gtsam::symbol_shorthand::M
Key M(std::uint64_t j)
Definition: inference/Symbol.h:160
mu
double mu
Definition: testBoundingConstraint.cpp:37
HybridBayesNet.h
A Bayes net of Gaussian Conditionals indexed by discrete keys.
vv
static const VectorValues vv
Definition: testHybridGaussianConditional.cpp:45
DiscreteConditional.h
exp
const EIGEN_DEVICE_FUNC ExpReturnType exp() const
Definition: ArrayCwiseUnaryOps.h:97
sampling::sigma
static const double sigma
Definition: testGaussianBayesNet.cpp:170
gtsam::FactorGraph::at
const sharedFactor at(size_t i) const
Definition: FactorGraph.h:306
Key.h
gtsam::VectorValues
Definition: VectorValues.h:74
gtsam::TableDistribution
Definition: TableDistribution.h:39
gtsam::symbol_shorthand::Z
Key Z(std::uint64_t j)
Definition: inference/Symbol.h:173
parameters
static ConjugateGradientParameters parameters
Definition: testIterative.cpp:33
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:916
gtsam::HybridGaussianConditional
A conditional of gaussian conditionals indexed by discrete variables, as part of a Bayes Network....
Definition: HybridGaussianConditional.h:55
gtsam::HybridGaussianFactorGraph
Definition: HybridGaussianFactorGraph.h:106
main
int main()
Definition: testGaussianMixture.cpp:172
Symbol.h
pybind_wrapper_test_script.z
z
Definition: pybind_wrapper_test_script.py:61
ceres::pow
Jet< T, N > pow(const Jet< T, N > &f, double g)
Definition: jet.h:570
EXPECT_DOUBLES_EQUAL
#define EXPECT_DOUBLES_EQUAL(expected, actual, threshold)
Definition: Test.h:161
p1
Vector3f p1
Definition: MatrixBase_all.cpp:2
TestResult
Definition: TestResult.h:26
DiscreteKey.h
specialized key for discrete variables
gtsam::DiscreteConditional::shared_ptr
std::shared_ptr< This > shared_ptr
shared_ptr to this class
Definition: DiscreteConditional.h:43
gtsam::HybridBayesNet::discretePosterior
AlgebraicDecisionTree< Key > discretePosterior(const VectorValues &continuousValues) const
Compute normalized posterior P(M|X=x) and return as a tree.
Definition: HybridBayesNet.cpp:256
prob_m_z
double prob_m_z(double mu0, double mu1, double sigma0, double sigma1, double z)
Definition: testGaussianMixture.cpp:57
gtsam::HybridBayesNet::push_back
void push_back(std::shared_ptr< HybridConditional > conditional)
Add a hybrid conditional using a shared_ptr.
Definition: HybridBayesNet.h:76
mixing
DiscreteConditional::shared_ptr mixing
Definition: testGaussianMixture.cpp:44
HybridGaussianFactorGraph.h
Linearized Hybrid factor graph that uses type erasure.
gtsam
traits
Definition: SFMdata.h:40
gtsam::TEST
TEST(SmartFactorBase, Pinhole)
Definition: testSmartFactorBase.cpp:38
NoiseModel.h
gtsam::DiscreteValues
Definition: DiscreteValues.h:34
m1Assignment
const DiscreteValues m1Assignment
Definition: testGaussianMixture.cpp:41
gtsam::FactorGraph::push_back
IsDerived< DERIVEDFACTOR > push_back(std::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:147
p0
Vector3f p0
Definition: MatrixBase_all.cpp:2
gtsam::DiscreteKey
std::pair< Key, size_t > DiscreteKey
Definition: DiscreteKey.h:38
gtsam::assert_equal
bool assert_equal(const Matrix &expected, const Matrix &actual, double tol)
Definition: Matrix.cpp:41
m
const DiscreteKey m(M(0), 2)
M_PI
#define M_PI
Definition: mconf.h:117
gtsam::HybridBayesNet::toFactorGraph
HybridGaussianFactorGraph toFactorGraph(const VectorValues &measurements) const
Definition: HybridBayesNet.cpp:270
Z
#define Z
Definition: icosphere.cpp:21
DecisionTreeFactor.h
ceres::sqrt
Jet< T, N > sqrt(const Jet< T, N > &f)
Definition: jet.h:418
gtsam::FactorGraph::emplace_shared
IsDerived< DERIVEDFACTOR > emplace_shared(Args &&... args)
Emplace a shared pointer to factor of given type.
Definition: FactorGraph.h:153
gtsam::HybridBayesNet::emplace_shared
void emplace_shared(Args &&...args)
Definition: HybridBayesNet.h:116
M
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:51


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:06:42