Program Listing for File nn_param_handler.hpp
↰ Return to documentation for file (/tmp/ws/src/depthai-ros/depthai_ros_driver/include/depthai_ros_driver/param_handlers/nn_param_handler.hpp
)
#pragma once
#include <fstream>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "depthai-shared/common/CameraBoardSocket.hpp"
#include "depthai/pipeline/datatype/CameraControl.hpp"
#include "depthai_ros_driver/param_handlers/base_param_handler.hpp"
#include "nlohmann/json.hpp"
namespace dai {
namespace node {
class NeuralNetwork;
class MobileNetDetectionNetwork;
class MobileNetSpatialDetectionNetwork;
class YoloDetectionNetwork;
class YoloSpatialDetectionNetwork;
class ImageManip;
} // namespace node
} // namespace dai
namespace rclcpp {
class Node;
class Parameter;
} // namespace rclcpp
namespace depthai_ros_driver {
namespace param_handlers {
namespace nn {
enum class NNFamily { Segmentation, Mobilenet, Yolo };
}
class NNParamHandler : public BaseParamHandler {
public:
explicit NNParamHandler(rclcpp::Node* node, const std::string& name, const dai::CameraBoardSocket& socket = dai::CameraBoardSocket::CAM_A);
~NNParamHandler();
nn::NNFamily getNNFamily();
template <typename T>
void declareParams(std::shared_ptr<T> nn, std::shared_ptr<dai::node::ImageManip> imageManip) {
declareAndLogParam<bool>("i_disable_resize", false);
declareAndLogParam<bool>("i_enable_passthrough", false);
declareAndLogParam<bool>("i_enable_passthrough_depth", false);
declareAndLogParam<bool>("i_get_base_device_timestamp", false);
declareAndLogParam<bool>("i_update_ros_base_time_on_ros_msg", false);
auto nn_path = getParam<std::string>("i_nn_config_path");
using json = nlohmann::json;
std::ifstream f(nn_path);
json data = json::parse(f);
parseConfigFile(nn_path, nn, imageManip);
}
void setNNParams(nlohmann::json data, std::shared_ptr<dai::node::NeuralNetwork> nn);
void setNNParams(nlohmann::json data, std::shared_ptr<dai::node::MobileNetDetectionNetwork> nn);
void setNNParams(nlohmann::json data, std::shared_ptr<dai::node::YoloDetectionNetwork> nn);
void setNNParams(nlohmann::json data, std::shared_ptr<dai::node::MobileNetSpatialDetectionNetwork> nn);
void setNNParams(nlohmann::json data, std::shared_ptr<dai::node::YoloSpatialDetectionNetwork> nn);
template <typename T>
void setSpatialParams(std::shared_ptr<T> nn) {
nn->setBoundingBoxScaleFactor(0.5);
nn->setDepthLowerThreshold(100);
nn->setDepthUpperThreshold(10000);
}
template <typename T>
void setYoloParams(nlohmann::json data, std::shared_ptr<T> nn) {
auto metadata = data["nn_config"]["NN_specific_metadata"];
int num_classes = 80;
if(metadata.contains("classes")) {
num_classes = metadata["classes"].get<int>();
nn->setNumClasses(num_classes);
}
int coordinates = 4;
if(metadata.contains("coordinates")) {
coordinates = metadata["coordinates"].get<int>();
nn->setCoordinateSize(coordinates);
}
std::vector<float> anchors = {10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319};
if(metadata.contains("anchors")) {
anchors = metadata["anchors"].get<std::vector<float>>();
nn->setAnchors(anchors);
}
std::map<std::string, std::vector<int>> anchor_masks = {{"side13", {3, 4, 5}}, {"side26", {1, 2, 3}}};
if(metadata.contains("anchor_masks")) {
anchor_masks.clear();
for(auto& el : metadata["anchor_masks"].items()) {
anchor_masks.insert({el.key(), el.value()});
}
}
nn->setAnchorMasks(anchor_masks);
float iou_threshold = 0.5f;
if(metadata.contains("iou_threshold")) {
iou_threshold = metadata["iou_threshold"].get<float>();
nn->setIouThreshold(iou_threshold);
}
}
template <typename T>
void parseConfigFile(const std::string& path, std::shared_ptr<T> nn, std::shared_ptr<dai::node::ImageManip> imageManip) {
using json = nlohmann::json;
std::ifstream f(path);
json data = json::parse(f);
if(data.contains("model") && data.contains("nn_config")) {
auto modelPath = getModelPath(data);
declareAndLogParam("i_model_path", modelPath);
if(!getParam<bool>("i_disable_resize")) {
setImageManip(modelPath, imageManip);
}
nn->setBlobPath(modelPath);
nn->setNumPoolFrames(declareAndLogParam<int>("i_num_pool_frames", 4));
nn->setNumInferenceThreads(declareAndLogParam<int>("i_num_inference_threads", 2));
nn->input.setBlocking(false);
declareAndLogParam<int>("i_max_q_size", 30);
setNNParams(data, nn);
}
}
dai::CameraControl setRuntimeParams(const std::vector<rclcpp::Parameter>& params) override;
private:
void setImageManip(const std::string& model_path, std::shared_ptr<dai::node::ImageManip> imageManip);
std::string getModelPath(const nlohmann::json& data);
std::unordered_map<std::string, nn::NNFamily> nnFamilyMap;
};
} // namespace param_handlers
} // namespace depthai_ros_driver