25 : conv1a(torch::nn::Conv2dOptions( 1, c1, 3).stride(1).padding(1)),
26 conv1b(torch::nn::Conv2dOptions(c1, c1, 3).stride(1).padding(1)),
28 conv2a(torch::nn::Conv2dOptions(c1, c2, 3).stride(1).padding(1)),
29 conv2b(torch::nn::Conv2dOptions(c2, c2, 3).stride(1).padding(1)),
31 conv3a(torch::nn::Conv2dOptions(c2, c3, 3).stride(1).padding(1)),
32 conv3b(torch::nn::Conv2dOptions(c3, c3, 3).stride(1).padding(1)),
34 conv4a(torch::nn::Conv2dOptions(c3, c4, 3).stride(1).padding(1)),
35 conv4b(torch::nn::Conv2dOptions(c4, c4, 3).stride(1).padding(1)),
37 convPa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
38 convPb(torch::nn::Conv2dOptions(c5, 65, 1).stride(1).padding(0)),
40 convDa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
41 convDb(torch::nn::Conv2dOptions(c5, d1, 1).stride(1).padding(0))
44 register_module(
"conv1a",
conv1a);
45 register_module(
"conv1b",
conv1b);
47 register_module(
"conv2a",
conv2a);
48 register_module(
"conv2b",
conv2b);
50 register_module(
"conv3a",
conv3a);
51 register_module(
"conv3b",
conv3b);
53 register_module(
"conv4a",
conv4a);
54 register_module(
"conv4b",
conv4b);
56 register_module(
"convPa",
convPa);
57 register_module(
"convPb",
convPb);
59 register_module(
"convDa",
convDa);
60 register_module(
"convDb",
convDb);
66 x = torch::relu(
conv1a->forward(x));
67 x = torch::relu(
conv1b->forward(x));
68 x = torch::max_pool2d(x, 2, 2);
70 x = torch::relu(
conv2a->forward(x));
71 x = torch::relu(
conv2b->forward(x));
72 x = torch::max_pool2d(x, 2, 2);
74 x = torch::relu(
conv3a->forward(x));
75 x = torch::relu(
conv3b->forward(x));
76 x = torch::max_pool2d(x, 2, 2);
78 x = torch::relu(
conv4a->forward(x));
79 x = torch::relu(
conv4b->forward(x));
81 auto cPa = torch::relu(
convPa->forward(x));
82 auto semi =
convPb->forward(cPa);
84 auto cDa = torch::relu(
convDa->forward(x));
85 auto desc =
convDb->forward(cDa);
87 auto dn = torch::norm(desc, 2, 1);
88 desc = desc.div(torch::unsqueeze(dn, 1));
90 semi = torch::softmax(semi, 1);
91 semi = semi.slice(1, 0, 64);
92 semi = semi.permute({0, 2, 3, 1});
95 int Hc = semi.size(1);
96 int Wc = semi.size(2);
97 semi = semi.contiguous().view({-1, Hc, Wc, 8, 8});
98 semi = semi.permute({0, 1, 3, 2, 4});
99 semi = semi.contiguous().view({-1, Hc * 8, Wc * 8});
102 std::vector<torch::Tensor> ret;
109 void NMS(
const std::vector<cv::KeyPoint> & ptsIn,
110 const cv::Mat & conf,
111 const cv::Mat & descriptorsIn,
112 std::vector<cv::KeyPoint> & ptsOut,
113 cv::Mat & descriptorsOut,
114 int border,
int dist_thresh,
int img_width,
int img_height);
117 threshold_(threshold),
119 minDistance_(minDistance),
122 UDEBUG(
"modelPath=%s thr=%f nms=%d cuda=%d", modelPath.c_str(), threshold, nms?1:0, cuda?1:0);
123 if(modelPath.empty())
125 UERROR(
"Model's path is empty!");
131 UERROR(
"Model's path \"%s\" doesn't exist!", path.c_str());
134 model_ = std::make_shared<SuperPoint>();
137 if(cuda && !torch::cuda::is_available())
139 UWARN(
"Cuda option is enabled but torch doesn't have cuda support on this platform, using CPU instead.");
141 cuda_ = cuda && torch::cuda::is_available();
142 torch::Device
device(
cuda_?torch::kCUDA:torch::kCPU);
152 UASSERT(img.type() == CV_8UC1);
153 UASSERT(mask.empty() || (mask.type() == CV_8UC1 && img.cols == mask.cols && img.rows == mask.rows));
157 torch::NoGradGuard no_grad_guard;
158 auto x = torch::from_blob(img.data, {1, 1, img.rows, img.cols}, torch::kByte);
159 x = x.to(torch::kFloat) / 255;
161 torch::Device
device(
cuda_?torch::kCUDA:torch::kCPU);
162 x = x.set_requires_grad(
false);
163 auto out =
model_->forward(x.to(device));
165 prob_ = out[0].squeeze(0);
169 kpts = torch::nonzero(kpts);
172 auto kpts_cpu = kpts.to(torch::kCPU);
173 auto prob_cpu =
prob_.to(torch::kCPU);
175 std::vector<cv::KeyPoint> keypoints_no_nms;
176 for (
int i = 0; i < kpts_cpu.size(0); i++) {
177 if(mask.empty() || mask.at<
unsigned char>(kpts_cpu[i][0].item<
int>(), kpts_cpu[i][1].item<int>()) != 0)
179 float response = prob_cpu[kpts_cpu[i][0]][kpts_cpu[i][1]].item<
float>();
180 keypoints_no_nms.push_back(cv::KeyPoint(kpts_cpu[i][1].item<float>(), kpts_cpu[i][0].item<float>(), 8, -1, response));
185 if (
nms_ && !keypoints_no_nms.empty()) {
186 cv::Mat conf(keypoints_no_nms.size(), 1, CV_32F);
187 for (
size_t i = 0; i < keypoints_no_nms.size(); i++) {
188 int x = keypoints_no_nms[i].pt.x;
189 int y = keypoints_no_nms[i].pt.y;
190 conf.at<
float>(i, 0) = prob_cpu[y][x].item<float>();
195 int height = img.rows;
196 int width = img.cols;
198 std::vector<cv::KeyPoint> keypoints;
200 NMS(keypoints_no_nms, conf, descEmpty, keypoints, descEmpty, border, dist_thresh, width, height);
204 return keypoints_no_nms;
209 UERROR(
"No model is loaded!");
210 return std::vector<cv::KeyPoint>();
218 UERROR(
"SPDetector has been reset before extracting the descriptors! detect() should be called before compute().");
223 cv::Mat kpt_mat(keypoints.size(), 2, CV_32F);
228 for (
size_t i = 0; i < keypoints.size(); i++) {
229 kpt_mat.at<
float>(i, 0) = (
float)keypoints[i].pt.y - s/2 + 0.5;
230 kpt_mat.at<
float>(i, 1) = (
float)keypoints[i].pt.x - s/2 + 0.5;
233 auto fkpts = torch::from_blob(kpt_mat.data, {(long int)keypoints.size(), 2}, torch::kFloat);
235 float w =
desc_.size(3);
236 float h =
desc_.size(2);
238 torch::Device
device(
cuda_?torch::kCUDA:torch::kCPU);
239 auto grid = torch::zeros({1, 1, fkpts.size(0), 2}).to(device);
240 grid[0][0].slice(1, 0, 1) = 2.0 * fkpts.slice(1, 1, 2) / (w*s - s/2 - 0.5) - 1;
241 grid[0][0].slice(1, 1, 2) = 2.0 * fkpts.slice(1, 0, 1) / (h*s - s/2 - 0.5) - 1;
243 auto desc = torch::grid_sampler(
desc_, grid, 0, 0,
true);
247 desc = desc.squeeze();
248 desc = desc.transpose(0, 1).contiguous();
251 desc = desc.to(torch::kCPU);
253 cv::Mat desc_mat(cv::Size(desc.size(1), desc.size(0)), CV_32FC1, desc.data_ptr<
float>());
255 return desc_mat.clone();
259 UERROR(
"No model is loaded!");
264 void NMS(
const std::vector<cv::KeyPoint> & ptsIn,
265 const cv::Mat & conf,
266 const cv::Mat & descriptorsIn,
267 std::vector<cv::KeyPoint> & ptsOut,
268 cv::Mat & descriptorsOut,
269 int border,
int dist_thresh,
int img_width,
int img_height)
272 std::vector<cv::Point2f> pts_raw;
274 for (
size_t i = 0; i < ptsIn.size(); i++)
276 int u = (int) ptsIn[i].pt.x;
277 int v = (
int) ptsIn[i].pt.y;
279 pts_raw.push_back(cv::Point2f(u, v));
286 cv::Mat grid = cv::Mat(cv::Size(img_width, img_height), CV_8UC1);
287 cv::Mat inds = cv::Mat(cv::Size(img_width, img_height), CV_16UC1);
289 cv::Mat confidence = cv::Mat(cv::Size(img_width, img_height), CV_32FC1);
295 for (
size_t i = 0; i < pts_raw.size(); i++)
297 int uu = (int) pts_raw[i].x;
298 int vv = (int) pts_raw[i].y;
300 grid.at<
unsigned char>(vv, uu) = 100;
301 inds.at<
unsigned short>(vv, uu) = i;
303 confidence.at<
float>(vv, uu) = conf.at<
float>(i, 0);
312 cv::copyMakeBorder(grid, grid, dist_thresh, dist_thresh, dist_thresh, dist_thresh, cv::BORDER_CONSTANT, 0);
314 for (
size_t i = 0; i < pts_raw.size(); i++)
317 int uu = (int) pts_raw[i].x + dist_thresh;
318 int vv = (int) pts_raw[i].y + dist_thresh;
319 float c = confidence.at<
float>(vv-dist_thresh, uu-dist_thresh);
321 if (grid.at<
unsigned char>(vv, uu) == 100)
323 for(
int k = -dist_thresh; k < (dist_thresh+1); k++)
325 for(
int j = -dist_thresh; j < (dist_thresh+1); j++)
330 if ( confidence.at<
float>(vv + k - dist_thresh, uu + j - dist_thresh) <= c )
332 grid.at<
unsigned char>(vv + k, uu + j) = 0;
336 grid.at<
unsigned char>(vv, uu) = 255;
340 size_t valid_cnt = 0;
341 std::vector<int> select_indice;
343 grid = cv::Mat(grid, cv::Rect(dist_thresh, dist_thresh, img_width, img_height));
348 for (
int v = 0; v < img_height; v++)
350 for (
int u = 0; u < img_width; u++)
352 if (grid.at<
unsigned char>(v,u) == 255)
354 int select_ind = (int) inds.at<
unsigned short>(v, u);
355 float response = conf.at<
float>(select_ind, 0);
356 ptsOut.push_back(cv::KeyPoint(pts_raw[select_ind], 8.0
f, -1, response));
358 select_indice.push_back(select_ind);
364 if(!descriptorsIn.empty())
366 UASSERT(descriptorsIn.rows == (
int)ptsIn.size());
367 descriptorsOut.create(select_indice.size(), 256, CV_32F);
369 for (
size_t i=0; i<select_indice.size(); i++)
371 for (
int j=0; j < 256; j++)
373 descriptorsOut.at<
float>(i, j) = descriptorsIn.at<
float>(select_indice[i], j);
GLM_FUNC_DECL genIType mask(genIType const &count)
static std::string homeDir()
std::shared_ptr< SuperPoint > model_
cv::Mat compute(const std::vector< cv::KeyPoint > &keypoints)
std::vector< torch::Tensor > forward(torch::Tensor x)
SPDetector(const std::string &modelPath, float threshold=0.2f, bool nms=true, int minDistance=4, bool cuda=false)
Some conversion functions.
#define UASSERT(condition)
GLM_FUNC_DECL genType normalize(genType const &x)
void NMS(const std::vector< cv::KeyPoint > &ptsIn, const cv::Mat &conf, const cv::Mat &descriptorsIn, std::vector< cv::KeyPoint > &ptsOut, cv::Mat &descriptorsOut, int border, int dist_thresh, int img_width, int img_height)
std::vector< cv::KeyPoint > detect(const cv::Mat &img, const cv::Mat &mask=cv::Mat())
std::string UTILITE_EXP uReplaceChar(const std::string &str, char before, char after)
ULogger class and convenient macros.