Program Listing for File DetectionNetwork.hpp

Return to documentation for file (include/depthai/pipeline/node/DetectionNetwork.hpp)

#pragma once

#include <depthai/pipeline/Node.hpp>
#include <depthai/pipeline/node/NeuralNetwork.hpp>

#include "depthai/openvino/OpenVINO.hpp"

// standard
#include <fstream>

// shared
#include <depthai-shared/properties/DetectionNetworkProperties.hpp>

namespace dai {
namespace node {

class DetectionNetwork : public NodeCRTP<NeuralNetwork, DetectionNetwork, DetectionNetworkProperties> {
   public:
    constexpr static const char* NAME = "DetectionNetwork";

   protected:
    DetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId);
    DetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId, std::unique_ptr<Properties> props);

   public:
    Output out{*this, "out", Output::Type::MSender, {{DatatypeEnum::ImgDetections, false}}};

    Output outNetwork{*this, "outNetwork", Output::Type::MSender, {{DatatypeEnum::NNData, false}}};

    void setConfidenceThreshold(float thresh);

    float getConfidenceThreshold() const;
};

class MobileNetDetectionNetwork : public NodeCRTP<DetectionNetwork, MobileNetDetectionNetwork, DetectionNetworkProperties> {
   public:
    MobileNetDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId);
    MobileNetDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId, std::unique_ptr<Properties> props);
};

class YoloDetectionNetwork : public NodeCRTP<DetectionNetwork, YoloDetectionNetwork, DetectionNetworkProperties> {
   public:
    YoloDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId);
    YoloDetectionNetwork(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId, std::unique_ptr<Properties> props);

    void setNumClasses(int numClasses);
    void setCoordinateSize(int coordinates);
    void setAnchors(std::vector<float> anchors);
    void setAnchorMasks(std::map<std::string, std::vector<int>> anchorMasks);
    void setIouThreshold(float thresh);

    int getNumClasses() const;
    int getCoordinateSize() const;
    std::vector<float> getAnchors() const;
    std::map<std::string, std::vector<int>> getAnchorMasks() const;
    float getIouThreshold() const;
};

}  // namespace node
}  // namespace dai