Program Listing for File sample.hpp

Return to documentation for file (include/beluga/views/sample.hpp)

// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_VIEWS_SAMPLE_HPP
#define BELUGA_VIEWS_SAMPLE_HPP

#include <random>

#include <range/v3/utility/random.hpp>
#include <range/v3/view/common.hpp>
#include <range/v3/view/generate.hpp>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/views/particles.hpp>

namespace beluga::views {

namespace detail {


template <class Range, class Distribution, class URNG = typename ranges::detail::default_random_engine>
struct sample_view : public ranges::view_facade<sample_view<Range, Distribution, URNG>, ranges::infinite> {
 public:
  sample_view() = default;


  constexpr sample_view(Range range, Distribution distribution, URNG& engine = ranges::detail::get_random_engine())
      : range_{std::move(range)}, distribution_{std::move(distribution)}, engine_{std::addressof(engine)} {
    assert(ranges::size(range) > 0);
    assert(distribution_.min() == 0);
    assert(distribution_.max() == static_cast<typename Distribution::result_type>(ranges::size(range_)) - 1);
  }

 private:
  // `ranges::range_access` needs access to the cursor members.
  friend ranges::range_access;

  static_assert(ranges::sized_range<Range>);
  static_assert(ranges::random_access_range<Range>);
  static_assert(std::is_same_v<typename Distribution::result_type, ranges::range_difference_t<Range>>);

  struct cursor {
   public:
    cursor() = default;

    constexpr explicit cursor(sample_view* view)
        : view_(view), first_{ranges::begin(view_->range_)}, it_{first_ + view_->compute_offset()} {}

    [[nodiscard]] constexpr decltype(auto) read() const noexcept(noexcept(*this->it_)) { return *it_; }

    constexpr void next() { it_ = first_ + view_->compute_offset(); }

   private:
    sample_view* view_;
    ranges::iterator_t<Range> first_;
    ranges::iterator_t<Range> it_;
  };

  [[nodiscard]] constexpr auto begin_cursor() { return cursor{this}; }

  [[nodiscard]] constexpr auto end_cursor() const noexcept { return ranges::unreachable_sentinel_t{}; }

  [[nodiscard]] constexpr auto compute_offset() { return distribution_(*engine_); }

  Range range_;
  Distribution distribution_;
  URNG* engine_;
};


template <class T, class Enable = void>
struct is_random_distribution : public std::false_type {};

template <class T>
struct is_random_distribution<T, std::void_t<decltype(std::declval<T&>()(std::declval<std::mt19937&>()))>>
    : std::true_type {};

template <class T>
inline constexpr bool is_random_distribution_v = is_random_distribution<T>::value;


struct sample_base_fn {
 protected:
  template <class Range, class Weights, class URNG>
  constexpr auto sample_from_range(Range&& range, Weights&& weights, URNG& engine) const {
    static_assert(ranges::sized_range<Range>);
    static_assert(ranges::random_access_range<Range>);
    static_assert(ranges::input_range<Weights>);
    using result_type = ranges::range_difference_t<Range>;
    auto w = ranges::views::common(weights);
    auto distribution = std::discrete_distribution<result_type>{ranges::begin(w), ranges::end(w)};
    return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
  }


  template <class Range, class URNG>
  constexpr auto sample_from_range(Range&& range, URNG& engine) const {
    static_assert(ranges::sized_range<Range>);
    static_assert(ranges::random_access_range<Range>);
    if constexpr (beluga::is_particle_range_v<Range>) {
      return sample_from_range(beluga::views::states(range), beluga::views::weights(range), engine) |
             ranges::views::transform(beluga::make_from_state<ranges::range_value_t<Range>>);
    } else {
      using result_type = ranges::range_difference_t<Range>;
      auto distribution =
          std::uniform_int_distribution<result_type>{0, static_cast<result_type>(ranges::size(range) - 1)};
      return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
    }
  }

  template <class Distribution, class URNG>
  constexpr auto sample_from_distribution(Distribution distribution, URNG& engine) const {
    return ranges::views::generate(
        [distribution = std::move(distribution), &engine]() mutable { return distribution(engine); });
  }
};

struct sample_fn : public sample_base_fn {
  template <class T, class U, class V>
  constexpr auto operator()(T&& t, U&& u, V& v) const {
    static_assert(ranges::range<T>);
    static_assert(ranges::range<U>);
    return sample_from_range(std::forward<T>(t), std::forward<U>(u), v);  // Assume V is a URNG
  }

  template <class T, class U>
  constexpr auto operator()(T&& t, U&& u) const {
    if constexpr (ranges::range<T> && ranges::range<U>) {
      auto& engine = ranges::detail::get_random_engine();
      return sample_from_range(std::forward<T>(t), std::forward<U>(u), engine);
    } else if constexpr (is_random_distribution_v<T>) {
      static_assert(std::is_lvalue_reference_v<U&&>);  // Assume U is a URNG
      return sample_from_distribution(std::forward<T>(t), u);
    } else {
      static_assert(ranges::range<T>);
      static_assert(std::is_lvalue_reference_v<U&&>);  // Assume U is a URNG
      return sample_from_range(std::forward<T>(t), u);
    }
  }

  template <class T>
  constexpr auto operator()(T&& t) const {
    if constexpr (ranges::range<T>) {
      auto& engine = ranges::detail::get_random_engine();
      return sample_from_range(std::forward<T>(t), engine);
    } else if constexpr (is_random_distribution_v<T>) {
      auto& engine = ranges::detail::get_random_engine();
      return sample_from_distribution(std::forward<T>(t), engine);
    } else {
      static_assert(std::is_lvalue_reference_v<T&&>);  // Assume T is a URNG
      return ranges::make_view_closure(ranges::bind_back(sample_fn{}, std::ref(t)));
    }
  }

  template <class Range, class URNG>
  constexpr auto operator()(Range&& range, std::reference_wrapper<URNG> engine) const {
    static_assert(ranges::range<Range>);
    return sample_from_range(std::forward<Range>(range), engine.get());
  }
};

}  // namespace detail


inline constexpr ranges::views::view_closure<detail::sample_fn> sample;

}  // namespace beluga::views

#endif