test_take_while_kld.cpp
Go to the documentation of this file.
1 // Copyright 2023 Ekumen, Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <gtest/gtest.h>
16 
17 #include <array>
18 #include <cstddef>
19 #include <tuple>
20 #include <vector>
21 
22 #include <range/v3/algorithm/equal.hpp>
23 #include <range/v3/iterator/operations.hpp>
24 #include <range/v3/range/access.hpp>
25 #include <range/v3/range/concepts.hpp>
26 #include <range/v3/range/conversion.hpp>
27 #include <range/v3/view/empty.hpp>
28 #include <range/v3/view/generate.hpp>
29 #include <range/v3/view/intersperse.hpp>
30 #include <range/v3/view/sample.hpp>
31 #include <range/v3/view/take_while.hpp>
32 
33 #include "beluga/primitives.hpp"
35 
36 namespace {
37 
38 inline constexpr auto identity = [](auto&& t) noexcept { return std::forward<decltype(t)>(t); };
39 
40 TEST(TakeWhileKld, ConceptChecksFromContiguousRange) {
41  const std::size_t min = 0;
42  const std::size_t max = 1200;
43  const double epsilon = 0.05;
44  auto input = std::array{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
45  auto output = beluga::views::take_while_kld(input, identity, min, max, epsilon);
46 
47  static_assert(ranges::common_range<decltype(input)>);
48  static_assert(!ranges::common_range<decltype(output)>);
49 
50  static_assert(!ranges::viewable_range<decltype(input)>);
51  static_assert(ranges::viewable_range<decltype(output)>);
52 
53  static_assert(ranges::forward_range<decltype(input)>);
54  static_assert(ranges::forward_range<decltype(output)>);
55 
56  static_assert(ranges::sized_range<decltype(input)>);
57  static_assert(!ranges::sized_range<decltype(output)>);
58 
59  static_assert(ranges::bidirectional_range<decltype(input)>);
60  static_assert(ranges::bidirectional_range<decltype(output)>);
61 
62  static_assert(ranges::random_access_range<decltype(input)>);
63  static_assert(ranges::random_access_range<decltype(output)>);
64 
65  static_assert(ranges::contiguous_range<decltype(input)>);
66  static_assert(!ranges::contiguous_range<decltype(output)>);
67 }
68 
69 TEST(TakeWhileKld, ConceptChecksFromInfiniteRange) {
70  const std::size_t min = 0;
71  const std::size_t max = 1200;
72  const double epsilon = 0.05;
73  auto input = ranges::views::generate([]() { return 1; });
74  auto output = beluga::views::take_while_kld(input, identity, min, max, epsilon);
75 
76  static_assert(!ranges::common_range<decltype(input)>);
77  static_assert(!ranges::common_range<decltype(output)>);
78 
79  static_assert(ranges::viewable_range<decltype(input)>);
80  static_assert(ranges::viewable_range<decltype(output)>);
81 
82  static_assert(!ranges::forward_range<decltype(input)>);
83  static_assert(!ranges::forward_range<decltype(output)>);
84 
85  static_assert(!ranges::sized_range<decltype(input)>);
86  static_assert(!ranges::sized_range<decltype(output)>);
87 
88  static_assert(!ranges::bidirectional_range<decltype(input)>);
89  static_assert(!ranges::bidirectional_range<decltype(output)>);
90 
91  static_assert(!ranges::random_access_range<decltype(input)>);
92  static_assert(!ranges::random_access_range<decltype(output)>);
93 
94  static_assert(!ranges::contiguous_range<decltype(input)>);
95  static_assert(!ranges::contiguous_range<decltype(output)>);
96 }
97 
98 class KldConditionWithParam : public ::testing::TestWithParam<std::tuple<double, std::size_t, std::size_t>> {};
99 
100 auto GenerateDistinctHashes(std::size_t count) {
101  return ranges::views::generate([count, hash = 0UL]() mutable {
102  if (hash < count) {
103  ++hash;
104  }
105  return hash;
106  });
107 }
108 
109 TEST_P(KldConditionWithParam, Minimum) {
110  const std::size_t cluster_count = std::get<1>(GetParam());
111  const std::size_t min = 1'000;
112  const double epsilon = 0.01;
113  const double kld_k = 0.95;
114  auto output = GenerateDistinctHashes(cluster_count) | //
115  ranges::views::take_while(beluga::kld_condition(min, epsilon, kld_k));
116  ASSERT_GE(ranges::distance(output), min);
117 }
118 
119 TEST_P(KldConditionWithParam, Limit) {
120  const double kld_k = std::get<0>(GetParam());
121  const std::size_t cluster_count = std::get<1>(GetParam());
122  const std::size_t expected_count = std::get<2>(GetParam());
123  const std::size_t min = 0;
124  const double epsilon = 0.01;
125  auto output = GenerateDistinctHashes(cluster_count) | //
126  ranges::views::take_while(beluga::kld_condition(min, epsilon, kld_k));
127  ASSERT_EQ(ranges::distance(output), expected_count);
128 }
129 
130 constexpr double kPercentile90th = 1.28155156327703;
131 constexpr double kPercentile99th = 2.32634787735669;
132 
133 INSTANTIATE_TEST_SUITE_P(
134  KldPairs,
135  KldConditionWithParam,
136  testing::Values(
137  std::make_tuple(kPercentile90th, 3, 228),
138  std::make_tuple(kPercentile90th, 4, 311),
139  std::make_tuple(kPercentile90th, 5, 388),
140  std::make_tuple(kPercentile90th, 6, 461),
141  std::make_tuple(kPercentile90th, 7, 531),
142  std::make_tuple(kPercentile90th, 100, 5871),
143  std::make_tuple(kPercentile99th, 3, 462),
144  std::make_tuple(kPercentile99th, 4, 569),
145  std::make_tuple(kPercentile99th, 5, 666),
146  std::make_tuple(kPercentile99th, 6, 756),
147  std::make_tuple(kPercentile99th, 7, 843),
148  std::make_tuple(kPercentile99th, 100, 6733)));
149 
150 TEST(TakeWhileKld, TakeZero) {
151  const std::size_t min = 2;
152  const std::size_t max = 3;
153  const double epsilon = 0.1;
154  auto output = ranges::views::empty<std::size_t> | //
155  beluga::views::take_while_kld(identity, min, max, epsilon);
156  ASSERT_EQ(ranges::distance(output), 0);
157 }
158 
159 TEST(TakeWhileKld, TakeMaximum) {
160  const std::size_t min = 200;
161  const std::size_t max = 1200;
162  const double epsilon = 0.05;
163  auto output = ranges::views::generate([]() { return 1UL; }) | //
164  beluga::views::take_while_kld(identity, min, max, epsilon);
165  ASSERT_EQ(ranges::distance(output), max);
166 }
167 
168 TEST(TakeWhileKld, TakeLimit) {
169  const std::size_t min = 0;
170  const std::size_t max = 1200;
171  const double epsilon = 0.05;
172  auto output = ranges::views::generate([]() { return 1UL; }) | //
173  ranges::views::intersperse(2UL) | //
174  ranges::views::intersperse(3UL) | //
175  beluga::views::take_while_kld(identity, min, max, epsilon);
176  ASSERT_EQ(ranges::distance(output), 135);
177 }
178 
179 TEST(TakeWhileKld, TakeMinimum) {
180  const std::size_t min = 200;
181  const std::size_t max = 1200;
182  const double epsilon = 0.05;
183  auto output = ranges::views::generate([]() { return 1UL; }) | //
184  ranges::views::intersperse(2UL) | //
185  ranges::views::intersperse(3UL) | //
186  beluga::views::take_while_kld(identity, min, max, epsilon);
187  ASSERT_EQ(ranges::distance(output), min);
188 }
189 
190 TEST(TakeWhileKld, FromParticleRange) {
191  struct State {};
192  auto input = std::array{std::make_tuple(State{}, beluga::Weight(2.0))};
193  const std::size_t min = 2;
194  const std::size_t max = 3;
195  const double epsilon = 0.1;
196  const auto hasher = [](auto) { return 42; };
197  auto output = input | beluga::views::take_while_kld(hasher, min, max, epsilon);
198  ASSERT_EQ(ranges::distance(output), 1);
199  auto [state, weight] = *ranges::begin(output);
200  ASSERT_EQ(weight, 2.0);
201 }
202 
203 TEST(TakeWhileKld, ValueHashCorrespondence) {
204  // Make sure that the hasher is called on each particle in order.
205  // Even when the sample is random, the correspondence should be kept
206  auto input = std::array{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
207  const std::size_t min = 2;
208  const std::size_t max = 5;
209  const double epsilon = 0.1;
210  auto hashes = std::vector<std::size_t>{};
211  const auto hasher = [&hashes](auto value) {
212  hashes.push_back(static_cast<std::size_t>(value));
213  return value;
214  };
215  auto output = input | //
216  ranges::views::sample(5) | //
217  beluga::views::take_while_kld(hasher, min, max, epsilon) | //
218  ranges::to<std::vector>;
219  ASSERT_TRUE(ranges::equal(hashes, output));
220 }
221 
222 TEST(TakeWhileKld, HashStoredInParticle) {
223  struct Particle {
224  int state;
225  double weight;
226  std::size_t hash;
227  };
228  auto input = std::array{Particle{1, 1.0, 42}, Particle{2, 1.0, 43}, Particle{3, 1.0, 44}};
229  const std::size_t min = 2;
230  const std::size_t max = 5;
231  const double epsilon = 0.1;
232  const auto hasher = [](const auto& particle) { return particle.hash; };
233  auto output = input | beluga::views::take_while_kld(hasher, min, max, epsilon);
234  ASSERT_EQ(ranges::distance(output), 3);
235 }
236 
237 } // namespace
primitives.hpp
Implementation of library primitives to abstract member access.
take_while_kld.hpp
Implementation of a take_while_kld range adaptor object.
beluga::views::take_while_kld
constexpr detail::take_while_kld_fn take_while_kld
Definition: take_while_kld.hpp:170
beluga::TEST
TEST(Bresenham, MultiPassGuarantee)
Definition: test_bresenham.cpp:27


beluga
Author(s):
autogenerated on Tue Jul 16 2024 02:59:53