netvlad_tf_ros.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 # Using netvlad tensorflow-v1 implementation from https://github.com/uzh-rpg/netvlad_tf_open/
4 # For ROS melodic, follow the following instructions to rebuild cv_bridge with Python3
5 # https://medium.com/@beta_b0t/how-to-setup-ros-with-python-3-44a69ca36674
6 # On Jetpack 4.4 (18.04 and OpenCV4), use vision_opencv's noetic branch. In cv_bridge/CMakeLists.txt,
7 # apply this patch:
8 # -find_package(Boost REQUIRED python37)
9 # +find_package(Boost REQUIRED python3)
10 
11 from __future__ import print_function
12 
13 import roslib
14 import sys
15 import rospy
16 import cv2
17 import numpy as np
18 import tensorflow as tf
19 import time
20 
21 import netvlad_tf.net_from_mat as nfm
22 import netvlad_tf.nets as nets
23 
24 from std_msgs.msg import String
25 from sensor_msgs.msg import Image
26 from cv_bridge import CvBridge, CvBridgeError
27 from rtabmap_ros import compression as cp
28 from rtabmap_ros.msg import GlobalDescriptor
29 
31 
32  def __init__(self):
33 
34  self.dim = rospy.get_param('~dim', 4096)
35  self.scale = rospy.get_param('~scale', 1.0)
36  rospy.loginfo("Parameter dim=%d", self.dim)
37  rospy.loginfo("Parameter scale=%d", self.scale)
38 
39  tf.reset_default_graph()
40 
41  self.image_batch = tf.placeholder(
42  dtype=tf.float32, shape=[None, None, None, 3])
43 
44  self.net_out = nets.vgg16NetvladPca(self.image_batch)
45  self.saver = tf.train.Saver()
46 
47  self.sess = tf.Session()
48  self.saver.restore(self.sess, nets.defaultCheckpoint())
49 
50  self.pub = rospy.Publisher('netvlad_descriptor', GlobalDescriptor, queue_size=1)
51 
52  self.bridge = CvBridge()
53  self.image_sub = rospy.Subscriber("image",Image,self.callback, queue_size=1)
54 
55  def callback(self,data):
56  start = time.time()
57  try:
58  cv_image = self.bridge.imgmsg_to_cv2(data, "rgb8")
59  except CvBridgeError as e:
60  print(e)
61 
62  if self.scale != 1.0:
63  width = int(cv_image.shape[1] * self.scale)
64  height = int(cv_image.shape[0] * self.scale)
65  cv_image = cv2.resize(cv_image, (width, height), interpolation = cv2.INTER_AREA)
66 
67  batch = np.expand_dims(cv_image, axis=0)
68  result = self.sess.run(self.net_out, feed_dict={self.image_batch: batch})
69  result = result[:,:self.dim]
70 
71  descriptor = GlobalDescriptor()
72  descriptor.type = 0
73  descriptor.header = data.header
74  descriptor.data = cp.compress(result)
75  self.pub.publish(descriptor)
76  end = time.time()
77  rospy.loginfo("Extracting descriptor (img=%dx%d, dim=%d): %fs", cv_image.shape[1], cv_image.shape[0], self.dim, end-start)
78 
79 def main(args):
80  rospy.init_node('netvlad', anonymous=True)
81  n = netvlad_ros()
82  try:
83  rospy.spin()
84  except KeyboardInterrupt:
85  print("Shutting down")
86 
87 if __name__ == '__main__':
88  main(sys.argv)
89 
def main(args)


rtabmap_ros
Author(s): Mathieu Labbe
autogenerated on Tue Jan 24 2023 04:04:40