Program Listing for File agnocast_subscription.hpp
↰ Return to documentation for file (include/agnocast/agnocast_subscription.hpp)
#pragma once
#include "agnocast/agnocast_callback_info.hpp"
#include "agnocast/agnocast_ioctl.hpp"
#include "agnocast/agnocast_mq.hpp"
#include "agnocast/agnocast_public_api.hpp"
#include "agnocast/agnocast_smart_pointer.hpp"
#include "agnocast/agnocast_tracepoint_wrapper.h"
#include "agnocast/agnocast_utils.hpp"
#include "rclcpp/detail/qos_parameters.hpp"
#include "rclcpp/rclcpp.hpp"
#include <fcntl.h>
#include <mqueue.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <atomic>
#include <cstdint>
#include <cstring>
#include <functional>
#include <string>
#include <thread>
#include <vector>
namespace agnocast
{
class Node;
extern std::mutex mmap_mtx;
void map_read_only_area(const pid_t pid, const uint64_t shm_addr, const uint64_t shm_size);
// Get the default callback group from an agnocast::Node for tracepoint use.
// Defined in .cpp to avoid circular inclusion between agnocast_subscription.hpp and
// agnocast_node.hpp.
rclcpp::CallbackGroup::SharedPtr get_default_callback_group_for_tracepoint(agnocast::Node * node);
const void * get_node_base_address(Node * node);
AGNOCAST_PUBLIC
struct SubscriptionOptions
{
rclcpp::CallbackGroup::SharedPtr callback_group{nullptr};
bool ignore_local_publications{false};
rclcpp::QosOverridingOptions qos_overriding_options{};
};
// These are cut out of the class for information hiding.
mqd_t open_mq_for_subscription(
const std::string & topic_name, const topic_local_id_t subscriber_id,
std::pair<mqd_t, std::string> & mq_subscription);
void remove_mq(const std::pair<mqd_t, std::string> & mq_subscription);
uint32_t get_publisher_count_core(const std::string & topic_name);
template <typename NodeT>
rclcpp::CallbackGroup::SharedPtr get_valid_callback_group(
NodeT * node, const SubscriptionOptions & options)
{
rclcpp::CallbackGroup::SharedPtr callback_group = options.callback_group;
if (callback_group) {
if (!node->get_node_base_interface()->callback_group_in_node(callback_group)) {
RCLCPP_ERROR(logger, "Cannot create agnocast subscription, callback group not in node.");
close(agnocast_fd);
exit(EXIT_FAILURE);
}
} else {
callback_group = node->get_node_base_interface()->get_default_callback_group();
}
return callback_group;
}
class SubscriptionBase
{
protected:
topic_local_id_t id_;
const std::string topic_name_;
union ioctl_add_subscriber_args initialize(
const rclcpp::QoS & qos, const bool is_take_sub, const bool ignore_local_publications,
const bool is_bridge, const std::string & node_name);
public:
SubscriptionBase(rclcpp::Node * node, const std::string & topic_name);
SubscriptionBase(agnocast::Node * node, const std::string & topic_name);
uint32_t get_publisher_count() const { return get_publisher_count_core(topic_name_); }
virtual ~SubscriptionBase()
{
// NOTE: Unmapping memory when a subscriber is destroyed is not implemented. Multiple
// subscribers
// may share the same mmap region, requiring reference counting in kmod. Since leaving the
// memory mapped should not cause any functional issues, this is left as future work.
struct ioctl_remove_subscriber_args remove_subscriber_args
{
};
remove_subscriber_args.topic_name = {topic_name_.c_str(), topic_name_.size()};
remove_subscriber_args.subscriber_id = id_;
if (ioctl(agnocast_fd, AGNOCAST_REMOVE_SUBSCRIBER_CMD, &remove_subscriber_args) < 0) {
RCLCPP_WARN(logger, "Failed to remove subscriber (id=%d) from kernel.", id_);
}
}
};
// Internal implementation — users should use agnocast::Subscription<MessageT> instead.
template <typename MessageT, typename BridgeRequestPolicy>
class BasicSubscription : public SubscriptionBase
{
std::pair<mqd_t, std::string> mq_subscription_;
uint32_t callback_info_id_;
template <typename NodeT, typename Func>
rclcpp::QoS constructor_impl(
NodeT * node, const rclcpp::QoS & qos, Func && callback,
rclcpp::CallbackGroup::SharedPtr callback_group, agnocast::SubscriptionOptions options,
const bool is_bridge)
{
auto node_parameters = node->get_node_parameters_interface();
const rclcpp::QoS actual_qos =
options.qos_overriding_options.get_policy_kinds().size()
? rclcpp::detail::declare_qos_parameters(
options.qos_overriding_options, node_parameters, topic_name_, qos,
rclcpp::detail::SubscriptionQosParametersTraits{})
: qos;
validate_qos(actual_qos);
union ioctl_add_subscriber_args add_subscriber_args = initialize(
actual_qos, false, options.ignore_local_publications, is_bridge,
node->get_fully_qualified_name());
id_ = add_subscriber_args.ret_id;
BridgeRequestPolicy::template request_bridge<MessageT>(topic_name_, id_);
mqd_t mq = open_mq_for_subscription(topic_name_, id_, mq_subscription_);
const bool is_transient_local =
actual_qos.durability() == rclcpp::DurabilityPolicy::TransientLocal;
callback_info_id_ = agnocast::register_callback<MessageT>(
std::forward<Func>(callback), topic_name_, id_, is_transient_local, mq, callback_group);
return actual_qos;
}
public:
using SharedPtr = std::shared_ptr<BasicSubscription<MessageT, BridgeRequestPolicy>>;
template <typename Func>
BasicSubscription(
rclcpp::Node * node, const std::string & topic_name, const rclcpp::QoS & qos, Func && callback,
agnocast::SubscriptionOptions options, const bool is_bridge = false)
: SubscriptionBase(node, topic_name)
{
rclcpp::CallbackGroup::SharedPtr callback_group = get_valid_callback_group(node, options);
const void * callback_addr = static_cast<const void *>(&callback);
const char * callback_symbol = tracetools::get_symbol(callback);
const rclcpp::QoS actual_qos =
constructor_impl(node, qos, std::forward<Func>(callback), callback_group, options, is_bridge);
{
uint64_t pid_callback_info_id = (static_cast<uint64_t>(getpid()) << 32) | callback_info_id_;
TRACEPOINT(
agnocast_subscription_init, static_cast<const void *>(this),
static_cast<const void *>(
node->get_node_base_interface()->get_shared_rcl_node_handle().get()),
callback_addr, static_cast<const void *>(callback_group.get()), callback_symbol,
topic_name_.c_str(), actual_qos.depth(), pid_callback_info_id);
}
}
template <typename Func>
BasicSubscription(
agnocast::Node * node, const std::string & topic_name, const rclcpp::QoS & qos,
Func && callback, agnocast::SubscriptionOptions options)
: SubscriptionBase(node, topic_name)
{
rclcpp::CallbackGroup::SharedPtr callback_group = get_valid_callback_group(node, options);
const void * callback_addr = static_cast<const void *>(&callback);
const char * callback_symbol = tracetools::get_symbol(callback);
const rclcpp::QoS actual_qos =
constructor_impl(node, qos, std::forward<Func>(callback), callback_group, options, false);
{
uint64_t pid_callback_info_id = (static_cast<uint64_t>(getpid()) << 32) | callback_info_id_;
TRACEPOINT(
agnocast_subscription_init, static_cast<const void *>(this),
static_cast<const void *>(get_node_base_address(node)), callback_addr,
static_cast<const void *>(callback_group.get()), callback_symbol, topic_name_.c_str(),
actual_qos.depth(), pid_callback_info_id);
}
}
~BasicSubscription()
{
// Remove from callback info map to prevent stale references on re-subscription and to avoid
// fd reuse conflicts. When mq_close() is called in remove_mq(), the OS may later reuse the
// same fd number for a new subscription. If the old entry remains in id2_callback_info,
// adding the new fd to epoll (EPOLL_CTL_ADD) can fail with EEXIST because epoll still
// associates that fd number with the stale entry.
{
std::lock_guard<std::mutex> lock(id2_callback_info_mtx);
id2_callback_info.erase(callback_info_id_);
}
remove_mq(mq_subscription_);
}
};
// Internal implementation — users should use agnocast::TakeSubscription<MessageT> instead.
template <typename MessageT, typename BridgeRequestPolicy>
class BasicTakeSubscription : public SubscriptionBase
{
private:
// Cached pointer from the most recent take(allow_same_message=true) call.
// When the same entry is returned again, a copy sharing the same control_block is returned
// so that the kernel-side reference is not released until all userspace copies are destroyed.
agnocast::ipc_shared_ptr<const MessageT> last_taken_ptr_;
std::mutex last_taken_ptr_mtx_;
template <typename NodeT>
rclcpp::QoS constructor_impl(
NodeT * node, const rclcpp::QoS & qos, agnocast::SubscriptionOptions options)
{
auto node_parameters = node->get_node_parameters_interface();
const rclcpp::QoS actual_qos =
options.qos_overriding_options.get_policy_kinds().size()
? rclcpp::detail::declare_qos_parameters(
options.qos_overriding_options, node_parameters, topic_name_, qos,
rclcpp::detail::SubscriptionQosParametersTraits{})
: qos;
validate_qos(actual_qos);
union ioctl_add_subscriber_args add_subscriber_args = initialize(
actual_qos, true, options.ignore_local_publications, false, node->get_fully_qualified_name());
id_ = add_subscriber_args.ret_id;
BridgeRequestPolicy::template request_bridge<MessageT>(topic_name_, id_);
return actual_qos;
}
public:
using SharedPtr = std::shared_ptr<BasicTakeSubscription<MessageT, BridgeRequestPolicy>>;
BasicTakeSubscription(
rclcpp::Node * node, const std::string & topic_name, const rclcpp::QoS & qos,
agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
: SubscriptionBase(node, topic_name)
{
const rclcpp::QoS actual_qos = constructor_impl(node, qos, options);
{
auto default_cbg = node->get_node_base_interface()->get_default_callback_group();
auto dummy_cb = []() {};
std::string dummy_cb_symbols = "dummy_take" + topic_name_;
TRACEPOINT(
agnocast_subscription_init, static_cast<const void *>(this),
static_cast<const void *>(
node->get_node_base_interface()->get_shared_rcl_node_handle().get()),
static_cast<const void *>(&dummy_cb), static_cast<const void *>(default_cbg.get()),
dummy_cb_symbols.c_str(), topic_name_.c_str(), actual_qos.depth(), 0);
}
}
BasicTakeSubscription(
agnocast::Node * node, const std::string & topic_name, const rclcpp::QoS & qos,
agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
: SubscriptionBase(node, topic_name)
{
const rclcpp::QoS actual_qos = constructor_impl(node, qos, options);
{
auto default_cbg = get_default_callback_group_for_tracepoint(node);
auto dummy_cb = []() {};
std::string dummy_cb_symbols = "dummy_take" + topic_name_;
TRACEPOINT(
agnocast_subscription_init, static_cast<const void *>(this),
static_cast<const void *>(get_node_base_address(node)),
static_cast<const void *>(&dummy_cb), static_cast<const void *>(default_cbg.get()),
dummy_cb_symbols.c_str(), topic_name_.c_str(), actual_qos.depth(), 0);
}
}
AGNOCAST_PUBLIC
agnocast::ipc_shared_ptr<const MessageT> take(bool allow_same_message = false)
{
publisher_shm_info pub_shm_infos[MAX_PUBLISHER_NUM]{};
union ioctl_take_msg_args take_args;
take_args.topic_name = {topic_name_.c_str(), topic_name_.size()};
take_args.subscriber_id = id_;
take_args.allow_same_message = allow_same_message;
take_args.pub_shm_info_addr = reinterpret_cast<uint64_t>(pub_shm_infos);
take_args.pub_shm_info_size = MAX_PUBLISHER_NUM;
{
std::lock_guard<std::mutex> lock(mmap_mtx);
if (ioctl(agnocast_fd, AGNOCAST_TAKE_MSG_CMD, &take_args) < 0) {
RCLCPP_ERROR(logger, "AGNOCAST_TAKE_MSG_CMD failed: %s", strerror(errno));
close(agnocast_fd);
exit(EXIT_FAILURE);
}
for (uint32_t i = 0; i < take_args.ret_pub_shm_num; i++) {
const pid_t pid = pub_shm_infos[i].pid;
const uint64_t addr = pub_shm_infos[i].shm_addr;
const uint64_t size = pub_shm_infos[i].shm_size;
map_read_only_area(pid, addr, size);
}
}
if (take_args.ret_addr == 0) {
TRACEPOINT(agnocast_take, static_cast<void *>(this), 0, 0);
return agnocast::ipc_shared_ptr<const MessageT>();
}
TRACEPOINT(
agnocast_take, static_cast<void *>(this), reinterpret_cast<void *>(take_args.ret_addr),
take_args.ret_entry_id);
if (allow_same_message) {
// Declared outside the lock scope so that its destructor (which may call ioctl to release
// the kernel reference) runs after the mutex is released, avoiding unnecessary contention.
agnocast::ipc_shared_ptr<const MessageT> old_ptr;
{
std::lock_guard<std::mutex> lock(last_taken_ptr_mtx_);
// When the kernel returned the same entry as last time, return a copy of the cached
// pointer (sharing the same control_block) instead of creating a new one.
// This keeps the kernel-side reference alive until all copies are destroyed.
if (last_taken_ptr_ && last_taken_ptr_.get_entry_id() == take_args.ret_entry_id) {
return last_taken_ptr_;
}
MessageT * ptr = reinterpret_cast<MessageT *>(take_args.ret_addr);
auto result =
agnocast::ipc_shared_ptr<const MessageT>(ptr, topic_name_, id_, take_args.ret_entry_id);
old_ptr = std::move(last_taken_ptr_);
last_taken_ptr_ = result;
return result;
}
}
MessageT * ptr = reinterpret_cast<MessageT *>(take_args.ret_addr);
return agnocast::ipc_shared_ptr<const MessageT>(ptr, topic_name_, id_, take_args.ret_entry_id);
}
};
// Internal implementation — users should use agnocast::PollingSubscriber<MessageT> instead.
template <typename MessageT, typename BridgeRequestPolicy>
class BasicPollingSubscriber
{
typename BasicTakeSubscription<MessageT, BridgeRequestPolicy>::SharedPtr subscriber_;
public:
using SharedPtr = std::shared_ptr<BasicPollingSubscriber<MessageT, BridgeRequestPolicy>>;
explicit BasicPollingSubscriber(
rclcpp::Node * node, const std::string & topic_name, const rclcpp::QoS & qos = rclcpp::QoS{1},
agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
{
subscriber_ = std::make_shared<BasicTakeSubscription<MessageT, BridgeRequestPolicy>>(
node, topic_name, qos, options);
};
explicit BasicPollingSubscriber(
agnocast::Node * node, const std::string & topic_name, const rclcpp::QoS & qos = rclcpp::QoS{1},
agnocast::SubscriptionOptions options = agnocast::SubscriptionOptions())
{
subscriber_ = std::make_shared<BasicTakeSubscription<MessageT, BridgeRequestPolicy>>(
node, topic_name, qos, options);
};
const agnocast::ipc_shared_ptr<const MessageT> takeData() { return subscriber_->take(true); };
AGNOCAST_PUBLIC
const agnocast::ipc_shared_ptr<const MessageT> take_data() { return subscriber_->take(true); };
};
struct RosToAgnocastRequestPolicy;
AGNOCAST_PUBLIC
template <typename MessageT>
using Subscription = agnocast::BasicSubscription<MessageT, agnocast::RosToAgnocastRequestPolicy>;
AGNOCAST_PUBLIC
template <typename MessageT>
using TakeSubscription =
agnocast::BasicTakeSubscription<MessageT, agnocast::RosToAgnocastRequestPolicy>;
AGNOCAST_PUBLIC
template <typename MessageT>
using PollingSubscriber =
agnocast::BasicPollingSubscriber<MessageT, agnocast::RosToAgnocastRequestPolicy>;
} // namespace agnocast