Program Listing for File sample.hpp
↰ Return to documentation for file (include/beluga/views/sample.hpp
)
// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef BELUGA_VIEWS_SAMPLE_HPP
#define BELUGA_VIEWS_SAMPLE_HPP
#include <random>
#include <range/v3/utility/random.hpp>
#include <range/v3/view/common.hpp>
#include <range/v3/view/generate.hpp>
#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/views/particles.hpp>
namespace beluga::views {
namespace detail {
template <class Range, class Distribution, class URNG = typename ranges::detail::default_random_engine>
struct sample_view : public ranges::view_facade<sample_view<Range, Distribution, URNG>, ranges::infinite> {
public:
sample_view() = default;
constexpr sample_view(Range range, Distribution distribution, URNG& engine = ranges::detail::get_random_engine())
: range_{std::move(range)}, distribution_{std::move(distribution)}, engine_{std::addressof(engine)} {
assert(ranges::size(range) > 0);
assert(distribution_.min() == 0);
assert(distribution_.max() == static_cast<typename Distribution::result_type>(ranges::size(range_)) - 1);
}
private:
// `ranges::range_access` needs access to the cursor members.
friend ranges::range_access;
static_assert(ranges::sized_range<Range>);
static_assert(ranges::random_access_range<Range>);
static_assert(std::is_same_v<typename Distribution::result_type, ranges::range_difference_t<Range>>);
struct cursor {
public:
cursor() = default;
constexpr explicit cursor(sample_view* view)
: view_(view), first_{ranges::begin(view_->range_)}, it_{first_ + view_->compute_offset()} {}
[[nodiscard]] constexpr decltype(auto) read() const noexcept(noexcept(*this->it_)) { return *it_; }
constexpr void next() { it_ = first_ + view_->compute_offset(); }
private:
sample_view* view_;
ranges::iterator_t<Range> first_;
ranges::iterator_t<Range> it_;
};
[[nodiscard]] constexpr auto begin_cursor() { return cursor{this}; }
[[nodiscard]] constexpr auto end_cursor() const noexcept { return ranges::unreachable_sentinel_t{}; }
[[nodiscard]] constexpr auto compute_offset() { return distribution_(*engine_); }
Range range_;
Distribution distribution_;
URNG* engine_;
};
template <class T, class Enable = void>
struct is_random_distribution : public std::false_type {};
template <class T>
struct is_random_distribution<T, std::void_t<decltype(std::declval<T&>()(std::declval<std::mt19937&>()))>>
: std::true_type {};
template <class T>
inline constexpr bool is_random_distribution_v = is_random_distribution<T>::value;
struct sample_base_fn {
protected:
template <class Range, class Weights, class URNG>
constexpr auto sample_from_range(Range&& range, Weights&& weights, URNG& engine) const {
static_assert(ranges::sized_range<Range>);
static_assert(ranges::random_access_range<Range>);
static_assert(ranges::input_range<Weights>);
using result_type = ranges::range_difference_t<Range>;
auto w = ranges::views::common(weights);
auto distribution = std::discrete_distribution<result_type>{ranges::begin(w), ranges::end(w)};
return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
}
template <class Range, class URNG>
constexpr auto sample_from_range(Range&& range, URNG& engine) const {
static_assert(ranges::sized_range<Range>);
static_assert(ranges::random_access_range<Range>);
if constexpr (beluga::is_particle_range_v<Range>) {
return sample_from_range(beluga::views::states(range), beluga::views::weights(range), engine) |
ranges::views::transform(beluga::make_from_state<ranges::range_value_t<Range>>);
} else {
using result_type = ranges::range_difference_t<Range>;
auto distribution =
std::uniform_int_distribution<result_type>{0, static_cast<result_type>(ranges::size(range) - 1)};
return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
}
}
template <class Distribution, class URNG>
constexpr auto sample_from_distribution(Distribution distribution, URNG& engine) const {
return ranges::views::generate(
[distribution = std::move(distribution), &engine]() mutable { return distribution(engine); });
}
};
struct sample_fn : public sample_base_fn {
template <class T, class U, class V>
constexpr auto operator()(T&& t, U&& u, V& v) const {
static_assert(ranges::range<T>);
static_assert(ranges::range<U>);
return sample_from_range(std::forward<T>(t), std::forward<U>(u), v); // Assume V is a URNG
}
template <class T, class U>
constexpr auto operator()(T&& t, U&& u) const {
if constexpr (ranges::range<T> && ranges::range<U>) {
auto& engine = ranges::detail::get_random_engine();
return sample_from_range(std::forward<T>(t), std::forward<U>(u), engine);
} else if constexpr (is_random_distribution_v<T>) {
static_assert(std::is_lvalue_reference_v<U&&>); // Assume U is a URNG
return sample_from_distribution(std::forward<T>(t), u);
} else {
static_assert(ranges::range<T>);
static_assert(std::is_lvalue_reference_v<U&&>); // Assume U is a URNG
return sample_from_range(std::forward<T>(t), u);
}
}
template <class T>
constexpr auto operator()(T&& t) const {
if constexpr (ranges::range<T>) {
auto& engine = ranges::detail::get_random_engine();
return sample_from_range(std::forward<T>(t), engine);
} else if constexpr (is_random_distribution_v<T>) {
auto& engine = ranges::detail::get_random_engine();
return sample_from_distribution(std::forward<T>(t), engine);
} else {
static_assert(std::is_lvalue_reference_v<T&&>); // Assume T is a URNG
return ranges::make_view_closure(ranges::bind_back(sample_fn{}, std::ref(t)));
}
}
template <class Range, class URNG>
constexpr auto operator()(Range&& range, std::reference_wrapper<URNG> engine) const {
static_assert(ranges::range<Range>);
return sample_from_range(std::forward<Range>(range), engine.get());
}
};
} // namespace detail
inline constexpr ranges::views::view_closure<detail::sample_fn> sample;
} // namespace beluga::views
#endif