SuperPoint.h
Go to the documentation of this file.
1 
5 #ifndef SUPERPOINT_H
6 #define SUPERPOINT_H
7 
8 
9 #include <torch/torch.h>
10 #include <opencv2/opencv.hpp>
11 
12 #include <vector>
13 
14 #ifdef EIGEN_MPL2_ONLY
15 #undef EIGEN_MPL2_ONLY
16 #endif
17 
18 
19 namespace find_object
20 {
21 
22 struct SuperPoint : torch::nn::Module {
23  SuperPoint();
24 
25  std::vector<torch::Tensor> forward(torch::Tensor x);
26 
27 
28  torch::nn::Conv2d conv1a;
29  torch::nn::Conv2d conv1b;
30 
31  torch::nn::Conv2d conv2a;
32  torch::nn::Conv2d conv2b;
33 
34  torch::nn::Conv2d conv3a;
35  torch::nn::Conv2d conv3b;
36 
37  torch::nn::Conv2d conv4a;
38  torch::nn::Conv2d conv4b;
39 
40  torch::nn::Conv2d convPa;
41  torch::nn::Conv2d convPb;
42 
43  // descriptor
44  torch::nn::Conv2d convDa;
45  torch::nn::Conv2d convDb;
46 
47 };
48 
49 class SPDetector {
50 public:
51  SPDetector(const std::string & modelPath, float threshold = 0.2f, bool nms = true, int minDistance = 4, bool cuda = false);
52  virtual ~SPDetector();
53  std::vector<cv::KeyPoint> detect(const cv::Mat &img);
54  cv::Mat compute(const std::vector<cv::KeyPoint> &keypoints);
55 
56  void setThreshold(float threshold) {threshold_ = threshold;}
57  void SetNMS(bool enabled) {nms_ = enabled;}
58  void setMinDistance(float minDistance) {minDistance_ = minDistance;}
59 
60 private:
61  std::shared_ptr<SuperPoint> model_;
62  torch::Tensor prob_;
63  torch::Tensor desc_;
64 
65  float threshold_;
66  bool nms_;
68  bool cuda_;
69 
70  bool detected_;
71 };
72 
73 }
74 
75 #endif
torch::nn::Conv2d convDb
Definition: SuperPoint.h:45
torch::nn::Conv2d conv3b
Definition: SuperPoint.h:35
std::shared_ptr< SuperPoint > model_
Definition: SuperPoint.h:61
torch::Tensor desc_
Definition: SuperPoint.h:63
f
torch::nn::Conv2d conv2a
Definition: SuperPoint.h:31
torch::nn::Conv2d conv3a
Definition: SuperPoint.h:34
void setMinDistance(float minDistance)
Definition: SuperPoint.h:58
torch::nn::Conv2d convPb
Definition: SuperPoint.h:41
torch::nn::Conv2d convPa
Definition: SuperPoint.h:40
torch::nn::Conv2d conv1a
Definition: SuperPoint.h:28
std::vector< torch::Tensor > forward(torch::Tensor x)
Definition: SuperPoint.cc:61
torch::nn::Conv2d conv2b
Definition: SuperPoint.h:32
void SetNMS(bool enabled)
Definition: SuperPoint.h:57
torch::nn::Conv2d convDa
Definition: SuperPoint.h:44
void setThreshold(float threshold)
Definition: SuperPoint.h:56
torch::nn::Conv2d conv1b
Definition: SuperPoint.h:29
torch::nn::Conv2d conv4b
Definition: SuperPoint.h:38
torch::Tensor prob_
Definition: SuperPoint.h:62
torch::nn::Conv2d conv4a
Definition: SuperPoint.h:37


find_object_2d
Author(s): Mathieu Labbe
autogenerated on Mon Dec 12 2022 03:20:09