15 #include <gtest/gtest.h>
25 TEST(ThrunRecoveryProbabilityEstimator, InvalidAlphaSlow) {
26 const double alpha_slow = -0.2;
27 const double alpha_fast = 0.4;
31 TEST(ThrunRecoveryProbabilityEstimator, InvalidAlphaFast) {
32 const double alpha_slow = 0.2;
33 const double alpha_fast = 0.1;
37 TEST(ThrunRecoveryProbabilityEstimator, ProbabilityWithNoParticles) {
39 const double alpha_slow = 0.2;
40 const double alpha_fast = 0.4;
42 ASSERT_EQ(estimator(std::vector<std::tuple<int, beluga::Weight>>{}), 0.0);
45 TEST(ThrunRecoveryProbabilityEstimator, ProbabilityWithZeroWeight) {
47 const double alpha_slow = 0.2;
48 const double alpha_fast = 0.4;
50 ASSERT_EQ(estimator(std::vector<std::tuple<int, beluga::Weight>>{{1, 0.0}, {2, 0.0}}), 0.0);
53 TEST(ThrunRecoveryProbabilityEstimator, ProbabilityAfterUpdateAndReset) {
54 const double alpha_slow = 0.5;
55 const double alpha_fast = 1.0;
59 auto input = std::vector<std::tuple<int, beluga::Weight>>{{5, 1.0}, {6, 2.0}, {7, 3.0}};
60 ASSERT_EQ(estimator(input), 0.0);
62 input = std::vector<std::tuple<int, beluga::Weight>>{{5, 0.5}, {6, 1.0}, {7, 1.5}};
63 ASSERT_NEAR(estimator(input), 0.33, 0.01);
65 input = std::vector<std::tuple<int, beluga::Weight>>{{5, 0.5}, {6, 1.0}, {7, 1.5}};
66 ASSERT_NEAR(estimator(input), 0.20, 0.01);
69 ASSERT_EQ(estimator(input), 0.0);
72 class ThrunRecoveryProbabilityWithParam :
public ::testing::TestWithParam<std::tuple<double, double, double>> {};
74 TEST_P(ThrunRecoveryProbabilityWithParam, Probabilities) {
75 const auto [initial_weight, final_weight, expected_probability] = GetParam();
77 const double alpha_slow = 0.001;
78 const double alpha_fast = 0.1;
80 auto particles = std::vector<std::tuple<int, beluga::Weight>>{{1, initial_weight}};
82 ASSERT_NEAR(estimator(particles), 0.0, 0.01);
86 ASSERT_NEAR(estimator(particles), expected_probability, 0.01);
89 INSTANTIATE_TEST_SUITE_P(
90 ThrunRecoveryProbability,
91 ThrunRecoveryProbabilityWithParam,
93 std::make_tuple(1.0, 1.5, 0.00),
94 std::make_tuple(1.0, 2.0, 0.00),
95 std::make_tuple(1.0, 0.5, 0.05),
96 std::make_tuple(0.5, 0.1, 0.08),
97 std::make_tuple(0.5, 0.0, 0.10)));