SuperPoint.cc
Go to the documentation of this file.
1 
7 
8 
9 namespace find_object
10 {
11 
12 const int c1 = 64;
13 const int c2 = 64;
14 const int c3 = 128;
15 const int c4 = 128;
16 const int c5 = 256;
17 const int d1 = 256;
18 
19 
20 
22  : conv1a(torch::nn::Conv2dOptions( 1, c1, 3).stride(1).padding(1)),
23  conv1b(torch::nn::Conv2dOptions(c1, c1, 3).stride(1).padding(1)),
24 
25  conv2a(torch::nn::Conv2dOptions(c1, c2, 3).stride(1).padding(1)),
26  conv2b(torch::nn::Conv2dOptions(c2, c2, 3).stride(1).padding(1)),
27 
28  conv3a(torch::nn::Conv2dOptions(c2, c3, 3).stride(1).padding(1)),
29  conv3b(torch::nn::Conv2dOptions(c3, c3, 3).stride(1).padding(1)),
30 
31  conv4a(torch::nn::Conv2dOptions(c3, c4, 3).stride(1).padding(1)),
32  conv4b(torch::nn::Conv2dOptions(c4, c4, 3).stride(1).padding(1)),
33 
34  convPa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
35  convPb(torch::nn::Conv2dOptions(c5, 65, 1).stride(1).padding(0)),
36 
37  convDa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
38  convDb(torch::nn::Conv2dOptions(c5, d1, 1).stride(1).padding(0))
39 
40  {
41  register_module("conv1a", conv1a);
42  register_module("conv1b", conv1b);
43 
44  register_module("conv2a", conv2a);
45  register_module("conv2b", conv2b);
46 
47  register_module("conv3a", conv3a);
48  register_module("conv3b", conv3b);
49 
50  register_module("conv4a", conv4a);
51  register_module("conv4b", conv4b);
52 
53  register_module("convPa", convPa);
54  register_module("convPb", convPb);
55 
56  register_module("convDa", convDa);
57  register_module("convDb", convDb);
58  }
59 
60 
61 std::vector<torch::Tensor> SuperPoint::forward(torch::Tensor x) {
62 
63  x = torch::relu(conv1a->forward(x));
64  x = torch::relu(conv1b->forward(x));
65  x = torch::max_pool2d(x, 2, 2);
66 
67  x = torch::relu(conv2a->forward(x));
68  x = torch::relu(conv2b->forward(x));
69  x = torch::max_pool2d(x, 2, 2);
70 
71  x = torch::relu(conv3a->forward(x));
72  x = torch::relu(conv3b->forward(x));
73  x = torch::max_pool2d(x, 2, 2);
74 
75  x = torch::relu(conv4a->forward(x));
76  x = torch::relu(conv4b->forward(x));
77 
78  auto cPa = torch::relu(convPa->forward(x));
79  auto semi = convPb->forward(cPa); // [B, 65, H/8, W/8]
80 
81  auto cDa = torch::relu(convDa->forward(x));
82  auto desc = convDb->forward(cDa); // [B, d1, H/8, W/8]
83 
84  auto dn = torch::norm(desc, 2, 1);
85  desc = desc.div(torch::unsqueeze(dn, 1));
86 
87  semi = torch::softmax(semi, 1);
88  semi = semi.slice(1, 0, 64);
89  semi = semi.permute({0, 2, 3, 1}); // [B, H/8, W/8, 64]
90 
91 
92  int Hc = semi.size(1);
93  int Wc = semi.size(2);
94  semi = semi.contiguous().view({-1, Hc, Wc, 8, 8});
95  semi = semi.permute({0, 1, 3, 2, 4});
96  semi = semi.contiguous().view({-1, Hc * 8, Wc * 8}); // [B, H, W]
97 
98 
99  std::vector<torch::Tensor> ret;
100  ret.push_back(semi);
101  ret.push_back(desc);
102 
103  return ret;
104  }
105 
106 void NMS(const std::vector<cv::KeyPoint> & ptsIn,
107  const cv::Mat & conf,
108  const cv::Mat & descriptorsIn,
109  std::vector<cv::KeyPoint> & ptsOut,
110  cv::Mat & descriptorsOut,
111  int border, int dist_thresh, int img_width, int img_height);
112 
113 SPDetector::SPDetector(const std::string & modelPath, float threshold, bool nms, int minDistance, bool cuda) :
114  threshold_(threshold),
115  nms_(nms),
116  minDistance_(minDistance),
117  detected_(false)
118 {
119  UDEBUG("modelPath=%s thr=%f nms=%d cuda=%d", modelPath.c_str(), threshold, nms?1:0, cuda?1:0);
120  if(modelPath.empty())
121  {
122  return;
123  }
124  model_ = std::make_shared<SuperPoint>();
125  torch::load(model_, modelPath);
126 
127  if(cuda && !torch::cuda::is_available())
128  {
129  UWARN("Cuda option is enabled but torch doesn't have cuda support on this platform, using CPU instead.");
130  }
131  cuda_ = cuda && torch::cuda::is_available();
132  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
133  model_->to(device);
134 }
135 
137 {
138 }
139 
140 std::vector<cv::KeyPoint> SPDetector::detect(const cv::Mat &img)
141 {
142  detected_ = false;
143  if(model_)
144  {
145  torch::NoGradGuard no_grad_guard;
146  auto x = torch::from_blob(img.data, {1, 1, img.rows, img.cols}, torch::kByte);
147  x = x.to(torch::kFloat) / 255;
148 
149  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
150  x = x.set_requires_grad(false);
151  auto out = model_->forward(x.to(device));
152 
153  prob_ = out[0].squeeze(0); // [H, W]
154  desc_ = out[1]; // [1, 256, H/8, W/8]
155 
156  auto kpts = (prob_ > threshold_);
157  kpts = torch::nonzero(kpts); // [n_keypoints, 2] (y, x)
158 
159  std::vector<cv::KeyPoint> keypoints_no_nms;
160  for (int i = 0; i < kpts.size(0); i++) {
161  float response = prob_[kpts[i][0]][kpts[i][1]].item<float>();
162  keypoints_no_nms.push_back(cv::KeyPoint(kpts[i][1].item<float>(), kpts[i][0].item<float>(), 8, -1, response));
163  }
164 
165  detected_ = true;
166  if (nms_ && !keypoints_no_nms.empty()) {
167  cv::Mat conf(keypoints_no_nms.size(), 1, CV_32F);
168  for (size_t i = 0; i < keypoints_no_nms.size(); i++) {
169  int x = keypoints_no_nms[i].pt.x;
170  int y = keypoints_no_nms[i].pt.y;
171  conf.at<float>(i, 0) = prob_[y][x].item<float>();
172  }
173 
174  int border = 0;
175  int dist_thresh = minDistance_;
176  int height = img.rows;
177  int width = img.cols;
178 
179  std::vector<cv::KeyPoint> keypoints;
180  cv::Mat descEmpty;
181  NMS(keypoints_no_nms, conf, descEmpty, keypoints, descEmpty, border, dist_thresh, width, height);
182  return keypoints;
183  }
184  else {
185  return keypoints_no_nms;
186  }
187  }
188  else
189  {
190  UERROR("No model is loaded!");
191  return std::vector<cv::KeyPoint>();
192  }
193 }
194 
195 cv::Mat SPDetector::compute(const std::vector<cv::KeyPoint> &keypoints)
196 {
197  if(!detected_)
198  {
199  UERROR("SPDetector has been reset before extracting the descriptors! detect() should be called before compute().");
200  return cv::Mat();
201  }
202  if(model_.get())
203  {
204  cv::Mat kpt_mat(keypoints.size(), 2, CV_32F); // [n_keypoints, 2] (y, x)
205 
206  for (size_t i = 0; i < keypoints.size(); i++) {
207  kpt_mat.at<float>(i, 0) = (float)keypoints[i].pt.y;
208  kpt_mat.at<float>(i, 1) = (float)keypoints[i].pt.x;
209  }
210 
211  auto fkpts = torch::from_blob(kpt_mat.data, {(long int)keypoints.size(), 2}, torch::kFloat);
212 
213  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
214  auto grid = torch::zeros({1, 1, fkpts.size(0), 2}).to(device); // [1, 1, n_keypoints, 2]
215  grid[0][0].slice(1, 0, 1) = 2.0 * fkpts.slice(1, 1, 2) / prob_.size(1) - 1; // x
216  grid[0][0].slice(1, 1, 2) = 2.0 * fkpts.slice(1, 0, 1) / prob_.size(0) - 1; // y
217 
218  auto desc = torch::grid_sampler(desc_, grid, 0, 0, true); // [1, 256, 1, n_keypoints]
219  desc = desc.squeeze(0).squeeze(1); // [256, n_keypoints]
220 
221  // normalize to 1
222  auto dn = torch::norm(desc, 2, 1);
223  desc = desc.div(torch::unsqueeze(dn, 1));
224 
225  desc = desc.transpose(0, 1).contiguous(); // [n_keypoints, 256]
226  if(cuda_)
227  desc = desc.to(torch::kCPU);
228 
229  cv::Mat desc_mat(cv::Size(desc.size(1), desc.size(0)), CV_32FC1, desc.data<float>());
230 
231  return desc_mat.clone();
232  }
233  else
234  {
235  UERROR("No model is loaded!");
236  return cv::Mat();
237  }
238 }
239 
240 void NMS(const std::vector<cv::KeyPoint> & ptsIn,
241  const cv::Mat & conf,
242  const cv::Mat & descriptorsIn,
243  std::vector<cv::KeyPoint> & ptsOut,
244  cv::Mat & descriptorsOut,
245  int border, int dist_thresh, int img_width, int img_height)
246 {
247 
248  std::vector<cv::Point2f> pts_raw;
249 
250  for (size_t i = 0; i < ptsIn.size(); i++)
251  {
252  int u = (int) ptsIn[i].pt.x;
253  int v = (int) ptsIn[i].pt.y;
254 
255  pts_raw.push_back(cv::Point2f(u, v));
256  }
257 
258  //Grid Value Legend:
259  // 255 : Kept.
260  // 0 : Empty or suppressed.
261  // 100 : To be processed (converted to either kept or suppressed).
262  cv::Mat grid = cv::Mat(cv::Size(img_width, img_height), CV_8UC1);
263  cv::Mat inds = cv::Mat(cv::Size(img_width, img_height), CV_16UC1);
264 
265  cv::Mat confidence = cv::Mat(cv::Size(img_width, img_height), CV_32FC1);
266 
267  grid.setTo(0);
268  inds.setTo(0);
269  confidence.setTo(0);
270 
271  for (size_t i = 0; i < pts_raw.size(); i++)
272  {
273  int uu = (int) pts_raw[i].x;
274  int vv = (int) pts_raw[i].y;
275 
276  grid.at<unsigned char>(vv, uu) = 100;
277  inds.at<unsigned short>(vv, uu) = i;
278 
279  confidence.at<float>(vv, uu) = conf.at<float>(i, 0);
280  }
281 
282  // debug
283  //cv::Mat confidenceVis = confidence.clone() * 255;
284  //confidenceVis.convertTo(confidenceVis, CV_8UC1);
285  //cv::imwrite("confidence.bmp", confidenceVis);
286  //cv::imwrite("grid_in.bmp", grid);
287 
288  cv::copyMakeBorder(grid, grid, dist_thresh, dist_thresh, dist_thresh, dist_thresh, cv::BORDER_CONSTANT, 0);
289 
290  for (size_t i = 0; i < pts_raw.size(); i++)
291  {
292  // account for top left padding
293  int uu = (int) pts_raw[i].x + dist_thresh;
294  int vv = (int) pts_raw[i].y + dist_thresh;
295  float c = confidence.at<float>(vv-dist_thresh, uu-dist_thresh);
296 
297  if (grid.at<unsigned char>(vv, uu) == 100) // If not yet suppressed.
298  {
299  for(int k = -dist_thresh; k < (dist_thresh+1); k++)
300  {
301  for(int j = -dist_thresh; j < (dist_thresh+1); j++)
302  {
303  if(j==0 && k==0)
304  continue;
305 
306  if ( confidence.at<float>(vv + k - dist_thresh, uu + j - dist_thresh) <= c )
307  {
308  grid.at<unsigned char>(vv + k, uu + j) = 0;
309  }
310  }
311  }
312  grid.at<unsigned char>(vv, uu) = 255;
313  }
314  }
315 
316  size_t valid_cnt = 0;
317  std::vector<int> select_indice;
318 
319  grid = cv::Mat(grid, cv::Rect(dist_thresh, dist_thresh, img_width, img_height));
320 
321  //debug
322  //cv::imwrite("grid_nms.bmp", grid);
323 
324  for (int v = 0; v < img_height; v++)
325  {
326  for (int u = 0; u < img_width; u++)
327  {
328  if (grid.at<unsigned char>(v,u) == 255)
329  {
330  int select_ind = (int) inds.at<unsigned short>(v, u);
331  float response = conf.at<float>(select_ind, 0);
332  ptsOut.push_back(cv::KeyPoint(pts_raw[select_ind], 8.0f, -1, response));
333 
334  select_indice.push_back(select_ind);
335  valid_cnt++;
336  }
337  }
338  }
339 
340  if(!descriptorsIn.empty())
341  {
342  UASSERT(descriptorsIn.rows == (int)ptsIn.size());
343  descriptorsOut.create(select_indice.size(), 256, CV_32F);
344 
345  for (size_t i=0; i<select_indice.size(); i++)
346  {
347  for (int j=0; j < 256; j++)
348  {
349  descriptorsOut.at<float>(i, j) = descriptorsIn.at<float>(select_indice[i], j);
350  }
351  }
352  }
353 }
354 
355 }
torch::nn::Conv2d convDb
Definition: SuperPoint.h:45
const int c1
Definition: SuperPoint.cc:12
const int c3
Definition: SuperPoint.cc:14
torch::nn::Conv2d conv3b
Definition: SuperPoint.h:35
std::shared_ptr< SuperPoint > model_
Definition: SuperPoint.h:61
torch::Tensor desc_
Definition: SuperPoint.h:63
f
torch::nn::Conv2d conv2a
Definition: SuperPoint.h:31
torch::nn::Conv2d conv3a
Definition: SuperPoint.h:34
cv::Mat compute(const std::vector< cv::KeyPoint > &keypoints)
Definition: SuperPoint.cc:195
torch::nn::Conv2d convPb
Definition: SuperPoint.h:41
torch::nn::Conv2d convPa
Definition: SuperPoint.h:40
torch::nn::Conv2d conv1a
Definition: SuperPoint.h:28
#define UASSERT(condition)
const int c5
Definition: SuperPoint.cc:16
std::vector< torch::Tensor > forward(torch::Tensor x)
Definition: SuperPoint.cc:61
torch::nn::Conv2d conv2b
Definition: SuperPoint.h:32
torch::nn::Conv2d convDa
Definition: SuperPoint.h:44
SPDetector(const std::string &modelPath, float threshold=0.2f, bool nms=true, int minDistance=4, bool cuda=false)
Definition: SuperPoint.cc:113
torch::nn::Conv2d conv1b
Definition: SuperPoint.h:29
const int c2
Definition: SuperPoint.cc:13
#define UDEBUG(...)
torch::nn::Conv2d conv4b
Definition: SuperPoint.h:38
torch::Tensor prob_
Definition: SuperPoint.h:62
#define UERROR(...)
const int c4
Definition: SuperPoint.cc:15
const int d1
Definition: SuperPoint.cc:17
torch::nn::Conv2d conv4a
Definition: SuperPoint.h:37
ULogger class and convenient macros.
#define UWARN(...)
void NMS(const std::vector< cv::KeyPoint > &ptsIn, const cv::Mat &conf, const cv::Mat &descriptorsIn, std::vector< cv::KeyPoint > &ptsOut, cv::Mat &descriptorsOut, int border, int dist_thresh, int img_width, int img_height)
Definition: SuperPoint.cc:240
std::vector< cv::KeyPoint > detect(const cv::Mat &img)
Definition: SuperPoint.cc:140
const std::string response


find_object_2d
Author(s): Mathieu Labbe
autogenerated on Mon Dec 12 2022 03:20:09