src
superpoint_torch
SuperPoint.h
Go to the documentation of this file.
1
5
#ifndef SUPERPOINT_H
6
#define SUPERPOINT_H
7
8
9
#include <torch/torch.h>
10
#include <opencv2/opencv.hpp>
11
12
#include <vector>
13
14
#ifdef EIGEN_MPL2_ONLY
15
#undef EIGEN_MPL2_ONLY
16
#endif
17
18
19
namespace
find_object
20
{
21
22
struct
SuperPoint
: torch::nn::Module {
23
SuperPoint
();
24
25
std::vector<torch::Tensor>
forward
(torch::Tensor x);
26
27
28
torch::nn::Conv2d
conv1a
;
29
torch::nn::Conv2d
conv1b
;
30
31
torch::nn::Conv2d
conv2a
;
32
torch::nn::Conv2d
conv2b
;
33
34
torch::nn::Conv2d
conv3a
;
35
torch::nn::Conv2d
conv3b
;
36
37
torch::nn::Conv2d
conv4a
;
38
torch::nn::Conv2d
conv4b
;
39
40
torch::nn::Conv2d
convPa
;
41
torch::nn::Conv2d
convPb
;
42
43
// descriptor
44
torch::nn::Conv2d
convDa
;
45
torch::nn::Conv2d
convDb
;
46
47
};
48
49
class
SPDetector
{
50
public
:
51
SPDetector
(
const
std::string & modelPath,
float
threshold = 0.2
f
,
bool
nms =
true
,
int
minDistance = 4,
bool
cuda =
false
);
52
virtual
~SPDetector
();
53
std::vector<cv::KeyPoint>
detect
(
const
cv::Mat &img);
54
cv::Mat
compute
(
const
std::vector<cv::KeyPoint> &keypoints);
55
56
void
setThreshold
(
float
threshold) {
threshold_
= threshold;}
57
void
SetNMS
(
bool
enabled) {
nms_
= enabled;}
58
void
setMinDistance
(
float
minDistance) {
minDistance_
= minDistance;}
59
60
private
:
61
std::shared_ptr<SuperPoint>
model_
;
62
torch::Tensor
prob_
;
63
torch::Tensor
desc_
;
64
65
float
threshold_
;
66
bool
nms_
;
67
int
minDistance_
;
68
bool
cuda_
;
69
70
bool
detected_
;
71
};
72
73
}
74
75
#endif
find_object::SuperPoint::conv2b
torch::nn::Conv2d conv2b
Definition:
SuperPoint.h:32
find_object::SuperPoint::convDa
torch::nn::Conv2d convDa
Definition:
SuperPoint.h:44
find_object::SPDetector::prob_
torch::Tensor prob_
Definition:
SuperPoint.h:62
find_object::SuperPoint
Definition:
SuperPoint.h:22
find_object::SuperPoint::convDb
torch::nn::Conv2d convDb
Definition:
SuperPoint.h:45
find_object::SuperPoint::conv3a
torch::nn::Conv2d conv3a
Definition:
SuperPoint.h:34
find_object::SPDetector::nms_
bool nms_
Definition:
SuperPoint.h:66
find_object::SuperPoint::convPb
torch::nn::Conv2d convPb
Definition:
SuperPoint.h:41
find_object::SPDetector::SPDetector
SPDetector(const std::string &modelPath, float threshold=0.2f, bool nms=true, int minDistance=4, bool cuda=false)
Definition:
SuperPoint.cc:113
find_object::SuperPoint::conv3b
torch::nn::Conv2d conv3b
Definition:
SuperPoint.h:35
find_object::SPDetector::detect
std::vector< cv::KeyPoint > detect(const cv::Mat &img)
Definition:
SuperPoint.cc:140
find_object::SuperPoint::conv1b
torch::nn::Conv2d conv1b
Definition:
SuperPoint.h:29
find_object::SuperPoint::convPa
torch::nn::Conv2d convPa
Definition:
SuperPoint.h:40
find_object::SPDetector::detected_
bool detected_
Definition:
SuperPoint.h:70
f
f
find_object::SuperPoint::forward
std::vector< torch::Tensor > forward(torch::Tensor x)
Definition:
SuperPoint.cc:61
find_object::SuperPoint::SuperPoint
SuperPoint()
Definition:
SuperPoint.cc:21
find_object::SPDetector::cuda_
bool cuda_
Definition:
SuperPoint.h:68
find_object::SuperPoint::conv1a
torch::nn::Conv2d conv1a
Definition:
SuperPoint.h:28
find_object::SPDetector::~SPDetector
virtual ~SPDetector()
Definition:
SuperPoint.cc:136
find_object::SPDetector
Definition:
SuperPoint.h:49
find_object::SuperPoint::conv4b
torch::nn::Conv2d conv4b
Definition:
SuperPoint.h:38
find_object::SuperPoint::conv2a
torch::nn::Conv2d conv2a
Definition:
SuperPoint.h:31
find_object
Definition:
Camera.h:38
find_object::SPDetector::compute
cv::Mat compute(const std::vector< cv::KeyPoint > &keypoints)
Definition:
SuperPoint.cc:195
find_object::SPDetector::desc_
torch::Tensor desc_
Definition:
SuperPoint.h:63
find_object::SPDetector::threshold_
float threshold_
Definition:
SuperPoint.h:65
find_object::SPDetector::SetNMS
void SetNMS(bool enabled)
Definition:
SuperPoint.h:57
find_object::SPDetector::minDistance_
int minDistance_
Definition:
SuperPoint.h:67
find_object::SPDetector::model_
std::shared_ptr< SuperPoint > model_
Definition:
SuperPoint.h:61
find_object::SuperPoint::conv4a
torch::nn::Conv2d conv4a
Definition:
SuperPoint.h:37
find_object::SPDetector::setThreshold
void setThreshold(float threshold)
Definition:
SuperPoint.h:56
find_object::SPDetector::setMinDistance
void setMinDistance(float minDistance)
Definition:
SuperPoint.h:58
find_object_2d
Author(s): Mathieu Labbe
autogenerated on Mon Dec 12 2022 03:43:35