SuperPoint.cc
Go to the documentation of this file.
1 
10 
11 
12 namespace rtabmap
13 {
14 
15 const int c1 = 64;
16 const int c2 = 64;
17 const int c3 = 128;
18 const int c4 = 128;
19 const int c5 = 256;
20 const int d1 = 256;
21 
22 
23 
25  : conv1a(torch::nn::Conv2dOptions( 1, c1, 3).stride(1).padding(1)),
26  conv1b(torch::nn::Conv2dOptions(c1, c1, 3).stride(1).padding(1)),
27 
28  conv2a(torch::nn::Conv2dOptions(c1, c2, 3).stride(1).padding(1)),
29  conv2b(torch::nn::Conv2dOptions(c2, c2, 3).stride(1).padding(1)),
30 
31  conv3a(torch::nn::Conv2dOptions(c2, c3, 3).stride(1).padding(1)),
32  conv3b(torch::nn::Conv2dOptions(c3, c3, 3).stride(1).padding(1)),
33 
34  conv4a(torch::nn::Conv2dOptions(c3, c4, 3).stride(1).padding(1)),
35  conv4b(torch::nn::Conv2dOptions(c4, c4, 3).stride(1).padding(1)),
36 
37  convPa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
38  convPb(torch::nn::Conv2dOptions(c5, 65, 1).stride(1).padding(0)),
39 
40  convDa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)),
41  convDb(torch::nn::Conv2dOptions(c5, d1, 1).stride(1).padding(0))
42 
43 {
44  register_module("conv1a", conv1a);
45  register_module("conv1b", conv1b);
46 
47  register_module("conv2a", conv2a);
48  register_module("conv2b", conv2b);
49 
50  register_module("conv3a", conv3a);
51  register_module("conv3b", conv3b);
52 
53  register_module("conv4a", conv4a);
54  register_module("conv4b", conv4b);
55 
56  register_module("convPa", convPa);
57  register_module("convPb", convPb);
58 
59  register_module("convDa", convDa);
60  register_module("convDb", convDb);
61 }
62 
63 std::vector<torch::Tensor> SuperPoint::forward(torch::Tensor x)
64 {
65  x = torch::relu(conv1a->forward(x));
66  x = torch::relu(conv1b->forward(x));
67  x = torch::max_pool2d(x, 2, 2);
68 
69  x = torch::relu(conv2a->forward(x));
70  x = torch::relu(conv2b->forward(x));
71  x = torch::max_pool2d(x, 2, 2);
72 
73  x = torch::relu(conv3a->forward(x));
74  x = torch::relu(conv3b->forward(x));
75  x = torch::max_pool2d(x, 2, 2);
76 
77  x = torch::relu(conv4a->forward(x));
78  x = torch::relu(conv4b->forward(x));
79 
80  auto cPa = torch::relu(convPa->forward(x));
81  auto semi = convPb->forward(cPa); // [B, 65, H/8, W/8]
82 
83  auto cDa = torch::relu(convDa->forward(x));
84  auto desc = convDb->forward(cDa); // [B, d1, H/8, W/8]
85 
86  auto dn = torch::norm(desc, 2, 1);
87  desc = desc.div(torch::unsqueeze(dn, 1));
88 
89  semi = torch::softmax(semi, 1);
90  semi = semi.slice(1, 0, 64);
91  semi = semi.permute({0, 2, 3, 1}); // [B, H/8, W/8, 64]
92 
93 
94  int Hc = semi.size(1);
95  int Wc = semi.size(2);
96  semi = semi.contiguous().view({-1, Hc, Wc, 8, 8});
97  semi = semi.permute({0, 1, 3, 2, 4});
98  semi = semi.contiguous().view({-1, Hc * 8, Wc * 8}); // [B, H, W]
99 
100 
101  std::vector<torch::Tensor> ret;
102  ret.push_back(semi);
103  ret.push_back(desc);
104 
105  return ret;
106 }
107 
108 SPDetector::SPDetector(const std::string & modelPath, float threshold, bool nms, int minDistance, bool cuda) :
109  threshold_(threshold),
110  nms_(nms),
111  minDistance_(minDistance),
112  detected_(false)
113 {
114  UDEBUG("modelPath=%s thr=%f nms=%d cuda=%d", modelPath.c_str(), threshold, nms?1:0, cuda?1:0);
115  if(modelPath.empty())
116  {
117  UERROR("Model's path is empty!");
118  return;
119  }
120  std::string path = uReplaceChar(modelPath, '~', UDirectory::homeDir());
121  if(!UFile::exists(path))
122  {
123  UERROR("Model's path \"%s\" doesn't exist!", path.c_str());
124  return;
125  }
126  model_ = std::make_shared<SuperPoint>();
127  torch::load(model_, uReplaceChar(path, '~', UDirectory::homeDir()));
128 
129  if(cuda && !torch::cuda::is_available())
130  {
131  UWARN("Cuda option is enabled but torch doesn't have cuda support on this platform, using CPU instead.");
132  }
133  cuda_ = cuda && torch::cuda::is_available();
134  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
135  model_->to(device);
136 }
137 
139 {
140 }
141 
142 std::vector<cv::KeyPoint> SPDetector::detect(const cv::Mat &img, const cv::Mat & mask)
143 {
144  UASSERT(img.type() == CV_8UC1);
145  UASSERT(mask.empty() || (mask.type() == CV_8UC1 && img.cols == mask.cols && img.rows == mask.rows));
146  detected_ = false;
147  if(model_)
148  {
149  torch::NoGradGuard no_grad_guard;
150  auto x = torch::from_blob(img.data, {1, 1, img.rows, img.cols}, torch::kByte);
151  x = x.to(torch::kFloat) / 255;
152 
153  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
154  x = x.set_requires_grad(false);
155  auto out = model_->forward(x.to(device));
156 
157  auto scores = out[0]; // [1, H, W]
158  desc_ = out[1]; // [1, 256, H/8, W/8]
159 
160  if(nms_)
161  {
162  auto options = torch::nn::functional::MaxPool2dFuncOptions(minDistance_*2+1).stride(1).padding(minDistance_);
163  auto options_r1 = torch::nn::functional::MaxPool2dFuncOptions(3).stride(1).padding(1);
164 
165  auto zeros = torch::zeros_like(scores);
166  auto max_mask = scores == torch::nn::functional::max_pool2d(scores, options);
167  auto max_mask_r1 = scores == torch::nn::functional::max_pool2d(scores, options_r1);
168  for(size_t i=0; i<2; i++)
169  {
170  auto supp_mask = torch::nn::functional::max_pool2d(max_mask.to(torch::kF32), options) > 0;
171  auto supp_scores = torch::where(supp_mask, zeros, scores);
172  auto new_max_mask = supp_scores == torch::nn::functional::max_pool2d(supp_scores, options);
173  max_mask = max_mask | (new_max_mask & (~supp_mask) & max_mask_r1);
174  }
175  prob_ = torch::where(max_mask, scores, zeros).squeeze(0);
176  }
177  else
178  {
179  prob_ = scores.squeeze(0);
180  }
181 
182  auto kpts = (prob_ > threshold_);
183  kpts = torch::nonzero(kpts); // [n_keypoints, 2] (y, x)
184 
185  //convert back to cpu if in gpu
186  auto kpts_cpu = kpts.to(torch::kCPU);
187  auto prob_cpu = prob_.to(torch::kCPU);
188 
189  std::vector<cv::KeyPoint> keypoints;
190  for(int i=0; i<kpts_cpu.size(0); i++)
191  {
192  if(mask.empty() || mask.at<unsigned char>(kpts_cpu[i][0].item<int>(), kpts_cpu[i][1].item<int>()) != 0)
193  {
194  float response = prob_cpu[kpts_cpu[i][0]][kpts_cpu[i][1]].item<float>();
195  keypoints.emplace_back(cv::KeyPoint(kpts_cpu[i][1].item<float>(), kpts_cpu[i][0].item<float>(), 8, -1, response));
196  }
197  }
198 
199  detected_ = true;
200  return keypoints;
201  }
202  else
203  {
204  UERROR("No model is loaded!");
205  return std::vector<cv::KeyPoint>();
206  }
207 }
208 
209 cv::Mat SPDetector::compute(const std::vector<cv::KeyPoint> &keypoints)
210 {
211  if(!detected_)
212  {
213  UERROR("SPDetector has been reset before extracting the descriptors! detect() should be called before compute().");
214  return cv::Mat();
215  }
216  if(keypoints.empty())
217  {
218  return cv::Mat();
219  }
220  if(model_.get())
221  {
222  cv::Mat kpt_mat(keypoints.size(), 2, CV_32F); // [n_keypoints, 2] (y, x)
223 
224  // Based on sample_descriptors() of SuperPoint implementation in SuperGlue:
225  // https://github.com/magicleap/SuperGluePretrainedNetwork/blob/45a750e5707696da49472f1cad35b0b203325417/models/superpoint.py#L80-L92
226  float s = 8;
227  for (size_t i = 0; i < keypoints.size(); i++) {
228  kpt_mat.at<float>(i, 0) = (float)keypoints[i].pt.y - s/2 + 0.5;
229  kpt_mat.at<float>(i, 1) = (float)keypoints[i].pt.x - s/2 + 0.5;
230  }
231 
232  auto fkpts = torch::from_blob(kpt_mat.data, {(long int)keypoints.size(), 2}, torch::kFloat);
233 
234  float w = desc_.size(3); //W/8
235  float h = desc_.size(2); //H/8
236 
237  torch::Device device(cuda_?torch::kCUDA:torch::kCPU);
238  auto grid = torch::zeros({1, 1, fkpts.size(0), 2}).to(device); // [1, 1, n_keypoints, 2]
239  grid[0][0].slice(1, 0, 1) = 2.0 * fkpts.slice(1, 1, 2) / (w*s - s/2 - 0.5) - 1; // x
240  grid[0][0].slice(1, 1, 2) = 2.0 * fkpts.slice(1, 0, 1) / (h*s - s/2 - 0.5) - 1; // y
241 
242  auto desc = torch::grid_sampler(desc_, grid, 0, 0, true); // [1, 256, 1, n_keypoints]
243 
244  // normalize to 1
245  desc = torch::nn::functional::normalize(desc.reshape({1, desc_.size(1), -1})); //[1, 256, n_keypoints]
246  desc = desc.squeeze(); //[256, n_keypoints]
247  desc = desc.transpose(0, 1).contiguous(); //[n_keypoints, 256]
248 
249  if(cuda_)
250  desc = desc.to(torch::kCPU);
251 
252  cv::Mat desc_mat(cv::Size(desc.size(1), desc.size(0)), CV_32FC1, desc.data_ptr<float>());
253 
254  return desc_mat.clone();
255  }
256  else
257  {
258  UERROR("No model is loaded!");
259  return cv::Mat();
260  }
261 }
262 
263 }
w
RowVector3d w
rtabmap::c4
const int c4
Definition: SuperPoint.cc:18
rtabmap::SuperPoint::conv3a
torch::nn::Conv2d conv3a
Definition: SuperPoint.h:34
glm::mask
GLM_FUNC_DECL genIType mask(genIType const &count)
rtabmap::d1
const int d1
Definition: SuperPoint.cc:20
s
RealScalar s
ret
int ret
rtabmap::SPDetector::~SPDetector
virtual ~SPDetector()
Definition: SuperPoint.cc:138
h
const double h
rtabmap::SuperPoint::convDa
torch::nn::Conv2d convDa
Definition: SuperPoint.h:44
rtabmap::SuperPoint::convPa
torch::nn::Conv2d convPa
Definition: SuperPoint.h:40
UDirectory.h
rtabmap::SPDetector::minDistance_
int minDistance_
Definition: SuperPoint.h:67
rtabmap::SPDetector::detect
std::vector< cv::KeyPoint > detect(const cv::Mat &img, const cv::Mat &mask=cv::Mat())
Definition: SuperPoint.cc:142
rtabmap::SPDetector::model_
std::shared_ptr< SuperPoint > model_
Definition: SuperPoint.h:61
rtabmap_netvlad.img
img
Definition: rtabmap_netvlad.py:78
threshold_
Index threshold_
rtabmap::SuperPoint::convDb
torch::nn::Conv2d convDb
Definition: SuperPoint.h:45
glm::normalize
GLM_FUNC_DECL genType normalize(genType const &x)
rtabmap::SPDetector::SPDetector
SPDetector(const std::string &modelPath, float threshold=0.2f, bool nms=true, int minDistance=4, bool cuda=false)
Definition: SuperPoint.cc:108
rtabmap::SPDetector::threshold_
float threshold_
Definition: SuperPoint.h:65
rtabmap::SuperPoint::conv4a
torch::nn::Conv2d conv4a
Definition: SuperPoint.h:37
UDirectory::homeDir
static std::string homeDir()
Definition: UDirectory.cpp:355
SuperPoint.h
rtabmap::c5
const int c5
Definition: SuperPoint.cc:19
rtabmap::SPDetector::desc_
torch::Tensor desc_
Definition: SuperPoint.h:63
rtabmap::SuperPoint::conv2b
torch::nn::Conv2d conv2b
Definition: SuperPoint.h:32
rtabmap::SPDetector::nms_
bool nms_
Definition: SuperPoint.h:66
UConversion.h
Some conversion functions.
rtabmap_superglue.device
string device
Definition: rtabmap_superglue.py:21
rtabmap::c3
const int c3
Definition: SuperPoint.cc:17
rtabmap::SPDetector::cuda_
bool cuda_
Definition: SuperPoint.h:68
rtabmap::SuperPoint::conv1b
torch::nn::Conv2d conv1b
Definition: SuperPoint.h:29
rtabmap::SPDetector::compute
cv::Mat compute(const std::vector< cv::KeyPoint > &keypoints)
Definition: SuperPoint.cc:209
UASSERT
#define UASSERT(condition)
x
x
out
std::ofstream out("Result.txt")
rtabmap::SuperPoint::conv1a
torch::nn::Conv2d conv1a
Definition: SuperPoint.h:28
path
path
UWARN
#define UWARN(...)
uReplaceChar
std::string UTILITE_EXPORT uReplaceChar(const std::string &str, char before, char after)
Definition: UConversion.cpp:33
rtabmap::SuperPoint::conv3b
torch::nn::Conv2d conv3b
Definition: SuperPoint.h:35
ULogger.h
ULogger class and convenient macros.
rtabmap::SuperPoint::conv4b
torch::nn::Conv2d conv4b
Definition: SuperPoint.h:38
rtabmap::c1
const int c1
Definition: SuperPoint.cc:15
UDEBUG
#define UDEBUG(...)
rtabmap::c2
const int c2
Definition: SuperPoint.cc:16
rtabmap::SuperPoint::convPb
torch::nn::Conv2d convPb
Definition: SuperPoint.h:41
rtabmap::SPDetector::prob_
torch::Tensor prob_
Definition: SuperPoint.h:62
rtabmap::SPDetector::detected_
bool detected_
Definition: SuperPoint.h:70
rtabmap::SuperPoint::forward
std::vector< torch::Tensor > forward(torch::Tensor x)
Definition: SuperPoint.cc:63
rtabmap::SuperPoint::conv2a
torch::nn::Conv2d conv2a
Definition: SuperPoint.h:31
false
#define false
Definition: ConvertUTF.c:56
rtabmap::SuperPoint::SuperPoint
SuperPoint()
Definition: SuperPoint.cc:24
nn
idx_t * nn
UFile.h
rtabmap
Definition: CameraARCore.cpp:35
UFile::exists
bool exists()
Definition: UFile.h:104
UERROR
#define UERROR(...)
stride
Index stride
options
i
int i


rtabmap
Author(s): Mathieu Labbe
autogenerated on Mon Jul 1 2024 02:42:39