Program Listing for File multivariate_normal_distribution.hpp

Return to documentation for file (include/beluga/random/multivariate_normal_distribution.hpp)

// Copyright 2023-2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_RANDOM_MULTIVARIATE_NORMAL_DISTRIBUTION_HPP
#define BELUGA_RANDOM_MULTIVARIATE_NORMAL_DISTRIBUTION_HPP

#include <random>
#include <utility>

#include <beluga/random/multivariate_distribution_traits.hpp>

namespace beluga {

template <class Vector, class Matrix>
class MultivariateNormalDistributionParam {
 public:
  static_assert(std::is_base_of_v<Eigen::EigenBase<Vector>, Vector>, "Vector should be an Eigen type");
  static_assert(
      Vector::ColsAtCompileTime == 1 || Vector::RowsAtCompileTime == 1,
      "Vector should be a column or row vector");

  using scalar_type = typename Vector::Scalar;

  using vector_type = Vector;

  using matrix_type = Matrix;

  MultivariateNormalDistributionParam() = default;


  explicit MultivariateNormalDistributionParam(matrix_type covariance)
      : transform_{make_transform(std::move(covariance))} {}


  MultivariateNormalDistributionParam(vector_type mean, matrix_type covariance)
      : mean_{std::move(mean)}, transform_{make_transform(std::move(covariance))} {}


  [[nodiscard]] bool operator==(const MultivariateNormalDistributionParam& other) const {
    return mean_ == other.mean_ && transform_ == other.transform_;
  }


  [[nodiscard]] bool operator!=(const MultivariateNormalDistributionParam& other) const { return !(*this == other); }


  template <class Generator>
  [[nodiscard]] auto operator()(std::normal_distribution<scalar_type>& distribution, Generator& generator) const {
    const auto delta = vector_type{}.unaryExpr([&distribution, &generator](auto) { return distribution(generator); });
    if constexpr (vector_type::ColsAtCompileTime == 1) {
      return mean_ + transform_ * delta;
    } else {
      return mean_ + delta * transform_;
    }
  }

 private:
  vector_type mean_{vector_type::Zero()};
  matrix_type transform_{make_transform(vector_type::Ones().asDiagonal())};

  [[nodiscard]] static matrix_type make_transform(matrix_type covariance) {
    // For more information about the method used to generate correlated normal vectors
    // from independent normal distributions, see:
    // https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Drawing_values_from_the_distribution
    // \cite gentle2009computationalstatistics.
    if (!covariance.isApprox(covariance.transpose())) {
      throw std::runtime_error("Invalid covariance matrix, it is not symmetric.");
    }
    const auto solver = Eigen::SelfAdjointEigenSolver<matrix_type>{covariance};
    if (solver.info() != Eigen::Success) {
      throw std::runtime_error("Invalid covariance matrix, eigen solver failed.");
    }
    const auto& eigenvalues = solver.eigenvalues();
    if ((eigenvalues.array() < 0.0).any()) {
      throw std::runtime_error("Invalid covariance matrix, it has negative eigenvalues.");
    }
    return solver.eigenvectors() * eigenvalues.cwiseSqrt().asDiagonal();
  }
};


template <class T>
class MultivariateNormalDistribution {
 public:
  template <class U>
  friend class MultivariateNormalDistribution;

  using scalar_type = typename multivariate_distribution_traits<T>::scalar_type;

  using vector_type = typename multivariate_distribution_traits<T>::vector_type;

  using covariance_type = typename multivariate_distribution_traits<T>::covariance_type;

  using result_type = typename multivariate_distribution_traits<T>::result_type;

  using param_type = MultivariateNormalDistributionParam<vector_type, covariance_type>;

  MultivariateNormalDistribution() = default;


  explicit MultivariateNormalDistribution(const param_type& params) : params_{params} {}


  explicit MultivariateNormalDistribution(covariance_type covariance) : params_{std::move(covariance)} {}


  MultivariateNormalDistribution(result_type mean, covariance_type covariance)
      : params_{multivariate_distribution_traits<T>::to_vector(std::move(mean)), std::move(covariance)} {}


  template <class U>
  /* implicit */ MultivariateNormalDistribution(const MultivariateNormalDistribution<U>& other)  // NOLINT
      : params_{other.params_}, distribution_{other.distribution_} {}


  template <class U>
  /* implicit */ MultivariateNormalDistribution(MultivariateNormalDistribution<U>&& other) noexcept  // NOLINT
      : params_(std::move(other.params_)), distribution_{std::move(other.distribution_)} {}


  template <class U>
  MultivariateNormalDistribution& operator=(const MultivariateNormalDistribution<U>& other) {
    params_ = other.params_;
    distribution_ = other.distribution_;
    return *this;
  }


  template <class U>
  MultivariateNormalDistribution& operator=(MultivariateNormalDistribution<U>&& other) {
    params_ = std::move(other.params_);
    distribution_ = std::move(other.distribution_);
    return *this;
  }

  void reset() { distribution_.reset(); }

  [[nodiscard]] const param_type& param() const { return params_; }


  void param(const param_type& params) { params_ = params; }


  template <class Generator>
  [[nodiscard]] result_type operator()(Generator& generator) {
    return (*this)(generator, params_);
  }


  template <class Generator>
  [[nodiscard]] result_type operator()(Generator& generator, const param_type& params) {
    return multivariate_distribution_traits<T>::from_vector(params(distribution_, generator));
  }


  [[nodiscard]] bool operator==(const MultivariateNormalDistribution<T>& other) const {
    return params_ == other.params_ && distribution_ == other.distribution_;
  }


  [[nodiscard]] bool operator!=(const MultivariateNormalDistribution<T>& other) const { return !(*this == other); }

 private:
  param_type params_;
  std::normal_distribution<scalar_type> distribution_;
};

template <class T>
MultivariateNormalDistribution(const T&, const typename multivariate_distribution_traits<T>::covariance_type&)
    -> MultivariateNormalDistribution<typename multivariate_distribution_traits<T>::result_type>;

}  // namespace beluga

#endif