Program Listing for File multivariate_normal_distribution.hpp
↰ Return to documentation for file (include/beluga/random/multivariate_normal_distribution.hpp
)
// Copyright 2023-2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef BELUGA_RANDOM_MULTIVARIATE_NORMAL_DISTRIBUTION_HPP
#define BELUGA_RANDOM_MULTIVARIATE_NORMAL_DISTRIBUTION_HPP
#include <random>
#include <utility>
#include <beluga/random/multivariate_distribution_traits.hpp>
namespace beluga {
template <class Vector, class Matrix>
class MultivariateNormalDistributionParam {
public:
static_assert(std::is_base_of_v<Eigen::EigenBase<Vector>, Vector>, "Vector should be an Eigen type");
static_assert(
Vector::ColsAtCompileTime == 1 || Vector::RowsAtCompileTime == 1,
"Vector should be a column or row vector");
using scalar_type = typename Vector::Scalar;
using vector_type = Vector;
using matrix_type = Matrix;
MultivariateNormalDistributionParam() = default;
explicit MultivariateNormalDistributionParam(matrix_type covariance)
: transform_{make_transform(std::move(covariance))} {}
MultivariateNormalDistributionParam(vector_type mean, matrix_type covariance)
: mean_{std::move(mean)}, transform_{make_transform(std::move(covariance))} {}
[[nodiscard]] bool operator==(const MultivariateNormalDistributionParam& other) const {
return mean_ == other.mean_ && transform_ == other.transform_;
}
[[nodiscard]] bool operator!=(const MultivariateNormalDistributionParam& other) const { return !(*this == other); }
template <class Generator>
[[nodiscard]] auto operator()(std::normal_distribution<scalar_type>& distribution, Generator& generator) const {
const auto delta = vector_type{}.unaryExpr([&distribution, &generator](auto) { return distribution(generator); });
if constexpr (vector_type::ColsAtCompileTime == 1) {
return mean_ + transform_ * delta;
} else {
return mean_ + delta * transform_;
}
}
private:
vector_type mean_{vector_type::Zero()};
matrix_type transform_{make_transform(vector_type::Ones().asDiagonal())};
[[nodiscard]] static matrix_type make_transform(matrix_type covariance) {
// For more information about the method used to generate correlated normal vectors
// from independent normal distributions, see:
// https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Drawing_values_from_the_distribution
// \cite gentle2009computationalstatistics.
if (!covariance.isApprox(covariance.transpose())) {
throw std::runtime_error("Invalid covariance matrix, it is not symmetric.");
}
const auto solver = Eigen::SelfAdjointEigenSolver<matrix_type>{covariance};
if (solver.info() != Eigen::Success) {
throw std::runtime_error("Invalid covariance matrix, eigen solver failed.");
}
const auto& eigenvalues = solver.eigenvalues();
if ((eigenvalues.array() < 0.0).any()) {
throw std::runtime_error("Invalid covariance matrix, it has negative eigenvalues.");
}
return solver.eigenvectors() * eigenvalues.cwiseSqrt().asDiagonal();
}
};
template <class T>
class MultivariateNormalDistribution {
public:
template <class U>
friend class MultivariateNormalDistribution;
using scalar_type = typename multivariate_distribution_traits<T>::scalar_type;
using vector_type = typename multivariate_distribution_traits<T>::vector_type;
using covariance_type = typename multivariate_distribution_traits<T>::covariance_type;
using result_type = typename multivariate_distribution_traits<T>::result_type;
using param_type = MultivariateNormalDistributionParam<vector_type, covariance_type>;
MultivariateNormalDistribution() = default;
explicit MultivariateNormalDistribution(const param_type& params) : params_{params} {}
explicit MultivariateNormalDistribution(covariance_type covariance) : params_{std::move(covariance)} {}
MultivariateNormalDistribution(result_type mean, covariance_type covariance)
: params_{multivariate_distribution_traits<T>::to_vector(std::move(mean)), std::move(covariance)} {}
template <class U>
/* implicit */ MultivariateNormalDistribution(const MultivariateNormalDistribution<U>& other) // NOLINT
: params_{other.params_}, distribution_{other.distribution_} {}
template <class U>
/* implicit */ MultivariateNormalDistribution(MultivariateNormalDistribution<U>&& other) noexcept // NOLINT
: params_(std::move(other.params_)), distribution_{std::move(other.distribution_)} {}
template <class U>
MultivariateNormalDistribution& operator=(const MultivariateNormalDistribution<U>& other) {
params_ = other.params_;
distribution_ = other.distribution_;
return *this;
}
template <class U>
MultivariateNormalDistribution& operator=(MultivariateNormalDistribution<U>&& other) {
params_ = std::move(other.params_);
distribution_ = std::move(other.distribution_);
return *this;
}
void reset() { distribution_.reset(); }
[[nodiscard]] const param_type& param() const { return params_; }
void param(const param_type& params) { params_ = params; }
template <class Generator>
[[nodiscard]] result_type operator()(Generator& generator) {
return (*this)(generator, params_);
}
template <class Generator>
[[nodiscard]] result_type operator()(Generator& generator, const param_type& params) {
return multivariate_distribution_traits<T>::from_vector(params(distribution_, generator));
}
[[nodiscard]] bool operator==(const MultivariateNormalDistribution<T>& other) const {
return params_ == other.params_ && distribution_ == other.distribution_;
}
[[nodiscard]] bool operator!=(const MultivariateNormalDistribution<T>& other) const { return !(*this == other); }
private:
param_type params_;
std::normal_distribution<scalar_type> distribution_;
};
template <class T>
MultivariateNormalDistribution(const T&, const typename multivariate_distribution_traits<T>::covariance_type&)
-> MultivariateNormalDistribution<typename multivariate_distribution_traits<T>::result_type>;
} // namespace beluga
#endif