SuperPoint.cc
Go to the documentation of this file.
1 
10 
11 
12 namespace rtabmap
13 {
14 
15 const int c1 = 64;
16 const int c2 = 64;
17 const int c3 = 128;
18 const int c4 = 128;
19 const int c5 = 256;
20 const int d1 = 256;
21 
22 
23 
25  : conv1a(torch::nn::Conv2dOptions( 1, c1, 3).stride(1).padding(1)),
26  conv1b(torch::nn::Conv2dOptions(c1, c1, 3).stride(1).padding(1)),
27 
28  conv2a(torch::nn::Conv2dOptions(c1, c2, 3).stride(1).padding(1)),
29  conv2b(torch::nn::Conv2dOptions(c2, c2, 3).stride(1).padding(1)),
30 
31  conv3a(torch::nn::Conv2dOptions(c2, c3, 3).stride(1).padding(1)),
32  conv3b(torch::nn::Conv2dOptions(c3, c3, 3).stride(1).padding(1)),
33 
34  conv4a(torch::nn::Conv2dOptions(c3, c4, 3).stride(1).padding(1)),
35  conv4b(torch::nn::Conv2dOptions(c4, c4, 3).stride(1).padding(1)),
36 
37  convPa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
38  convPb(torch::nn::Conv2dOptions(c5, 65, 1).stride(1).padding(0)),
39 
40  convDa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
41  convDb(torch::nn::Conv2dOptions(c5, d1, 1).stride(1).padding(0))
42 
43  {
44  register_module("conv1a", conv1a);
45  register_module("conv1b", conv1b);
46 
47  register_module("conv2a", conv2a);
48  register_module("conv2b", conv2b);
49 
50  register_module("conv3a", conv3a);
51  register_module("conv3b", conv3b);
52 
53  register_module("conv4a", conv4a);
54  register_module("conv4b", conv4b);
55 
56  register_module("convPa", convPa);
57  register_module("convPb", convPb);
58 
59  register_module("convDa", convDa);
60  register_module("convDb", convDb);
61  }
62 
63 
64 std::vector<torch::Tensor> SuperPoint::forward(torch::Tensor x) {
65 
66  x = torch::relu(conv1a->forward(x));
67  x = torch::relu(conv1b->forward(x));
68  x = torch::max_pool2d(x, 2, 2);
69 
70  x = torch::relu(conv2a->forward(x));
71  x = torch::relu(conv2b->forward(x));
72  x = torch::max_pool2d(x, 2, 2);
73 
74  x = torch::relu(conv3a->forward(x));
75  x = torch::relu(conv3b->forward(x));
76  x = torch::max_pool2d(x, 2, 2);
77 
78  x = torch::relu(conv4a->forward(x));
79  x = torch::relu(conv4b->forward(x));
80 
81  auto cPa = torch::relu(convPa->forward(x));
82  auto semi = convPb->forward(cPa); // [B, 65, H/8, W/8]
83 
84  auto cDa = torch::relu(convDa->forward(x));
85  auto desc = convDb->forward(cDa); // [B, d1, H/8, W/8]
86 
87  auto dn = torch::norm(desc, 2, 1);
88  desc = desc.div(torch::unsqueeze(dn, 1));
89 
90  semi = torch::softmax(semi, 1);
91  semi = semi.slice(1, 0, 64);
92  semi = semi.permute({0, 2, 3, 1}); // [B, H/8, W/8, 64]
93 
94 
95  int Hc = semi.size(1);
96  int Wc = semi.size(2);
97  semi = semi.contiguous().view({-1, Hc, Wc, 8, 8});
98  semi = semi.permute({0, 1, 3, 2, 4});
99  semi = semi.contiguous().view({-1, Hc * 8, Wc * 8}); // [B, H, W]
100 
101 
102  std::vector<torch::Tensor> ret;
103  ret.push_back(semi);
104  ret.push_back(desc);
105 
106  return ret;
107  }
108 
109 void NMS(const std::vector<cv::KeyPoint> & ptsIn,
110  const cv::Mat & conf,
111  const cv::Mat & descriptorsIn,
112  std::vector<cv::KeyPoint> & ptsOut,
113  cv::Mat & descriptorsOut,
114  int border, int dist_thresh, int img_width, int img_height);
115 
116 SPDetector::SPDetector(const std::string & modelPath, float threshold, bool nms, int minDistance, bool cuda) :
117  threshold_(threshold),
118  nms_(nms),
119  minDistance_(minDistance),
120  detected_(false)
121 {
122  UDEBUG("modelPath=%s thr=%f nms=%d cuda=%d", modelPath.c_str(), threshold, nms?1:0, cuda?1:0);
123  if(modelPath.empty())
124  {
125  UERROR("Model's path is empty!");
126  return;
127  }
128  std::string path = uReplaceChar(modelPath, '~', UDirectory::homeDir());
129  if(!UFile::exists(path))
130  {
131  UERROR("Model's path \"%s\" doesn't exist!", path.c_str());
132  return;
133  }
134  model_ = std::make_shared<SuperPoint>();
135  torch::load(model_, uReplaceChar(path, '~', UDirectory::homeDir()));
136 
137  if(cuda && !torch::cuda::is_available())
138  {
139  UWARN("Cuda option is enabled but torch doesn't have cuda support on this platform, using CPU instead.");
140  }
141  cuda_ = cuda && torch::cuda::is_available();
142  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
143  model_->to(device);
144 }
145 
147 {
148 }
149 
150 std::vector<cv::KeyPoint> SPDetector::detect(const cv::Mat &img, const cv::Mat & mask)
151 {
152  UASSERT(img.type() == CV_8UC1);
153  UASSERT(mask.empty() || (mask.type() == CV_8UC1 && img.cols == mask.cols && img.rows == mask.rows));
154  detected_ = false;
155  if(model_)
156  {
157  torch::NoGradGuard no_grad_guard;
158  auto x = torch::from_blob(img.data, {1, 1, img.rows, img.cols}, torch::kByte);
159  x = x.to(torch::kFloat) / 255;
160 
161  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
162  x = x.set_requires_grad(false);
163  auto out = model_->forward(x.to(device));
164 
165  prob_ = out[0].squeeze(0); // [H, W]
166  desc_ = out[1]; // [1, 256, H/8, W/8]
167 
168  auto kpts = (prob_ > threshold_);
169  kpts = torch::nonzero(kpts); // [n_keypoints, 2] (y, x)
170 
171  //convert back to cpu if in gpu
172  auto kpts_cpu = kpts.to(torch::kCPU);
173  auto prob_cpu = prob_.to(torch::kCPU);
174 
175  std::vector<cv::KeyPoint> keypoints_no_nms;
176  for (int i = 0; i < kpts_cpu.size(0); i++) {
177  if(mask.empty() || mask.at<unsigned char>(kpts_cpu[i][0].item<int>(), kpts_cpu[i][1].item<int>()) != 0)
178  {
179  float response = prob_cpu[kpts_cpu[i][0]][kpts_cpu[i][1]].item<float>();
180  keypoints_no_nms.push_back(cv::KeyPoint(kpts_cpu[i][1].item<float>(), kpts_cpu[i][0].item<float>(), 8, -1, response));
181  }
182  }
183 
184  detected_ = true;
185  if (nms_ && !keypoints_no_nms.empty()) {
186  cv::Mat conf(keypoints_no_nms.size(), 1, CV_32F);
187  for (size_t i = 0; i < keypoints_no_nms.size(); i++) {
188  int x = keypoints_no_nms[i].pt.x;
189  int y = keypoints_no_nms[i].pt.y;
190  conf.at<float>(i, 0) = prob_cpu[y][x].item<float>();
191  }
192 
193  int border = 0;
194  int dist_thresh = minDistance_;
195  int height = img.rows;
196  int width = img.cols;
197 
198  std::vector<cv::KeyPoint> keypoints;
199  cv::Mat descEmpty;
200  NMS(keypoints_no_nms, conf, descEmpty, keypoints, descEmpty, border, dist_thresh, width, height);
201  if(keypoints.size()>1)
202  {
203  return keypoints;
204  }
205  return std::vector<cv::KeyPoint>();
206  }
207  else if(keypoints_no_nms.size()>1)
208  {
209  return keypoints_no_nms;
210  }
211  else
212  {
213  return std::vector<cv::KeyPoint>();
214  }
215  }
216  else
217  {
218  UERROR("No model is loaded!");
219  return std::vector<cv::KeyPoint>();
220  }
221 }
222 
223 cv::Mat SPDetector::compute(const std::vector<cv::KeyPoint> &keypoints)
224 {
225  if(!detected_)
226  {
227  UERROR("SPDetector has been reset before extracting the descriptors! detect() should be called before compute().");
228  return cv::Mat();
229  }
230  if(keypoints.empty())
231  {
232  return cv::Mat();
233  }
234  if(model_.get())
235  {
236  cv::Mat kpt_mat(keypoints.size(), 2, CV_32F); // [n_keypoints, 2] (y, x)
237 
238  // Based on sample_descriptors() of SuperPoint implementation in SuperGlue:
239  // https://github.com/magicleap/SuperGluePretrainedNetwork/blob/45a750e5707696da49472f1cad35b0b203325417/models/superpoint.py#L80-L92
240  float s = 8;
241  for (size_t i = 0; i < keypoints.size(); i++) {
242  kpt_mat.at<float>(i, 0) = (float)keypoints[i].pt.y - s/2 + 0.5;
243  kpt_mat.at<float>(i, 1) = (float)keypoints[i].pt.x - s/2 + 0.5;
244  }
245 
246  auto fkpts = torch::from_blob(kpt_mat.data, {(long int)keypoints.size(), 2}, torch::kFloat);
247 
248  float w = desc_.size(3); //W/8
249  float h = desc_.size(2); //H/8
250 
251  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
252  auto grid = torch::zeros({1, 1, fkpts.size(0), 2}).to(device); // [1, 1, n_keypoints, 2]
253  grid[0][0].slice(1, 0, 1) = 2.0 * fkpts.slice(1, 1, 2) / (w*s - s/2 - 0.5) - 1; // x
254  grid[0][0].slice(1, 1, 2) = 2.0 * fkpts.slice(1, 0, 1) / (h*s - s/2 - 0.5) - 1; // y
255 
256  auto desc = torch::grid_sampler(desc_, grid, 0, 0, true); // [1, 256, 1, n_keypoints]
257 
258  // normalize to 1
259  desc = torch::nn::functional::normalize(desc.reshape({1, desc_.size(1), -1})); //[1, 256, n_keypoints]
260  desc = desc.squeeze(); //[256, n_keypoints]
261  desc = desc.transpose(0, 1).contiguous(); //[n_keypoints, 256]
262 
263  if(cuda_)
264  desc = desc.to(torch::kCPU);
265 
266  cv::Mat desc_mat(cv::Size(desc.size(1), desc.size(0)), CV_32FC1, desc.data_ptr<float>());
267 
268  return desc_mat.clone();
269  }
270  else
271  {
272  UERROR("No model is loaded!");
273  return cv::Mat();
274  }
275 }
276 
277 void NMS(const std::vector<cv::KeyPoint> & ptsIn,
278  const cv::Mat & conf,
279  const cv::Mat & descriptorsIn,
280  std::vector<cv::KeyPoint> & ptsOut,
281  cv::Mat & descriptorsOut,
282  int border, int dist_thresh, int img_width, int img_height)
283 {
284 
285  std::vector<cv::Point2f> pts_raw;
286 
287  for (size_t i = 0; i < ptsIn.size(); i++)
288  {
289  int u = (int) ptsIn[i].pt.x;
290  int v = (int) ptsIn[i].pt.y;
291 
292  pts_raw.push_back(cv::Point2f(u, v));
293  }
294 
295  //Grid Value Legend:
296  // 255 : Kept.
297  // 0 : Empty or suppressed.
298  // 100 : To be processed (converted to either kept or suppressed).
299  cv::Mat grid = cv::Mat(cv::Size(img_width, img_height), CV_8UC1);
300  cv::Mat inds = cv::Mat(cv::Size(img_width, img_height), CV_16UC1);
301 
302  cv::Mat confidence = cv::Mat(cv::Size(img_width, img_height), CV_32FC1);
303 
304  grid.setTo(0);
305  inds.setTo(0);
306  confidence.setTo(0);
307 
308  for (size_t i = 0; i < pts_raw.size(); i++)
309  {
310  int uu = (int) pts_raw[i].x;
311  int vv = (int) pts_raw[i].y;
312 
313  grid.at<unsigned char>(vv, uu) = 100;
314  inds.at<unsigned short>(vv, uu) = i;
315 
316  confidence.at<float>(vv, uu) = conf.at<float>(i, 0);
317  }
318 
319  // debug
320  //cv::Mat confidenceVis = confidence.clone() * 255;
321  //confidenceVis.convertTo(confidenceVis, CV_8UC1);
322  //cv::imwrite("confidence.bmp", confidenceVis);
323  //cv::imwrite("grid_in.bmp", grid);
324 
325  cv::copyMakeBorder(grid, grid, dist_thresh, dist_thresh, dist_thresh, dist_thresh, cv::BORDER_CONSTANT, 0);
326 
327  for (size_t i = 0; i < pts_raw.size(); i++)
328  {
329  // account for top left padding
330  int uu = (int) pts_raw[i].x + dist_thresh;
331  int vv = (int) pts_raw[i].y + dist_thresh;
332  float c = confidence.at<float>(vv-dist_thresh, uu-dist_thresh);
333 
334  if (grid.at<unsigned char>(vv, uu) == 100) // If not yet suppressed.
335  {
336  for(int k = -dist_thresh; k < (dist_thresh+1); k++)
337  {
338  for(int j = -dist_thresh; j < (dist_thresh+1); j++)
339  {
340  if(j==0 && k==0)
341  continue;
342 
343  if ( confidence.at<float>(vv + k - dist_thresh, uu + j - dist_thresh) <= c )
344  {
345  grid.at<unsigned char>(vv + k, uu + j) = 0;
346  }
347  }
348  }
349  grid.at<unsigned char>(vv, uu) = 255;
350  }
351  }
352 
353  size_t valid_cnt = 0;
354  std::vector<int> select_indice;
355 
356  grid = cv::Mat(grid, cv::Rect(dist_thresh, dist_thresh, img_width, img_height));
357 
358  //debug
359  //cv::imwrite("grid_nms.bmp", grid);
360 
361  for (int v = 0; v < img_height; v++)
362  {
363  for (int u = 0; u < img_width; u++)
364  {
365  if (grid.at<unsigned char>(v,u) == 255)
366  {
367  int select_ind = (int) inds.at<unsigned short>(v, u);
368  float response = conf.at<float>(select_ind, 0);
369  ptsOut.push_back(cv::KeyPoint(pts_raw[select_ind], 8.0f, -1, response));
370 
371  select_indice.push_back(select_ind);
372  valid_cnt++;
373  }
374  }
375  }
376 
377  if(!descriptorsIn.empty())
378  {
379  UASSERT(descriptorsIn.rows == (int)ptsIn.size());
380  descriptorsOut.create(select_indice.size(), 256, CV_32F);
381 
382  for (size_t i=0; i<select_indice.size(); i++)
383  {
384  for (int j=0; j < 256; j++)
385  {
386  descriptorsOut.at<float>(i, j) = descriptorsIn.at<float>(select_indice[i], j);
387  }
388  }
389  }
390 }
391 
392 }
GLM_FUNC_DECL genIType mask(genIType const &count)
static std::string homeDir()
Definition: UDirectory.cpp:355
std::shared_ptr< SuperPoint > model_
Definition: SuperPoint.h:61
cv::Mat compute(const std::vector< cv::KeyPoint > &keypoints)
Definition: SuperPoint.cc:223
torch::nn::Conv2d conv3b
Definition: SuperPoint.h:35
std::vector< torch::Tensor > forward(torch::Tensor x)
Definition: SuperPoint.cc:64
SPDetector(const std::string &modelPath, float threshold=0.2f, bool nms=true, int minDistance=4, bool cuda=false)
Definition: SuperPoint.cc:116
torch::nn::Conv2d convDa
Definition: SuperPoint.h:44
f
x
virtual ~SPDetector()
Definition: SuperPoint.cc:146
const int c4
Definition: SuperPoint.cc:18
Some conversion functions.
torch::nn::Conv2d conv2b
Definition: SuperPoint.h:32
#define UASSERT(condition)
GLM_FUNC_DECL genType normalize(genType const &x)
torch::nn::Conv2d convDb
Definition: SuperPoint.h:45
torch::nn::Conv2d conv4a
Definition: SuperPoint.h:37
const int c5
Definition: SuperPoint.cc:19
torch::nn::Conv2d conv2a
Definition: SuperPoint.h:31
void NMS(const std::vector< cv::KeyPoint > &ptsIn, const cv::Mat &conf, const cv::Mat &descriptorsIn, std::vector< cv::KeyPoint > &ptsOut, cv::Mat &descriptorsOut, int border, int dist_thresh, int img_width, int img_height)
Definition: SuperPoint.cc:277
torch::Tensor prob_
Definition: SuperPoint.h:62
std::vector< cv::KeyPoint > detect(const cv::Mat &img, const cv::Mat &mask=cv::Mat())
Definition: SuperPoint.cc:150
std::string UTILITE_EXP uReplaceChar(const std::string &str, char before, char after)
Definition: UConversion.cpp:33
const int c3
Definition: SuperPoint.cc:17
#define false
Definition: ConvertUTF.c:56
torch::nn::Conv2d conv1b
Definition: SuperPoint.h:29
#define UDEBUG(...)
const int c1
Definition: SuperPoint.cc:15
bool exists()
Definition: UFile.h:104
#define UERROR(...)
ULogger class and convenient macros.
#define UWARN(...)
torch::nn::Conv2d conv3a
Definition: SuperPoint.h:34
const int c2
Definition: SuperPoint.cc:16
torch::nn::Conv2d convPa
Definition: SuperPoint.h:40
torch::nn::Conv2d conv1a
Definition: SuperPoint.h:28
const int d1
Definition: SuperPoint.cc:20
torch::Tensor desc_
Definition: SuperPoint.h:63
torch::nn::Conv2d convPb
Definition: SuperPoint.h:41
torch::nn::Conv2d conv4b
Definition: SuperPoint.h:38


rtabmap
Author(s): Mathieu Labbe
autogenerated on Mon Jan 23 2023 03:38:57