22 : conv1a(torch::nn::Conv2dOptions( 1, c1, 3).stride(1).padding(1)),
23 conv1b(torch::nn::Conv2dOptions(c1, c1, 3).stride(1).padding(1)),
25 conv2a(torch::nn::Conv2dOptions(c1, c2, 3).stride(1).padding(1)),
26 conv2b(torch::nn::Conv2dOptions(c2, c2, 3).stride(1).padding(1)),
28 conv3a(torch::nn::Conv2dOptions(c2, c3, 3).stride(1).padding(1)),
29 conv3b(torch::nn::Conv2dOptions(c3, c3, 3).stride(1).padding(1)),
31 conv4a(torch::nn::Conv2dOptions(c3, c4, 3).stride(1).padding(1)),
32 conv4b(torch::nn::Conv2dOptions(c4, c4, 3).stride(1).padding(1)),
34 convPa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
35 convPb(torch::nn::Conv2dOptions(c5, 65, 1).stride(1).padding(0)),
37 convDa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
38 convDb(torch::nn::Conv2dOptions(c5, d1, 1).stride(1).padding(0))
41 register_module(
"conv1a",
conv1a);
42 register_module(
"conv1b",
conv1b);
44 register_module(
"conv2a",
conv2a);
45 register_module(
"conv2b",
conv2b);
47 register_module(
"conv3a",
conv3a);
48 register_module(
"conv3b",
conv3b);
50 register_module(
"conv4a",
conv4a);
51 register_module(
"conv4b",
conv4b);
53 register_module(
"convPa",
convPa);
54 register_module(
"convPb",
convPb);
56 register_module(
"convDa",
convDa);
57 register_module(
"convDb",
convDb);
63 x = torch::relu(
conv1a->forward(x));
64 x = torch::relu(
conv1b->forward(x));
65 x = torch::max_pool2d(x, 2, 2);
67 x = torch::relu(
conv2a->forward(x));
68 x = torch::relu(
conv2b->forward(x));
69 x = torch::max_pool2d(x, 2, 2);
71 x = torch::relu(
conv3a->forward(x));
72 x = torch::relu(
conv3b->forward(x));
73 x = torch::max_pool2d(x, 2, 2);
75 x = torch::relu(
conv4a->forward(x));
76 x = torch::relu(
conv4b->forward(x));
78 auto cPa = torch::relu(
convPa->forward(x));
79 auto semi =
convPb->forward(cPa);
81 auto cDa = torch::relu(
convDa->forward(x));
82 auto desc =
convDb->forward(cDa);
84 auto dn = torch::norm(desc, 2, 1);
85 desc = desc.div(torch::unsqueeze(dn, 1));
87 semi = torch::softmax(semi, 1);
88 semi = semi.slice(1, 0, 64);
89 semi = semi.permute({0, 2, 3, 1});
92 int Hc = semi.size(1);
93 int Wc = semi.size(2);
94 semi = semi.contiguous().view({-1, Hc, Wc, 8, 8});
95 semi = semi.permute({0, 1, 3, 2, 4});
96 semi = semi.contiguous().view({-1, Hc * 8, Wc * 8});
99 std::vector<torch::Tensor> ret;
106 void NMS(
const std::vector<cv::KeyPoint> & ptsIn,
107 const cv::Mat & conf,
108 const cv::Mat & descriptorsIn,
109 std::vector<cv::KeyPoint> & ptsOut,
110 cv::Mat & descriptorsOut,
111 int border,
int dist_thresh,
int img_width,
int img_height);
114 threshold_(threshold),
116 minDistance_(minDistance),
119 UDEBUG(
"modelPath=%s thr=%f nms=%d cuda=%d", modelPath.c_str(), threshold, nms?1:0, cuda?1:0);
120 if(modelPath.empty())
124 model_ = std::make_shared<SuperPoint>();
125 torch::load(
model_, modelPath);
127 if(cuda && !torch::cuda::is_available())
129 UWARN(
"Cuda option is enabled but torch doesn't have cuda support on this platform, using CPU instead.");
131 cuda_ = cuda && torch::cuda::is_available();
132 torch::Device device(
cuda_?torch::kCUDA:torch::kCPU);
145 torch::NoGradGuard no_grad_guard;
146 auto x = torch::from_blob(img.data, {1, 1, img.rows, img.cols}, torch::kByte);
147 x = x.to(torch::kFloat) / 255;
149 torch::Device device(
cuda_?torch::kCUDA:torch::kCPU);
150 x = x.set_requires_grad(
false);
151 auto out =
model_->forward(x.to(device));
153 prob_ = out[0].squeeze(0);
157 kpts = torch::nonzero(kpts);
159 std::vector<cv::KeyPoint> keypoints_no_nms;
160 for (
int i = 0; i < kpts.size(0); i++) {
161 float response =
prob_[kpts[i][0]][kpts[i][1]].item<
float>();
162 keypoints_no_nms.push_back(cv::KeyPoint(kpts[i][1].item<float>(), kpts[i][0].item<float>(), 8, -1, response));
166 if (
nms_ && !keypoints_no_nms.empty()) {
167 cv::Mat conf(keypoints_no_nms.size(), 1, CV_32F);
168 for (
size_t i = 0; i < keypoints_no_nms.size(); i++) {
169 int x = keypoints_no_nms[i].pt.x;
170 int y = keypoints_no_nms[i].pt.y;
171 conf.at<
float>(i, 0) =
prob_[y][x].item<float>();
176 int height = img.rows;
177 int width = img.cols;
179 std::vector<cv::KeyPoint> keypoints;
181 NMS(keypoints_no_nms, conf, descEmpty, keypoints, descEmpty, border, dist_thresh, width, height);
185 return keypoints_no_nms;
190 UERROR(
"No model is loaded!");
191 return std::vector<cv::KeyPoint>();
199 UERROR(
"SPDetector has been reset before extracting the descriptors! detect() should be called before compute().");
204 cv::Mat kpt_mat(keypoints.size(), 2, CV_32F);
206 for (
size_t i = 0; i < keypoints.size(); i++) {
207 kpt_mat.at<
float>(i, 0) = (
float)keypoints[i].pt.y;
208 kpt_mat.at<
float>(i, 1) = (
float)keypoints[i].pt.x;
211 auto fkpts = torch::from_blob(kpt_mat.data, {(long int)keypoints.size(), 2}, torch::kFloat);
213 torch::Device device(
cuda_?torch::kCUDA:torch::kCPU);
214 auto grid = torch::zeros({1, 1, fkpts.size(0), 2}).to(device);
215 grid[0][0].slice(1, 0, 1) = 2.0 * fkpts.slice(1, 1, 2) /
prob_.size(1) - 1;
216 grid[0][0].slice(1, 1, 2) = 2.0 * fkpts.slice(1, 0, 1) /
prob_.size(0) - 1;
218 auto desc = torch::grid_sampler(
desc_, grid, 0, 0,
true);
219 desc = desc.squeeze(0).squeeze(1);
222 auto dn = torch::norm(desc, 2, 1);
223 desc = desc.div(torch::unsqueeze(dn, 1));
225 desc = desc.transpose(0, 1).contiguous();
227 desc = desc.to(torch::kCPU);
229 cv::Mat desc_mat(cv::Size(desc.size(1), desc.size(0)), CV_32FC1, desc.data<
float>());
231 return desc_mat.clone();
235 UERROR(
"No model is loaded!");
240 void NMS(
const std::vector<cv::KeyPoint> & ptsIn,
241 const cv::Mat & conf,
242 const cv::Mat & descriptorsIn,
243 std::vector<cv::KeyPoint> & ptsOut,
244 cv::Mat & descriptorsOut,
245 int border,
int dist_thresh,
int img_width,
int img_height)
248 std::vector<cv::Point2f> pts_raw;
250 for (
size_t i = 0; i < ptsIn.size(); i++)
252 int u = (int) ptsIn[i].pt.x;
253 int v = (
int) ptsIn[i].pt.y;
255 pts_raw.push_back(cv::Point2f(u, v));
262 cv::Mat grid = cv::Mat(cv::Size(img_width, img_height), CV_8UC1);
263 cv::Mat inds = cv::Mat(cv::Size(img_width, img_height), CV_16UC1);
265 cv::Mat confidence = cv::Mat(cv::Size(img_width, img_height), CV_32FC1);
271 for (
size_t i = 0; i < pts_raw.size(); i++)
273 int uu = (int) pts_raw[i].x;
274 int vv = (int) pts_raw[i].y;
276 grid.at<
unsigned char>(vv, uu) = 100;
277 inds.at<
unsigned short>(vv, uu) = i;
279 confidence.at<
float>(vv, uu) = conf.at<
float>(i, 0);
288 cv::copyMakeBorder(grid, grid, dist_thresh, dist_thresh, dist_thresh, dist_thresh, cv::BORDER_CONSTANT, 0);
290 for (
size_t i = 0; i < pts_raw.size(); i++)
293 int uu = (int) pts_raw[i].x + dist_thresh;
294 int vv = (int) pts_raw[i].y + dist_thresh;
295 float c = confidence.at<
float>(vv-dist_thresh, uu-dist_thresh);
297 if (grid.at<
unsigned char>(vv, uu) == 100)
299 for(
int k = -dist_thresh; k < (dist_thresh+1); k++)
301 for(
int j = -dist_thresh; j < (dist_thresh+1); j++)
306 if ( confidence.at<
float>(vv + k - dist_thresh, uu + j - dist_thresh) <= c )
308 grid.at<
unsigned char>(vv + k, uu + j) = 0;
312 grid.at<
unsigned char>(vv, uu) = 255;
316 size_t valid_cnt = 0;
317 std::vector<int> select_indice;
319 grid = cv::Mat(grid, cv::Rect(dist_thresh, dist_thresh, img_width, img_height));
324 for (
int v = 0; v < img_height; v++)
326 for (
int u = 0; u < img_width; u++)
328 if (grid.at<
unsigned char>(v,u) == 255)
330 int select_ind = (int) inds.at<
unsigned short>(v, u);
331 float response = conf.at<
float>(select_ind, 0);
332 ptsOut.push_back(cv::KeyPoint(pts_raw[select_ind], 8.0
f, -1, response));
334 select_indice.push_back(select_ind);
340 if(!descriptorsIn.empty())
342 UASSERT(descriptorsIn.rows == (
int)ptsIn.size());
343 descriptorsOut.create(select_indice.size(), 256, CV_32F);
345 for (
size_t i=0; i<select_indice.size(); i++)
347 for (
int j=0; j < 256; j++)
349 descriptorsOut.at<
float>(i, j) = descriptorsIn.at<
float>(select_indice[i], j);
std::shared_ptr< SuperPoint > model_
cv::Mat compute(const std::vector< cv::KeyPoint > &keypoints)
#define UASSERT(condition)
std::vector< torch::Tensor > forward(torch::Tensor x)
SPDetector(const std::string &modelPath, float threshold=0.2f, bool nms=true, int minDistance=4, bool cuda=false)
ULogger class and convenient macros.
void NMS(const std::vector< cv::KeyPoint > &ptsIn, const cv::Mat &conf, const cv::Mat &descriptorsIn, std::vector< cv::KeyPoint > &ptsOut, cv::Mat &descriptorsOut, int border, int dist_thresh, int img_width, int img_height)
std::vector< cv::KeyPoint > detect(const cv::Mat &img)
const std::string response