Program Listing for File take_while_kld.hpp

Return to documentation for file (include/beluga/views/take_while_kld.hpp)

// Copyright 2023-2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_VIEWS_TAKE_WHILE_KLD_HPP
#define BELUGA_VIEWS_TAKE_WHILE_KLD_HPP

#include <cmath>
#include <unordered_set>

#include <range/v3/view/take.hpp>
#include <range/v3/view/take_while.hpp>

#include <beluga/type_traits/particle_traits.hpp>

namespace beluga {

namespace detail {

constexpr double kDefaultKldZ = 3.;

}  // namespace detail


inline auto kld_condition(std::size_t min, double epsilon, double z = beluga::detail::kDefaultKldZ) {
  auto target_size = [two_epsilon = 2 * epsilon, z](std::size_t k) {
    if (k <= 2U) {
      return std::numeric_limits<std::size_t>::max();
    }
    const double common = 2. / static_cast<double>(9 * (k - 1));
    const double base = 1. - common + std::sqrt(common) * z;
    const double result = (static_cast<double>(k - 1) / two_epsilon) * base * base * base;
    return static_cast<std::size_t>(std::ceil(result));
  };

  return [=, count = 0ULL, buckets = std::unordered_set<std::size_t>{}](std::size_t hash) mutable {
    count++;
    buckets.insert(hash);
    return count <= min || count <= target_size(buckets.size());
  };
}

namespace views {

namespace detail {

struct take_while_kld_fn {

  template <class Range, class Hasher, std::enable_if_t<ranges::range<Range>, int> = 0>
  constexpr auto operator()(
      Range&& range,
      Hasher hasher,
      std::size_t min,
      std::size_t max,
      double epsilon,
      double z = beluga::detail::kDefaultKldZ) const {
    static_assert(ranges::input_range<Range>);

    auto proj = [&hasher]() {
      if constexpr (std::is_invocable_r_v<std::size_t, Hasher, ranges::range_value_t<Range>>) {
        // Try to invoke the hasher with the range values by default.
        return std::move(hasher);
      } else {
        // If the above is not possible, assume this is a particle range and invoke
        // the hasher with the state element of each particle.
        static_assert(is_particle_range_v<Range>);
        static_assert(std::is_invocable_r_v<std::size_t, Hasher, beluga::state_t<ranges::range_value_t<Range>>>);
        return ranges::compose(std::move(hasher), beluga::state);
      }
    }();

    return ranges::views::all(std::forward<Range>(range)) |                                      //
           ranges::views::take_while(beluga::kld_condition(min, epsilon, z), std::move(proj)) |  //
           ranges::views::take(max);
  }


  template <class Hasher>
  constexpr auto operator()(
      Hasher hasher,
      std::size_t min,
      std::size_t max,
      double epsilon,
      double z = beluga::detail::kDefaultKldZ) const {
    return ranges::make_view_closure(ranges::bind_back(take_while_kld_fn{}, std::move(hasher), min, max, epsilon, z));
  }
};

}  // namespace detail


inline constexpr detail::take_while_kld_fn take_while_kld;

}  // namespace views

}  // namespace beluga

#endif