rtabmap_netvlad.py
Go to the documentation of this file.
1 #! /usr/bin/env python3
2 #
3 # Drop this file in the "python" folder of NetVLAD git (tensorflow-v1 used): https://github.com/uzh-rpg/netvlad_tf_open/
4 # Updated to work with https://github.com/uzh-rpg/netvlad_tf_open/pull/9
5 # To use with rtabmap:
6 # --Mem/GlobalDescriptorStrategy 1 --Kp/TfIdfLikelihoodUsed false --Mem/RehearsalSimilarity 1 --PyDescriptor/Dim 128 --PyDescriptor/Path ~/netvlad_tf_open/python/rtabmap_netvlad.py
7 #
8 
9 import sys
10 import os
11 import numpy as np
12 import time
13 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
14 if not hasattr(sys, 'argv'):
15  sys.argv = ['']
16 
17 #print(os.sys.path)
18 #print(sys.version)
19 
20 import tensorflow as tf
21 import netvlad_tf.net_from_mat as nfm
22 import netvlad_tf.nets as nets
23 
24 image_batch = None
25 net_out = None
26 saver = None
27 sess = None
28 dim = 4096
29 
30 def init(descriptorDim):
31  print("NetVLAD python init()")
32  global image_batch
33  global net_out
34  global saver
35  global sess
36  global dim
37 
38  dim = descriptorDim
39 
40  tf.compat.v1.disable_eager_execution()
41  tf.compat.v1.reset_default_graph()
42 
43  image_batch = tf.compat.v1.placeholder(
44  dtype=tf.float32, shape=[None, None, None, 3])
45 
46  net_out = nets.vgg16NetvladPca(image_batch)
47  saver = tf.compat.v1.train.Saver()
48 
49  sess = tf.compat.v1.Session()
50  saver.restore(sess, nets.defaultCheckpoint())
51 
52 
53 def extract(image):
54  print(f"NetVLAD python extract{image.shape}")
55  global image_batch
56  global net_out
57  global sess
58  global dim
59 
60  if(image.shape[2] == 1):
61  image = np.dstack((image, image, image))
62 
63  batch = np.expand_dims(image, axis=0)
64  result = sess.run(net_out, feed_dict={image_batch: batch})
65 
66  # All that needs to be done (only valid for NetVLAD+whitening networks!)
67  # to reduce the dimensionality of the NetVLAD representation below 4096 to D
68  # is to keep the first D dimensions and L2-normalize.
69  if(result.shape[1] > dim):
70  v = result[:, :dim]
71  result = v/np.linalg.norm(v)
72 
73  return np.float32(result)
74 
75 
76 if __name__ == '__main__':
77  #test
78  img = np.zeros([100,100,3],dtype=np.uint8)
79  img.fill(255)
80  init(128)
81  descriptor = extract(img)
82  print(descriptor.shape)
83  print(descriptor)
rtabmap_netvlad.extract
def extract(image)
Definition: rtabmap_netvlad.py:53
rtabmap_netvlad.init
def init(descriptorDim)
Definition: rtabmap_netvlad.py:30
hasattr
bool hasattr(handle obj, const char *name)
init


rtabmap
Author(s): Mathieu Labbe
autogenerated on Thu Jul 25 2024 02:50:15