Program Listing for File reweight.hpp

Return to documentation for file (include/beluga/actions/reweight.hpp)

// Copyright 2024 Ekumen, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BELUGA_ACTIONS_REWEIGHT_HPP
#define BELUGA_ACTIONS_REWEIGHT_HPP

#include <algorithm>
#include <execution>
#include <type_traits>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/views/likelihoods.hpp>
#include <beluga/views/particles.hpp>

#include <range/v3/action/action.hpp>
#include <range/v3/view/zip.hpp>

namespace beluga::actions {

namespace detail {

template <typename T, typename = void>
struct has_likelihood_member : std::false_type {};

template <typename T>
struct has_likelihood_member<T, std::void_t<decltype(std::declval<T&>().likelihood)>> : std::true_type {};

template <typename T>
inline constexpr bool has_likelihood_member_v = has_likelihood_member<T>::value;

struct reweight_fn {
 private:
  template <class ExecutionPolicy, class Range, class LikelihoodRange>
  constexpr void apply_impl(ExecutionPolicy&& policy, Range& range, const LikelihoodRange& likelihoods) const {
    // Update particle weights using std::transform.
    auto weights = range | beluga::views::weights;
    std::transform(
        policy, std::begin(weights), std::end(weights), std::begin(likelihoods), std::begin(weights),
        [](auto weight, auto likelihood) { return weight * likelihood; });

    // Store raw likelihood if the particle type has storage for it.
    // We use ranges::views::zip here because we are performing side-effects (writing to members)
    // which std::transform is not designed for.
    using ParticleType = ranges::range_value_t<Range>;
    if constexpr (has_likelihood_member_v<ParticleType>) {
      auto zipped = ranges::views::zip(range, likelihoods);
      std::for_each(policy, std::begin(zipped), std::end(zipped), [](auto&& tuple) {
        auto& particle = std::get<0>(tuple);
        particle.likelihood = std::get<1>(tuple);
      });
    }
  }

 public:

  template <
      class ExecutionPolicy,
      class Range,
      class Input,
      std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
      std::enable_if_t<ranges::range<Range>, int> = 0>
  constexpr auto operator()(ExecutionPolicy&& policy, Range& range, Input&& input) const -> Range& {
    static_assert(beluga::is_particle_range_v<Range>);
    if constexpr (ranges::range<Input>) {
      apply_impl(std::forward<ExecutionPolicy>(policy), range, input);
    } else {
      apply_impl(
          std::forward<ExecutionPolicy>(policy), range, range | beluga::views::likelihoods(std::forward<Input>(input)));
    }
    return range;
  }


  template <
      class Range,
      class Input,
      std::enable_if_t<ranges::range<Range>, int> = 0,
      std::enable_if_t<!std::is_execution_policy_v<std::decay_t<Input>>, int> = 0>
  constexpr auto operator()(Range& range, Input&& input) const -> Range& {
    return (*this)(std::execution::seq, range, std::forward<Input>(input));
  }


  template <
      class Range,
      class Input,
      class ExecutionPolicy,
      std::enable_if_t<ranges::range<Range>, int> = 0,
      std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0>
  constexpr auto operator()(Range&& range, Input&& input, ExecutionPolicy&& policy) const -> auto& {
    return (*this)(std::forward<ExecutionPolicy>(policy), range, std::forward<Input>(input));
  }


  template <
      class ExecutionPolicy,
      class Input,
      std::enable_if_t<std::is_execution_policy_v<std::decay_t<ExecutionPolicy>>, int> = 0,
      std::enable_if_t<!ranges::range<std::decay_t<ExecutionPolicy>>, int> = 0>
  constexpr auto operator()(ExecutionPolicy policy, Input input) const {
    return ranges::make_action_closure(ranges::bind_back(reweight_fn{}, std::move(input), std::move(policy)));
  }


  template <class Input, std::enable_if_t<!std::is_execution_policy_v<std::decay_t<Input>>, int> = 0>
  constexpr auto operator()(Input input) const {
    return ranges::make_action_closure(ranges::bind_back(reweight_fn{}, std::move(input)));
  }
};

}  // namespace detail


inline constexpr detail::reweight_fn reweight;

}  // namespace beluga::actions

#endif