Program Listing for File ImuConverter.hpp

Return to documentation for file (/tmp/ws/src/depthai-ros/depthai_bridge/include/depthai_bridge/ImuConverter.hpp)

#pragma once

#include <deque>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>

#include "depthai-shared/datatype/RawIMUData.hpp"
#include "depthai/pipeline/datatype/IMUData.hpp"
#include "depthai_bridge/depthaiUtility.hpp"
#include "depthai_ros_msgs/msg/imu_with_magnetic_field.hpp"
#include "rclcpp/time.hpp"
#include "sensor_msgs/msg/imu.hpp"
#include "sensor_msgs/msg/magnetic_field.hpp"

namespace dai {

namespace ros {

namespace ImuMsgs = sensor_msgs::msg;
using ImuPtr = ImuMsgs::Imu::SharedPtr;

enum class ImuSyncMethod { COPY, LINEAR_INTERPOLATE_GYRO, LINEAR_INTERPOLATE_ACCEL };

class ImuConverter {
   public:
    ImuConverter(const std::string& frameName,
                 ImuSyncMethod syncMode = ImuSyncMethod::LINEAR_INTERPOLATE_ACCEL,
                 double linear_accel_cov = 0.0,
                 double angular_velocity_cov = 0.0,
                 double rotation_cov = 0.0,
                 double magnetic_field_cov = 0.0,
                 bool enable_rotation = false,
                 bool enable_magn = false,
                 bool getBaseDeviceTimestamp = false);
    ~ImuConverter();

    void updateRosBaseTime();

    void setUpdateRosBaseTimeOnToRosMsg(bool update = true) {
        _updateRosBaseTimeOnToRosMsg = update;
    }

    void toRosMsg(std::shared_ptr<dai::IMUData> inData, std::deque<ImuMsgs::Imu>& outImuMsgs);
    void toRosDaiMsg(std::shared_ptr<dai::IMUData> inData, std::deque<depthai_ros_msgs::msg::ImuWithMagneticField>& outImuMsgs);

    template <typename T>
    T lerp(const T& a, const T& b, const double t) {
        return a * (1.0 - t) + b * t;
    }

    template <typename T>
    T lerpImu(const T& a, const T& b, const double t) {
        T res;
        res.x = lerp(a.x, b.x, t);
        res.y = lerp(a.y, b.y, t);
        res.z = lerp(a.z, b.z, t);
        return res;
    }

   private:
    template <typename T>
    void FillImuData_LinearInterpolation(std::vector<IMUPacket>& imuPackets, std::deque<T>& imuMsgs) {
        static std::deque<dai::IMUReportAccelerometer> accelHist;
        static std::deque<dai::IMUReportGyroscope> gyroHist;
        static std::deque<dai::IMUReportRotationVectorWAcc> rotationHist;
        static std::deque<dai::IMUReportMagneticField> magnHist;

        for(int i = 0; i < imuPackets.size(); ++i) {
            if(accelHist.size() == 0) {
                accelHist.push_back(imuPackets[i].acceleroMeter);
            } else if(accelHist.back().sequence != imuPackets[i].acceleroMeter.sequence) {
                accelHist.push_back(imuPackets[i].acceleroMeter);
            }

            if(gyroHist.size() == 0) {
                gyroHist.push_back(imuPackets[i].gyroscope);
            } else if(gyroHist.back().sequence != imuPackets[i].gyroscope.sequence) {
                gyroHist.push_back(imuPackets[i].gyroscope);
            }

            if(_enable_rotation && rotationHist.size() == 0) {
                rotationHist.push_back(imuPackets[i].rotationVector);
            } else if(_enable_rotation && rotationHist.back().sequence != imuPackets[i].rotationVector.sequence) {
                rotationHist.push_back(imuPackets[i].rotationVector);
            } else {
                rotationHist.resize(accelHist.size());
            }

            if(_enable_magn && magnHist.size() == 0) {
                magnHist.push_back(imuPackets[i].magneticField);
            } else if(_enable_magn && magnHist.back().sequence != imuPackets[i].magneticField.sequence) {
                magnHist.push_back(imuPackets[i].magneticField);
            } else {
                magnHist.resize(accelHist.size());
            }

            if(_syncMode == ImuSyncMethod::LINEAR_INTERPOLATE_ACCEL) {
                if(accelHist.size() < 3 && gyroHist.size() && rotationHist.size() && magnHist.size()) {
                    continue;
                } else {
                    if(_enable_rotation) {
                        if(_enable_magn) {
                            interpolate(accelHist, gyroHist, rotationHist, magnHist, imuMsgs);
                        } else {
                            interpolate(accelHist, gyroHist, rotationHist, imuMsgs);
                        }
                    } else {
                        interpolate(accelHist, gyroHist, imuMsgs);
                    }
                }

            } else if(_syncMode == ImuSyncMethod::LINEAR_INTERPOLATE_GYRO) {
                if(gyroHist.size() < 3 && accelHist.size() && rotationHist.size() && magnHist.size()) {
                    continue;
                } else {
                    if(_enable_rotation) {
                        if(_enable_magn) {
                            interpolate(gyroHist, accelHist, rotationHist, magnHist, imuMsgs);
                        } else {
                            interpolate(gyroHist, accelHist, rotationHist, imuMsgs);
                        }
                    } else {
                        interpolate(gyroHist, accelHist, imuMsgs);
                    }
                }
            }
        }
    }

    uint32_t _sequenceNum;
    double _linear_accel_cov, _angular_velocity_cov, _rotation_cov, _magnetic_field_cov;
    bool _enable_rotation;
    bool _enable_magn;
    const std::string _frameName = "";
    ImuSyncMethod _syncMode;
    std::chrono::time_point<std::chrono::steady_clock> _steadyBaseTime;
    rclcpp::Time _rosBaseTime;
    bool _getBaseDeviceTimestamp;
    // For handling ROS time shifts and debugging
    int64_t _totalNsChange{0};
    // Whether to update the ROS base time on each message conversion
    bool _updateRosBaseTimeOnToRosMsg{false};

    void fillImuMsg(ImuMsgs::Imu& msg, dai::IMUReportAccelerometer report);
    void fillImuMsg(ImuMsgs::Imu& msg, dai::IMUReportGyroscope report);
    void fillImuMsg(ImuMsgs::Imu& msg, dai::IMUReportRotationVectorWAcc report);
    void fillImuMsg(ImuMsgs::Imu& msg, dai::IMUReportMagneticField report);

    void fillImuMsg(depthai_ros_msgs::msg::ImuWithMagneticField& msg, dai::IMUReportAccelerometer report);
    void fillImuMsg(depthai_ros_msgs::msg::ImuWithMagneticField& msg, dai::IMUReportGyroscope report);
    void fillImuMsg(depthai_ros_msgs::msg::ImuWithMagneticField& msg, dai::IMUReportRotationVectorWAcc report);
    void fillImuMsg(depthai_ros_msgs::msg::ImuWithMagneticField& msg, dai::IMUReportMagneticField report);

    template <typename I, typename S, typename T, typename F, typename M>
    void CreateUnitMessage(M& msg, std::chrono::_V2::steady_clock::time_point timestamp, I first, S second, T third, F fourth) {
        fillImuMsg(msg, first);
        fillImuMsg(msg, second);
        fillImuMsg(msg, third);
        fillImuMsg(msg, fourth);

        msg.header.frame_id = _frameName;

        msg.header.stamp = getFrameTime(_rosBaseTime, _steadyBaseTime, timestamp);
    }

    template <typename I, typename S, typename T, typename M>
    void CreateUnitMessage(M& msg, std::chrono::_V2::steady_clock::time_point timestamp, I first, S second, T third) {
        fillImuMsg(msg, first);
        fillImuMsg(msg, second);
        fillImuMsg(msg, third);

        msg.header.frame_id = _frameName;

        msg.header.stamp = getFrameTime(_rosBaseTime, _steadyBaseTime, timestamp);
    }

    template <typename I, typename S, typename M>
    void CreateUnitMessage(M& msg, std::chrono::_V2::steady_clock::time_point timestamp, I first, S second) {
        fillImuMsg(msg, first);
        fillImuMsg(msg, second);

        msg.header.frame_id = _frameName;

        msg.header.stamp = getFrameTime(_rosBaseTime, _steadyBaseTime, timestamp);
    }

    template <typename I, typename S, typename M>
    void interpolate(std::deque<I>& interpolated, std::deque<S>& second, std::deque<M>& imuMsgs) {
        I interp0, interp1;
        S currSecond;
        interp0.sequence = -1;
        while(interpolated.size()) {
            if(interp0.sequence == -1) {
                interp0 = interpolated.front();
                interpolated.pop_front();
            } else {
                interp1 = interpolated.front();
                interpolated.pop_front();
                // remove std::milli to get in seconds
                std::chrono::duration<double, std::milli> duration_ms = interp1.timestamp.get() - interp0.timestamp.get();
                double dt = duration_ms.count();
                while(second.size()) {
                    currSecond = second.front();
                    if(currSecond.timestamp.get() > interp0.timestamp.get() && currSecond.timestamp.get() <= interp1.timestamp.get()) {
                        // remove std::milli to get in seconds
                        std::chrono::duration<double, std::milli> diff = currSecond.timestamp.get() - interp0.timestamp.get();
                        const double alpha = diff.count() / dt;
                        I interp = lerpImu(interp0, interp1, alpha);
                        M msg;
                        std::chrono::_V2::steady_clock::time_point tstamp;
                        if(_getBaseDeviceTimestamp)
                            tstamp = currSecond.getTimestampDevice();
                        else
                            tstamp = currSecond.getTimestamp();
                        CreateUnitMessage(msg, tstamp, interp, currSecond);
                        imuMsgs.push_back(msg);
                        second.pop_front();
                    } else if(currSecond.timestamp.get() > interp1.timestamp.get()) {
                        interp0 = interp1;
                        if(interpolated.size()) {
                            interp1 = interpolated.front();
                            interpolated.pop_front();
                            duration_ms = interp1.timestamp.get() - interp0.timestamp.get();
                            dt = duration_ms.count();
                        } else {
                            break;
                        }
                    } else {
                        second.pop_front();
                    }
                }
                interp0 = interp1;
            }
        }
        interpolated.push_back(interp0);
    }

    template <typename I, typename S, typename T, typename M>
    void interpolate(std::deque<I>& interpolated, std::deque<S>& second, std::deque<T>& third, std::deque<M>& imuMsgs) {
        I interp0, interp1;
        S currSecond;
        T currThird;
        interp0.sequence = -1;
        while(interpolated.size()) {
            if(interp0.sequence == -1) {
                interp0 = interpolated.front();
                interpolated.pop_front();
            } else {
                interp1 = interpolated.front();
                interpolated.pop_front();
                // remove std::milli to get in seconds
                std::chrono::duration<double, std::milli> duration_ms = interp1.timestamp.get() - interp0.timestamp.get();
                double dt = duration_ms.count();
                while(second.size()) {
                    currSecond = second.front();
                    currThird = third.front();
                    if(currSecond.timestamp.get() > interp0.timestamp.get() && currSecond.timestamp.get() <= interp1.timestamp.get()) {
                        // remove std::milli to get in seconds
                        std::chrono::duration<double, std::milli> diff = currSecond.timestamp.get() - interp0.timestamp.get();
                        const double alpha = diff.count() / dt;
                        I interp = lerpImu(interp0, interp1, alpha);
                        M msg;
                        std::chrono::_V2::steady_clock::time_point tstamp;
                        if(_getBaseDeviceTimestamp)
                            tstamp = currSecond.getTimestampDevice();
                        else
                            tstamp = currSecond.getTimestamp();
                        CreateUnitMessage(msg, tstamp, interp, currSecond, currThird);
                        imuMsgs.push_back(msg);
                        second.pop_front();
                        third.pop_front();
                    } else if(currSecond.timestamp.get() > interp1.timestamp.get()) {
                        interp0 = interp1;
                        if(interpolated.size()) {
                            interp1 = interpolated.front();
                            interpolated.pop_front();
                            duration_ms = interp1.timestamp.get() - interp0.timestamp.get();
                            dt = duration_ms.count();
                        } else {
                            break;
                        }
                    } else {
                        second.pop_front();
                        third.pop_front();
                    }
                }
                interp0 = interp1;
            }
        }
        interpolated.push_back(interp0);
    }

    template <typename I, typename S, typename T, typename F, typename M>
    void interpolate(std::deque<I>& interpolated, std::deque<S>& second, std::deque<T>& third, std::deque<F>& fourth, std::deque<M>& imuMsgs) {
        I interp0, interp1;
        S currSecond;
        T currThird;
        F currFourth;
        interp0.sequence = -1;
        while(interpolated.size()) {
            if(interp0.sequence == -1) {
                interp0 = interpolated.front();
                interpolated.pop_front();
            } else {
                interp1 = interpolated.front();
                interpolated.pop_front();
                // remove std::milli to get in seconds
                std::chrono::duration<double, std::milli> duration_ms = interp1.timestamp.get() - interp0.timestamp.get();
                double dt = duration_ms.count();
                while(second.size()) {
                    currSecond = second.front();
                    currThird = third.front();
                    currFourth = fourth.front();
                    if(currSecond.timestamp.get() > interp0.timestamp.get() && currSecond.timestamp.get() <= interp1.timestamp.get()) {
                        // remove std::milli to get in seconds
                        std::chrono::duration<double, std::milli> diff = currSecond.timestamp.get() - interp0.timestamp.get();
                        const double alpha = diff.count() / dt;
                        I interp = lerpImu(interp0, interp1, alpha);
                        M msg;
                        std::chrono::_V2::steady_clock::time_point tstamp;
                        if(_getBaseDeviceTimestamp)
                            tstamp = currSecond.getTimestampDevice();
                        else
                            tstamp = currSecond.getTimestamp();
                        CreateUnitMessage(msg, tstamp, interp, currSecond, currThird, currFourth);
                        imuMsgs.push_back(msg);
                        second.pop_front();
                        third.pop_front();
                        fourth.pop_front();
                    } else if(currSecond.timestamp.get() > interp1.timestamp.get()) {
                        interp0 = interp1;
                        if(interpolated.size()) {
                            interp1 = interpolated.front();
                            interpolated.pop_front();
                            duration_ms = interp1.timestamp.get() - interp0.timestamp.get();
                            dt = duration_ms.count();
                        } else {
                            break;
                        }
                    } else {
                        second.pop_front();
                        third.pop_front();
                        fourth.pop_front();
                    }
                }
                interp0 = interp1;
            }
        }
        interpolated.push_back(interp0);
    }
};

}  // namespace ros

namespace rosBridge = ros;

}  // namespace dai