Program Listing for File Node.hpp

Return to documentation for file (include/depthai/pipeline/Node.hpp)

#pragma once

#include <algorithm>
#include <functional>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_set>
#include <utility>

// project
#include "depthai/openvino/OpenVINO.hpp"
#include "depthai/pipeline/AssetManager.hpp"
#include "depthai/pipeline/MessageQueue.hpp"
#include "depthai/utility/RecordReplay.hpp"
#include "depthai/utility/copyable_unique_ptr.hpp"

// depthai
#include "depthai/capabilities/Capability.hpp"
#include "depthai/pipeline/datatype/DatatypeEnum.hpp"
#include "depthai/properties/Properties.hpp"

// libraries
#include <optional>

namespace dai {
// fwd declare Pipeline
class Pipeline;
class PipelineImpl;

// fwd declare input queue class
class InputQueue;

class Node : public std::enable_shared_from_this<Node> {
    friend class Pipeline;
    friend class PipelineImpl;
    friend class Device;

   public:
    // Nodes must always be managed
    Node(const Node&) = delete;
    Node& operator=(const Node&) = delete;
    Node(Node&&) = delete;
    Node& operator=(Node&&) = delete;

    using Id = std::int64_t;
    struct Connection;
    struct ConnectionInternal;
    // fwd declare classes
    class Input;
    class Output;
    class InputMap;
    class OutputMap;
    struct DatatypeHierarchy {
        DatatypeHierarchy(DatatypeEnum d, bool c) : datatype(d), descendants(c) {}
        DatatypeEnum datatype;
        bool descendants;
    };
    static constexpr auto DEFAULT_GROUP = "";
    static constexpr auto DEFAULT_NAME = "";
#define DEFAULT_TYPES                  \
    {                                  \
        { DatatypeEnum::Buffer, true } \
    }
    static constexpr auto DEFAULT_BLOCKING = true;
    static constexpr auto DEFAULT_QUEUE_SIZE = 3;
    static constexpr auto DEFAULT_WAIT_FOR_MESSAGE = false;
    static constexpr auto BLOCKING_QUEUE = true;
    static constexpr auto NON_BLOCKING_QUEUE = false;

    std::string createUniqueInputName();
    std::string createUniqueOutputName();

   protected:
    std::vector<Output*> outputRefs;
    std::vector<Input*> inputRefs;
    std::vector<OutputMap*> outputMapRefs;
    std::vector<InputMap*> inputMapRefs;
    std::vector<std::shared_ptr<Node>*> nodeRefs;

    // helpers for setting refs
    void setOutputRefs(std::initializer_list<Output*> l);
    void setOutputRefs(Output* outRef);
    void setInputRefs(std::initializer_list<Input*> l);
    void setInputRefs(Input* inRef);
    void setOutputMapRefs(std::initializer_list<OutputMap*> l);
    void setOutputMapRefs(OutputMap* outMapRef);
    void setInputMapRefs(std::initializer_list<InputMap*> l);
    void setInputMapRefs(InputMap* inMapRef);
    void setNodeRefs(std::initializer_list<std::pair<std::string, std::shared_ptr<Node>*>> l);
    void setNodeRefs(std::pair<std::string, std::shared_ptr<Node>*> nodeRef);
    void setNodeRefs(std::string alias, std::shared_ptr<Node>* nodeRef);

   private:
    std::vector<std::string> uniqueNames;

   public:
    struct OutputDescription {
        std::string name{DEFAULT_NAME};
        std::string group{DEFAULT_GROUP};
        std::vector<DatatypeHierarchy> types DEFAULT_TYPES;
    };

    template <typename U>
    friend class Subnode;

    class Output {
        friend class PipelineImpl;

       public:
        struct QueueConnection {
            Output* output;
            std::shared_ptr<MessageQueue> queue;
            bool operator==(const QueueConnection& rhs) const {
                return output == rhs.output && queue == rhs.queue;
            }
        };
        enum class Type { MSender, SSender };
        virtual ~Output() = default;

       private:
        std::reference_wrapper<Node> parent;
        std::vector<MessageQueue*> connectedInputs;
        std::vector<QueueConnection> queueConnections;
        Type type = Type::MSender;  // Slave sender not supported yet
        OutputDescription desc;

       public:
        // std::vector<Capability> possibleCapabilities;

        Output(Node& par, OutputDescription desc, bool ref = true) : parent(par), desc(std::move(desc)) {
            // Place oneself to the parents references
            if(ref) {
                par.setOutputRefs(this);
            }
            if(getName().empty()) {
                setName(par.createUniqueOutputName());
            }
        }

        Node& getParent() {
            return parent;
        }
        const Node& getParent() const {
            return parent;
        }

        std::string toString() const;

        std::string getName() const {
            return desc.name;
        }

        std::string getGroup() const {
            return desc.group;
        }

        void setGroup(std::string group) {
            desc.group = std::move(group);
        }

        void setName(std::string name) {
            desc.name = std::move(name);
        }

        Type getType() const {
            return type;
        }

        std::vector<DatatypeHierarchy> getPossibleDatatypes() const;

        void setPossibleDatatypes(std::vector<DatatypeHierarchy> types);

        bool isSamePipeline(const Input& in);

        bool canConnect(const Input& in);

        std::vector<ConnectionInternal> getConnections();

        std::vector<QueueConnection> getQueueConnections() {
            return queueConnections;
        }

        static constexpr bool OUTPUT_QUEUE_DEFAULT_BLOCKING = false;

        static constexpr unsigned int OUTPUT_QUEUE_DEFAULT_MAX_SIZE = 16;

        std::shared_ptr<dai::MessageQueue> createOutputQueue(unsigned int maxSize = OUTPUT_QUEUE_DEFAULT_MAX_SIZE,
                                                             bool blocking = OUTPUT_QUEUE_DEFAULT_BLOCKING);

       private:
        void link(const std::shared_ptr<dai::MessageQueue>& queue) {
            connectedInputs.push_back(queue.get());
            queueConnections.push_back({this, queue});
        }

        void unlink(const std::shared_ptr<dai::MessageQueue>& queue) {
            connectedInputs.erase(std::remove(connectedInputs.begin(), connectedInputs.end(), queue.get()), connectedInputs.end());
            queueConnections.erase(std::remove(queueConnections.begin(), queueConnections.end(), QueueConnection{this, queue}), queueConnections.end());
        }

       public:
        void link(Input& in);

        virtual void link(std::shared_ptr<Node> in);

        void unlink(Input& in);

        void send(const std::shared_ptr<ADatatype>& msg);

        bool trySend(const std::shared_ptr<ADatatype>& msg);
    };

    struct PairHash {
        template <class T1, class T2>
        std::size_t operator()(const std::pair<T1, T2>& pair) const {
            return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
        }
    };
    class OutputMap : public std::unordered_map<std::pair<std::string, std::string>, Output, PairHash> {
        OutputDescription defaultOutput;
        std::reference_wrapper<Node> parent;

       public:
        std::string name;
        OutputMap(Node& parent, std::string name, OutputDescription defaultOutput, bool ref = true);
        OutputMap(Node& parent, OutputDescription defaultOutput, bool ref = true);
        Output& operator[](const std::string& key);
        Output& operator[](std::pair<std::string, std::string> groupKey);
    };

    // Input extends the message queue with additional option that specifies whether to wait for message or not
    struct InputDescription {
        std::string name = DEFAULT_NAME;                     // Name of the input
        std::string group = DEFAULT_GROUP;                   // Group of the input
        bool blocking{DEFAULT_BLOCKING};                     // Whether to block when input queue is full
        int queueSize{DEFAULT_QUEUE_SIZE};                   // Size of the queue
        std::vector<DatatypeHierarchy> types DEFAULT_TYPES;  // Possible datatypes that can be received
        bool waitForMessage{DEFAULT_WAIT_FOR_MESSAGE};
    };

    class Input : public MessageQueue {
        friend class Output;
        friend class OutputMap;

       public:
        enum class Type { SReceiver, MReceiver };  // TODO(Morato) - refactor, make the MReceiver a separate class (shouldn't inherit from MessageQueue)

        ~Input() override;

       protected:
        std::vector<Output*> connectedOutputs;

       private:
        std::reference_wrapper<Node> parent;
        // Options - more information about the input
        bool waitForMessage{false};
        std::string group;
        Type type = Type::SReceiver;

       public:
        std::vector<DatatypeHierarchy> possibleDatatypes;
        explicit Input(Node& par, InputDescription desc, bool ref = true)
            : MessageQueue(std::move(desc.name), desc.queueSize, desc.blocking),
              parent(par),
              waitForMessage(desc.waitForMessage),
              possibleDatatypes(std::move(desc.types)) {
            if(ref) {
                par.setInputRefs(this);
            }
            if(getName().empty()) {
                setName(par.createUniqueInputName());
            }
        }

        const Node& getParent() const {
            return parent;
        }
        Node& getParent() {
            return parent;
        }

        Type getType() const {
            return type;
        }

        std::string toString() const;

        void setWaitForMessage(bool waitForMessage);

        bool getWaitForMessage() const;

        std::vector<DatatypeHierarchy> getPossibleDatatypes() const;

        void setPossibleDatatypes(std::vector<DatatypeHierarchy> types);

        void setReusePreviousMessage(bool reusePreviousMessage);

        bool getReusePreviousMessage() const;

        void setGroup(std::string group);

        std::string getGroup() const;

        bool isConnected() const;

        static constexpr bool INPUT_QUEUE_DEFAULT_BLOCKING = false;

        static constexpr unsigned int INPUT_QUEUE_DEFAULT_MAX_SIZE = 16;

        std::shared_ptr<InputQueue> createInputQueue(unsigned int maxSize = INPUT_QUEUE_DEFAULT_MAX_SIZE, bool blocking = INPUT_QUEUE_DEFAULT_BLOCKING);
    };

    class InputMap : public std::unordered_map<std::pair<std::string, std::string>, Input, PairHash> {
        std::reference_wrapper<Node> parent;
        InputDescription defaultInput;

       public:
        std::string name;
        // InputMap(Input defaultInput);
        // InputMap(std::string name, Input defaultInput);
        InputMap(Node& parent, InputDescription defaultInput);
        InputMap(Node& parent, std::string name, InputDescription defaultInput);
        Input& operator[](const std::string& key);
        Input& operator[](std::pair<std::string, std::string> groupKey);
        // Check if the input exists
        bool has(const std::string& key) const;
    };

    struct ConnectionInternal {
        ConnectionInternal(Output& out, Input& in);
        std::weak_ptr<Node> outputNode;
        std::string outputName;
        std::string outputGroup;
        std::weak_ptr<Node> inputNode;
        std::string inputName;
        std::string inputGroup;
        Output* out;
        Input* in;
        bool operator==(const ConnectionInternal& rhs) const;
        struct Hash {
            size_t operator()(const dai::Node::ConnectionInternal& obj) const;
        };
    };

    struct Connection {
        friend struct std::hash<Connection>;
        Connection(Output out, Input in);
        Connection(ConnectionInternal c);
        Id outputId;
        std::string outputName;
        std::string outputGroup;
        Id inputId;
        std::string inputName;
        std::string inputGroup;
        bool operator==(const Connection& rhs) const;
    };

   protected:
    bool configureMode{false};

    // when Pipeline tries to serialize and construct on remote, it will check if all connected nodes are on same pipeline
    std::weak_ptr<PipelineImpl> parent;

    // Node ID of the parent node
    int parentId{-1};

    // used to improve error messages
    // when pipeline starts all nodes are checked
    virtual bool needsBuild() {
        return false;
    }

   public:
    // TODO(themarpe) - restrict access
    Id id{-1};
    // used for naming inputs/outputs
    Id inputId{0};
    Id outputId{0};

    std::string alias;

   protected:
    AssetManager assetManager;

    // Optimized for adding, searching and removing connections
    // using NodeMap = std::unordered_map<Node::Id, std::shared_ptr<Node>>;
    using NodeMap = std::vector<std::shared_ptr<Node>>;
    NodeMap nodeMap;

    // Connection map, NodeId represents id of node connected TO (input)
    // using NodeConnectionMap = std::unordered_map<Node::Id, std::unordered_set<Node::Connection>>;
    using SetConnectionInternal = std::unordered_set<ConnectionInternal, ConnectionInternal::Hash>;
    SetConnectionInternal connections;
    using ConnectionMap = std::unordered_map<std::shared_ptr<Node>, SetConnectionInternal>;

   public:
    // access
    Pipeline getParentPipeline();
    const Pipeline getParentPipeline() const;

    std::string getAlias() const {
        return alias;
    }
    void setAlias(std::string alias) {
        this->alias = std::move(alias);
    }

    virtual const char* getName() const = 0;

    virtual void start() {};

    virtual void wait() {};

    virtual void stop() {};

    void stopPipeline();

    virtual void buildStage1();
    virtual void buildStage2();
    virtual void buildStage3();

    std::vector<Output> getOutputs();

    std::vector<Input> getInputs();

    std::vector<Output*> getOutputRefs();

    std::vector<const Output*> getOutputRefs() const;

    std::vector<Input*> getInputRefs();

    std::vector<const Input*> getInputRefs() const;

    std::vector<OutputMap*> getOutputMapRefs();

    std::vector<InputMap*> getInputMapRefs();

    Output* getOutputRef(std::string name);
    Output* getOutputRef(std::string group, std::string name);

    Input* getInputRef(std::string name);
    Input* getInputRef(std::string group, std::string name);

    OutputMap* getOutputMapRef(std::string group);

    InputMap* getInputMapRef(std::string group);

    // For record and replay
    virtual bool isSourceNode() const;

   protected:
    Node() = default;
    Node(bool conf);
    void removeConnectionToNode(std::shared_ptr<Node> node);

   public:
    virtual ~Node() = default;

    const AssetManager& getAssetManager() const;

    AssetManager& getAssetManager();

    std::vector<uint8_t> loadResource(std::filesystem::path uri);

    std::vector<uint8_t> moveResource(std::filesystem::path uri);

    template <class N>
    std::shared_ptr<N> create() {
        // Check that passed type 'N' is subclass of Node
        static_assert(std::is_base_of<Node, N>::value, "Specified class is not a subclass of Node");
        // Create and store the node in the map
        auto node = std::make_shared<N>();
        // Add
        add(node);
        // Return shared pointer to this node
        return node;
    }

    void add(std::shared_ptr<Node> node);

    // Access to nodes
    std::vector<std::shared_ptr<Node>> getAllNodes() const;
    std::shared_ptr<const Node> getNode(Node::Id id) const;
    std::shared_ptr<Node> getNode(Node::Id id);
    void remove(std::shared_ptr<Node> node);
    ConnectionMap getConnectionMap();
    void link(const Node::Output& out, const Node::Input& in);
    void unlink(const Node::Output& out, const Node::Input& in);

    virtual void link(std::shared_ptr<Node> in);
    virtual Node::Output* requestOutput(const Capability& capability, bool onHost);
    virtual std::vector<std::pair<Input&, std::shared_ptr<Capability>>> getRequiredInputs();

    virtual bool runOnHost() const = 0;

    const NodeMap& getNodeMap() const {
        return nodeMap;
    }

    virtual void buildInternal() {};
};

class SourceNode {
   public:
    virtual ~SourceNode() = default;
    virtual NodeRecordParams getNodeRecordParams() const;
    virtual Node::Output& getRecordOutput();
    virtual Node::Input& getReplayInput();
};

// Node CRTP class
template <typename Base, typename Derived>
class NodeCRTP : public Base {
   public:
    virtual ~NodeCRTP() = default;

    const char* getName() const override {
        return Derived::NAME;
    };
    // std::unique_ptr<Node> clone() const override {
    //     return std::make_unique<Derived>(static_cast<const Derived&>(*this));
    // };

    // No public constructor, only a factory function.
    template <typename... Args>
    [[nodiscard]] static std::shared_ptr<Derived> create(Args&&... args) {
        auto nodePtr = std::shared_ptr<Derived>(new Derived(std::forward<Args>(args)...));
        nodePtr->buildInternal();
        return nodePtr;
    }
    [[nodiscard]] static std::shared_ptr<Derived> create(std::unique_ptr<Properties> props) {
        return std::shared_ptr<Derived>(new Derived(props));
    }

    friend Derived;
    friend Base;
};

}  // namespace dai

// Specialization of std::hash for Node::Connection
namespace std {
template <>
struct hash<dai::Node::Connection> {
    size_t operator()(const dai::Node::Connection& obj) const {
        size_t seed = 0;
        std::hash<dai::Node::Id> hId;
        std::hash<std::string> hStr;
        seed ^= hId(obj.outputId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        seed ^= hStr(obj.outputName) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        seed ^= hId(obj.inputId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        seed ^= hStr(obj.outputName) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        return seed;
    }
};

}  // namespace std