Program Listing for File agnocast_subscription.hpp

Return to documentation for file (include/agnocast/agnocast_subscription.hpp)

#pragma once

#include "agnocast/agnocast_callback_info.hpp"
#include "agnocast/agnocast_ioctl.hpp"
#include "agnocast/agnocast_mq.hpp"
#include "agnocast/agnocast_public_api.hpp"
#include "agnocast/agnocast_smart_pointer.hpp"
#include "agnocast/agnocast_tracepoint_wrapper.h"
#include "agnocast/agnocast_utils.hpp"
#include "rclcpp/detail/qos_parameters.hpp"
#include "rclcpp/rclcpp.hpp"

#include <fcntl.h>
#include <mqueue.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

#include <atomic>
#include <cstdint>
#include <cstring>
#include <functional>
#include <string>
#include <thread>
#include <vector>

namespace agnocast
{
class Node;

extern std::mutex mmap_mtx;

void map_read_only_area(const pid_t pid, const uint64_t shm_addr, const uint64_t shm_size);

// Get the default callback group from an agnocast::Node for tracepoint use.
// Defined in .cpp to avoid circular inclusion between agnocast_subscription.hpp and
// agnocast_node.hpp.
rclcpp::CallbackGroup::SharedPtr get_default_callback_group_for_tracepoint(agnocast::Node * node);
const void * get_node_base_address(Node * node);

AGNOCAST_PUBLIC
struct SubscriptionOptions
{
  rclcpp::CallbackGroup::SharedPtr callback_group{nullptr};
  bool ignore_local_publications{false};
  rclcpp::QosOverridingOptions qos_overriding_options{};
};

// These are cut out of the class for information hiding.
mqd_t open_mq_for_subscription(
  const std::string & topic_name, const topic_local_id_t subscriber_id,
  std::pair<mqd_t, std::string> & mq_subscription);
void remove_mq(const std::pair<mqd_t, std::string> & mq_subscription);
uint32_t get_publisher_count_core(const std::string & topic_name);

template <typename NodeT>
rclcpp::CallbackGroup::SharedPtr get_valid_callback_group(
  NodeT * node, const SubscriptionOptions & options)
{
  rclcpp::CallbackGroup::SharedPtr callback_group = options.callback_group;

  if (callback_group) {
    if (!node->get_node_base_interface()->callback_group_in_node(callback_group)) {
      RCLCPP_ERROR(logger, "Cannot create agnocast subscription, callback group not in node.");
      close(agnocast_fd);
      exit(EXIT_FAILURE);
    }
  } else {
    callback_group = node->get_node_base_interface()->get_default_callback_group();
  }

  return callback_group;
}

class SubscriptionBase
{
protected:
  topic_local_id_t id_;
  const std::string topic_name_;
  union ioctl_add_subscriber_args initialize(
    const rclcpp::QoS & qos, const bool is_take_sub, const bool ignore_local_publications,
    const bool is_bridge, const std::string & node_name);

public:
  SubscriptionBase(rclcpp::Node * node, const std::string & topic_name);
  SubscriptionBase(agnocast::Node * node, const std::string & topic_name);

  uint32_t get_publisher_count() const { return get_publisher_count_core(topic_name_); }

  virtual ~SubscriptionBase()
  {
    // NOTE: Unmapping memory when a subscriber is destroyed is not implemented. Multiple
    // subscribers
    // may share the same mmap region, requiring reference counting in kmod. Since leaving the
    // memory mapped should not cause any functional issues, this is left as future work.
    struct ioctl_remove_subscriber_args remove_subscriber_args
    {
    };
    remove_subscriber_args.topic_name = {topic_name_.c_str(), topic_name_.size()};
    remove_subscriber_args.subscriber_id = id_;
    if (ioctl(agnocast_fd, AGNOCAST_REMOVE_SUBSCRIBER_CMD, &remove_subscriber_args) < 0) {
      RCLCPP_WARN(logger, "Failed to remove subscriber (id=%d) from kernel.", id_);
    }
  }
};

// Internal implementation — users should use agnocast::Subscription<MessageT> instead.
template <typename MessageT, typename BridgeRequestPolicy>
class BasicSubscription : public SubscriptionBase
{
  std::pair<mqd_t, std::string> mq_subscription_;
  uint32_t callback_info_id_;

  template <typename NodeT, typename Func>
  rclcpp::QoS constructor_impl(
    NodeT * node, const rclcpp::QoS & qos, Func && callback,
    rclcpp::CallbackGroup::SharedPtr callback_group, agnocast::SubscriptionOptions options,
    const bool is_bridge)
  {
    auto node_parameters = node->get_node_parameters_interface();
    const rclcpp::QoS actual_qos =
      options.qos_overriding_options.get_policy_kinds().size()
        ? rclcpp::detail::declare_qos_parameters(
            options.qos_overriding_options, node_parameters, topic_name_, qos,
            rclcpp::detail::SubscriptionQosParametersTraits{})
        : qos;

    validate_qos(actual_qos);

    union ioctl_add_subscriber_args add_subscriber_args = initialize(
      actual_qos, false, options.ignore_local_publications, is_bridge,
      node->get_fully_qualified_name());

    id_ = add_subscriber_args.ret_id;
    BridgeRequestPolicy::template request_bridge<MessageT>(topic_name_, id_);

    mqd_t mq = open_mq_for_subscription(topic_name_, id_, mq_subscription_);

    const bool is_transient_local =
      actual_qos.durability() == rclcpp::DurabilityPolicy::TransientLocal;
    callback_info_id_ = agnocast::register_callback<MessageT>(
      std::forward<Func>(callback), topic_name_, id_, is_transient_local, mq, callback_group);

    return actual_qos;
  }

public:
  using SharedPtr = std::shared_ptr<BasicSubscription<MessageT, BridgeRequestPolicy>>;

  template <typename Func>
  BasicSubscription(
    rclcpp::Node * node, const std::string & topic_name, const rclcpp::QoS & qos, Func && callback,
    agnocast::SubscriptionOptions options, const bool is_bridge = false)
  : SubscriptionBase(node, topic_name)
  {
    rclcpp::CallbackGroup::SharedPtr callback_group = get_valid_callback_group(node, options);

    const void * callback_addr = static_cast<const void *>(&callback);
    const char * callback_symbol = tracetools::get_symbol(callback);

    const rclcpp::QoS actual_qos =
      constructor_impl(node, qos, std::forward<Func>(callback), callback_group, options, is_bridge);

    {
      uint64_t pid_callback_info_id = (static_cast<uint64_t>(getpid()) << 32) | callback_info_id_;
      TRACEPOINT(
        agnocast_subscription_init, static_cast<const void *>(this),
        static_cast<const void *>(
          node->get_node_base_interface()->get_shared_rcl_node_handle().get()),
        callback_addr, static_cast<const void *>(callback_group.get()), callback_symbol,
        topic_name_.c_str(), actual_qos.depth(), pid_callback_info_id);
    }
  }

  template <typename Func>
  BasicSubscription(
    agnocast::Node * node, const std::string & topic_name, const rclcpp::QoS & qos,
    Func && callback, agnocast::SubscriptionOptions options)
  : SubscriptionBase(node, topic_name)
  {
    rclcpp::CallbackGroup::SharedPtr callback_group = get_valid_callback_group(node, options);

    const void * callback_addr = static_cast<const void *>(&callback);
    const char * callback_symbol = tracetools::get_symbol(callback);

    const rclcpp::QoS actual_qos =
      constructor_impl(node, qos, std::forward<Func>(callback), callback_group, options, false);

    {
      uint64_t pid_callback_info_id = (static_cast<uint64_t>(getpid()) << 32) | callback_info_id_;
      TRACEPOINT(
        agnocast_subscription_init, static_cast<const void *>(this),
        static_cast<const void *>(get_node_base_address(node)), callback_addr,
        static_cast<const void *>(callback_group.get()), callback_symbol, topic_name_.c_str(),
        actual_qos.depth(), pid_callback_info_id);
    }
  }

  ~BasicSubscription()
  {
    // Remove from callback info map to prevent stale references on re-subscription and to avoid
    // fd reuse conflicts. When mq_close() is called in remove_mq(), the OS may later reuse the
    // same fd number for a new subscription. If the old entry remains in id2_callback_info,
    // adding the new fd to epoll (EPOLL_CTL_ADD) can fail with EEXIST because epoll still
    // associates that fd number with the stale entry.
    {
      std::lock_guard<std::mutex> lock(id2_callback_info_mtx);
      id2_callback_info.erase(callback_info_id_);
    }
    remove_mq(mq_subscription_);
  }
};

// Internal implementation — users should use agnocast::TakeSubscription<MessageT> instead.
template <typename MessageT, typename BridgeRequestPolicy>
class BasicTakeSubscription : public SubscriptionBase
{
private:
  // Cached pointer from the most recent take(allow_same_message=true) call.
  // When the same entry is returned again, a copy sharing the same control_block is returned
  // so that the kernel-side reference is not released until all userspace copies are destroyed.
  agnocast::ipc_shared_ptr<const MessageT> last_taken_ptr_;
  std::mutex last_taken_ptr_mtx_;

  template <typename NodeT>
  rclcpp::QoS constructor_impl(
    NodeT * node, const rclcpp::QoS & qos, agnocast::SubscriptionOptions options)
  {
    auto node_parameters = node->get_node_parameters_interface();
    const rclcpp::QoS actual_qos =
      options.qos_overriding_options.get_policy_kinds().size()
        ? rclcpp::detail::declare_qos_parameters(
            options.qos_overriding_options, node_parameters, topic_name_, qos,
            rclcpp::detail::SubscriptionQosParametersTraits{})
        : qos;

    validate_qos(actual_qos);

    union ioctl_add_subscriber_args add_subscriber_args = initialize(
      actual_qos, true, options.ignore_local_publications, false, node->get_fully_qualified_name());

    id_ = add_subscriber_args.ret_id;
    BridgeRequestPolicy::template request_bridge<MessageT>(topic_name_, id_);

    return actual_qos;
  }

public:
  using SharedPtr = std::shared_ptr<BasicTakeSubscription<MessageT, BridgeRequestPolicy>>;

  BasicTakeSubscription(
    rclcpp::Node * node, const std::string & topic_name, const rclcpp::QoS & qos,
    agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
  : SubscriptionBase(node, topic_name)
  {
    const rclcpp::QoS actual_qos = constructor_impl(node, qos, options);

    {
      auto default_cbg = node->get_node_base_interface()->get_default_callback_group();
      auto dummy_cb = []() {};
      std::string dummy_cb_symbols = "dummy_take" + topic_name_;
      TRACEPOINT(
        agnocast_subscription_init, static_cast<const void *>(this),
        static_cast<const void *>(
          node->get_node_base_interface()->get_shared_rcl_node_handle().get()),
        static_cast<const void *>(&dummy_cb), static_cast<const void *>(default_cbg.get()),
        dummy_cb_symbols.c_str(), topic_name_.c_str(), actual_qos.depth(), 0);
    }
  }

  BasicTakeSubscription(
    agnocast::Node * node, const std::string & topic_name, const rclcpp::QoS & qos,
    agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
  : SubscriptionBase(node, topic_name)
  {
    const rclcpp::QoS actual_qos = constructor_impl(node, qos, options);

    {
      auto default_cbg = get_default_callback_group_for_tracepoint(node);
      auto dummy_cb = []() {};
      std::string dummy_cb_symbols = "dummy_take" + topic_name_;
      TRACEPOINT(
        agnocast_subscription_init, static_cast<const void *>(this),
        static_cast<const void *>(get_node_base_address(node)),
        static_cast<const void *>(&dummy_cb), static_cast<const void *>(default_cbg.get()),
        dummy_cb_symbols.c_str(), topic_name_.c_str(), actual_qos.depth(), 0);
    }
  }

  AGNOCAST_PUBLIC
  agnocast::ipc_shared_ptr<const MessageT> take(bool allow_same_message = false)
  {
    publisher_shm_info pub_shm_infos[MAX_PUBLISHER_NUM]{};

    union ioctl_take_msg_args take_args;
    take_args.topic_name = {topic_name_.c_str(), topic_name_.size()};
    take_args.subscriber_id = id_;
    take_args.allow_same_message = allow_same_message;
    take_args.pub_shm_info_addr = reinterpret_cast<uint64_t>(pub_shm_infos);
    take_args.pub_shm_info_size = MAX_PUBLISHER_NUM;

    {
      std::lock_guard<std::mutex> lock(mmap_mtx);

      if (ioctl(agnocast_fd, AGNOCAST_TAKE_MSG_CMD, &take_args) < 0) {
        RCLCPP_ERROR(logger, "AGNOCAST_TAKE_MSG_CMD failed: %s", strerror(errno));
        close(agnocast_fd);
        exit(EXIT_FAILURE);
      }

      for (uint32_t i = 0; i < take_args.ret_pub_shm_num; i++) {
        const pid_t pid = pub_shm_infos[i].pid;
        const uint64_t addr = pub_shm_infos[i].shm_addr;
        const uint64_t size = pub_shm_infos[i].shm_size;
        map_read_only_area(pid, addr, size);
      }
    }

    if (take_args.ret_addr == 0) {
      TRACEPOINT(agnocast_take, static_cast<void *>(this), 0, 0);
      return agnocast::ipc_shared_ptr<const MessageT>();
    }

    TRACEPOINT(
      agnocast_take, static_cast<void *>(this), reinterpret_cast<void *>(take_args.ret_addr),
      take_args.ret_entry_id);

    if (allow_same_message) {
      // Declared outside the lock scope so that its destructor (which may call ioctl to release
      // the kernel reference) runs after the mutex is released, avoiding unnecessary contention.
      agnocast::ipc_shared_ptr<const MessageT> old_ptr;
      {
        std::lock_guard<std::mutex> lock(last_taken_ptr_mtx_);

        // When the kernel returned the same entry as last time, return a copy of the cached
        // pointer (sharing the same control_block) instead of creating a new one.
        // This keeps the kernel-side reference alive until all copies are destroyed.
        if (last_taken_ptr_ && last_taken_ptr_.get_entry_id() == take_args.ret_entry_id) {
          return last_taken_ptr_;
        }

        MessageT * ptr = reinterpret_cast<MessageT *>(take_args.ret_addr);
        auto result =
          agnocast::ipc_shared_ptr<const MessageT>(ptr, topic_name_, id_, take_args.ret_entry_id);
        old_ptr = std::move(last_taken_ptr_);
        last_taken_ptr_ = result;
        return result;
      }
    }

    MessageT * ptr = reinterpret_cast<MessageT *>(take_args.ret_addr);
    return agnocast::ipc_shared_ptr<const MessageT>(ptr, topic_name_, id_, take_args.ret_entry_id);
  }
};

// Internal implementation — users should use agnocast::PollingSubscriber<MessageT> instead.
template <typename MessageT, typename BridgeRequestPolicy>
class BasicPollingSubscriber
{
  typename BasicTakeSubscription<MessageT, BridgeRequestPolicy>::SharedPtr subscriber_;

public:
  using SharedPtr = std::shared_ptr<BasicPollingSubscriber<MessageT, BridgeRequestPolicy>>;

  explicit BasicPollingSubscriber(
    rclcpp::Node * node, const std::string & topic_name, const rclcpp::QoS & qos = rclcpp::QoS{1},
    agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
  {
    subscriber_ = std::make_shared<BasicTakeSubscription<MessageT, BridgeRequestPolicy>>(
      node, topic_name, qos, options);
  };

  explicit BasicPollingSubscriber(
    agnocast::Node * node, const std::string & topic_name, const rclcpp::QoS & qos = rclcpp::QoS{1},
    agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
  {
    subscriber_ = std::make_shared<BasicTakeSubscription<MessageT, BridgeRequestPolicy>>(
      node, topic_name, qos, options);
  };

  const agnocast::ipc_shared_ptr<const MessageT> takeData() { return subscriber_->take(true); };
  AGNOCAST_PUBLIC
  const agnocast::ipc_shared_ptr<const MessageT> take_data() { return subscriber_->take(true); };
};

struct RosToAgnocastRequestPolicy;

AGNOCAST_PUBLIC
template <typename MessageT>
using Subscription = agnocast::BasicSubscription<MessageT, agnocast::RosToAgnocastRequestPolicy>;

AGNOCAST_PUBLIC
template <typename MessageT>
using TakeSubscription =
  agnocast::BasicTakeSubscription<MessageT, agnocast::RosToAgnocastRequestPolicy>;

AGNOCAST_PUBLIC
template <typename MessageT>
using PollingSubscriber =
  agnocast::BasicPollingSubscriber<MessageT, agnocast::RosToAgnocastRequestPolicy>;

}  // namespace agnocast