15 #include "absl/random/distributions.h"
23 #include "gtest/gtest.h"
24 #include "absl/random/internal/distribution_test_util.h"
25 #include "absl/random/random.h"
29 constexpr
int kSize = 400000;
36 template <
typename A,
typename B>
37 auto InferredUniformReturnT(
int)
38 -> decltype(
absl::Uniform(std::declval<absl::InsecureBitGen&>(),
39 std::declval<A>(), std::declval<B>()));
41 template <
typename,
typename>
42 Invalid InferredUniformReturnT(...);
44 template <
typename TagType,
typename A,
typename B>
45 auto InferredTaggedUniformReturnT(
int)
47 std::declval<absl::InsecureBitGen&>(),
48 std::declval<A>(), std::declval<B>()));
50 template <
typename,
typename,
typename>
51 Invalid InferredTaggedUniformReturnT(...);
77 template <
typename A,
typename B,
typename Expect>
78 void CheckArgsInferType() {
81 std::is_same<
Expect, decltype(InferredUniformReturnT<A, B>(0))>,
83 decltype(InferredUniformReturnT<B, A>(0))>>::
value,
87 std::is_same<
Expect, decltype(InferredTaggedUniformReturnT<
90 decltype(InferredTaggedUniformReturnT<
95 template <
typename A,
typename B,
typename ExplicitRet>
96 auto ExplicitUniformReturnT(
int) -> decltype(
97 absl::Uniform<ExplicitRet>(*std::declval<absl::InsecureBitGen*>(),
98 std::declval<A>(), std::declval<B>()));
100 template <
typename,
typename,
typename ExplicitRet>
101 Invalid ExplicitUniformReturnT(...);
103 template <
typename TagType,
typename A,
typename B,
typename ExplicitRet>
104 auto ExplicitTaggedUniformReturnT(
int) -> decltype(absl::Uniform<ExplicitRet>(
105 std::declval<TagType>(), *std::declval<absl::InsecureBitGen*>(),
106 std::declval<A>(), std::declval<B>()));
108 template <
typename,
typename,
typename,
typename ExplicitRet>
109 Invalid ExplicitTaggedUniformReturnT(...);
118 template <
typename A,
typename B,
typename Expect>
119 void CheckArgsReturnExpectedType() {
123 decltype(ExplicitUniformReturnT<A, B, Expect>(0))>,
124 std::is_same<
Expect, decltype(ExplicitUniformReturnT<B, A, Expect>(
130 decltype(ExplicitTaggedUniformReturnT<
132 std::is_same<
Expect, decltype(ExplicitTaggedUniformReturnT<
138 TEST_F(RandomDistributionsTest, UniformTypeInference) {
140 CheckArgsInferType<uint16_t, uint16_t, uint16_t>();
141 CheckArgsInferType<uint32_t, uint32_t, uint32_t>();
142 CheckArgsInferType<uint64_t, uint64_t, uint64_t>();
143 CheckArgsInferType<int16_t, int16_t, int16_t>();
144 CheckArgsInferType<int32_t, int32_t, int32_t>();
145 CheckArgsInferType<int64_t, int64_t, int64_t>();
146 CheckArgsInferType<float, float, float>();
147 CheckArgsInferType<double, double, double>();
150 CheckArgsReturnExpectedType<int16_t, int16_t, int32_t>();
151 CheckArgsReturnExpectedType<uint16_t, uint16_t, int32_t>();
152 CheckArgsReturnExpectedType<int16_t, int16_t, int64_t>();
153 CheckArgsReturnExpectedType<int16_t, int32_t, int64_t>();
154 CheckArgsReturnExpectedType<int16_t, int32_t, double>();
155 CheckArgsReturnExpectedType<float, float, double>();
156 CheckArgsReturnExpectedType<int, int, int16_t>();
159 CheckArgsInferType<uint16_t, uint32_t, uint32_t>();
160 CheckArgsInferType<uint16_t, uint64_t, uint64_t>();
161 CheckArgsInferType<uint16_t, int32_t, int32_t>();
162 CheckArgsInferType<uint16_t, int64_t, int64_t>();
163 CheckArgsInferType<uint16_t, float, float>();
164 CheckArgsInferType<uint16_t, double, double>();
167 CheckArgsInferType<int16_t, int32_t, int32_t>();
168 CheckArgsInferType<int16_t, int64_t, int64_t>();
169 CheckArgsInferType<int16_t, float, float>();
170 CheckArgsInferType<int16_t, double, double>();
174 CheckArgsInferType<uint16_t, int16_t, Invalid>();
175 CheckArgsInferType<int16_t, uint32_t, Invalid>();
176 CheckArgsInferType<int16_t, uint64_t, Invalid>();
179 CheckArgsInferType<uint32_t, uint64_t, uint64_t>();
180 CheckArgsInferType<uint32_t, int64_t, int64_t>();
181 CheckArgsInferType<uint32_t, double, double>();
184 CheckArgsInferType<int32_t, int64_t, int64_t>();
185 CheckArgsInferType<int32_t, double, double>();
188 CheckArgsInferType<uint32_t, int32_t, Invalid>();
189 CheckArgsInferType<int32_t, uint64_t, Invalid>();
190 CheckArgsInferType<int32_t, float, Invalid>();
191 CheckArgsInferType<uint32_t, float, Invalid>();
194 CheckArgsInferType<uint64_t, int64_t, Invalid>();
195 CheckArgsInferType<int64_t, float, Invalid>();
196 CheckArgsInferType<int64_t, double, Invalid>();
199 CheckArgsInferType<float, double, double>();
202 TEST_F(RandomDistributionsTest, UniformExamples) {
211 EXPECT_NE(1, absl::Uniform<double>(absl::IntervalOpenOpen,
gen, -1, 1));
212 EXPECT_NE(1, absl::Uniform<float>(absl::IntervalOpenOpen,
gen, 0, 1));
216 TEST_F(RandomDistributionsTest, UniformNoBounds) {
219 absl::Uniform<uint8_t>(
gen);
220 absl::Uniform<uint16_t>(
gen);
221 absl::Uniform<uint32_t>(
gen);
222 absl::Uniform<uint64_t>(
gen);
223 absl::Uniform<absl::uint128>(
gen);
226 TEST_F(RandomDistributionsTest, UniformNonsenseRanges) {
230 #if (defined(__i386__) || defined(_M_IX86)) && FLT_EVAL_METHOD != 0
235 <<
"Skipping the test because we detected x87 floating-point semantics";
243 EXPECT_EQ(0, absl::Uniform<uint64_t>(absl::IntervalOpenOpen,
gen, 0, 0));
244 EXPECT_EQ(1, absl::Uniform<uint64_t>(absl::IntervalOpenOpen,
gen, 1, 0));
258 EXPECT_EQ(0, absl::Uniform<int64_t>(absl::IntervalOpenOpen,
gen, 0, 0));
259 EXPECT_EQ(1, absl::Uniform<int64_t>(absl::IntervalOpenOpen,
gen, 1, 0));
274 const double e = std::nextafter(1.0, 2.0);
275 const double f = std::nextafter(1.0, 0.0);
276 const double g = std::numeric_limits<double>::denorm_min();
288 TEST_F(RandomDistributionsTest, UniformReal) {
304 TEST_F(RandomDistributionsTest, UniformInt) {
312 values[
i] =
static_cast<double>(
j) /
static_cast<double>(kMax);
352 TEST_F(RandomDistributionsTest, PoissonDefault) {
368 TEST_F(RandomDistributionsTest, PoissonLarge) {
369 constexpr
double kMean = 100000000.0;
374 values[
i] = absl::Poisson<int64_t>(
gen, kMean);
380 EXPECT_NEAR(kMean, moments.variance, kMean * 0.015);
381 EXPECT_NEAR(std::sqrt(kMean), moments.skewness, kMean * 0.02);
386 constexpr
double kP = 0.5151515151;
400 constexpr
double kAlpha = 2.0;
401 constexpr
double kBeta = 3.0;
427 EXPECT_NEAR(6.5944, moments.mean, 2000) << moments;
451 values[
i] = absl::LogUniform<int64_t>(
gen, 0, (1 << 10) - 1);
457 const double mean = (0 + 1 + 1 + 2 + 3 + 4 + 7 + 8 + 15 + 16 + 31 + 32 + 63 +
458 64 + 127 + 128 + 255 + 256 + 511 + 512 + 1023) /