deep_sort_tracker_node.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 import numpy as np
6 import rospy
7 import itertools, pkg_resources, sys
8 from distutils.version import LooseVersion
9 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
10  sys.version_info.major == 2:
11  print('''Please install chainer < 7.0.0:
12 
13  sudo pip install chainer==6.7.0
14 
15 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
16 ''', file=sys.stderr)
17  sys.exit(1)
18 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
19  print('''Please install CuPy
20 
21  sudo pip install cupy-cuda[your cuda version]
22 i.e.
23  sudo pip install cupy-cuda91
24 
25 ''', file=sys.stderr)
26  # sys.exit(1)
27 import chainer
28 from sensor_msgs.msg import Image
29 from cv_bridge import CvBridge
30 import message_filters
31 from jsk_recognition_msgs.msg import RectArray
32 from jsk_recognition_msgs.msg import ClassificationResult
33 from jsk_recognition_msgs.msg import Label
34 from jsk_recognition_msgs.msg import LabelArray
35 
36 from deep_sort.deep_sort_tracker import DeepSortTracker
37 
38 
40 
41  def __init__(self):
42  super(DeepSortTrackerNode, self).__init__()
43  self.bridge = CvBridge()
44 
45  self.target_labels = rospy.get_param('~target_labels', None)
46  self.gpu = rospy.get_param('~gpu', 0)
47  if self.gpu >= 0:
48  chainer.cuda.get_device_from_id(self.gpu).use()
49  chainer.global_config.train = False
50  chainer.global_config.enable_backprop = False
51 
52  pretrained_model = rospy.get_param('~pretrained_model')
53  self.tracker = DeepSortTracker(gpu=self.gpu,
54  pretrained_model=pretrained_model)
55  self.image_pub = rospy.Publisher(
56  '~output/viz',
57  Image, queue_size=1)
58  self.pub_labels = rospy.Publisher(
59  '~output/labels', LabelArray, queue_size=1)
60 
61  self.subscribe()
62 
63  def subscribe(self):
64  queue_size = rospy.get_param('~queue_size', 100)
66  '~input', Image, queue_size=1)
67  sub_rects = message_filters.Subscriber(
68  '~input/rects', RectArray, queue_size=1)
69  sub_class = message_filters.Subscriber(
70  '~input/class', ClassificationResult, queue_size=1)
71  self.subs = [sub_img, sub_rects, sub_class]
72  if rospy.get_param('~approximate_sync', False):
73  slop = rospy.get_param('~slop', 0.1)
74  sync = message_filters.ApproximateTimeSynchronizer(
75  fs=self.subs, queue_size=queue_size, slop=slop)
76  else:
78  fs=self.subs, queue_size=queue_size)
79  sync.registerCallback(self.callback)
80 
81  def callback(self, img_msg, rects_msg, class_msg):
82  bridge = self.bridge
83  tracker = self.tracker
84 
85  frame = bridge.imgmsg_to_cv2(
86  img_msg, desired_encoding='bgr8')
87 
88  scores = []
89  rects = []
90  for i, r in enumerate(rects_msg.rects):
91  if self.target_labels is not None and \
92  class_msg.label_names[i] not in self.target_labels:
93  continue
94  rects.append((int(r.x), int(r.y),
95  int(r.width),
96  int(r.height)))
97  scores.append(class_msg.label_proba[i])
98  rects = np.array(rects, 'f')
99  scores = np.array(scores, 'f')
100 
101  if len(rects) > 0:
102  tracker.track(frame, rects, scores)
103 
104  if self.pub_labels.get_num_connections() > 0:
105  msg_labels = LabelArray(header=img_msg.header)
106  for object_id, target_object in tracker.tracking_objects.items():
107  if target_object['out_of_frame']:
108  continue
109  msg_labels.labels.append(Label(id=object_id))
110  self.pub_labels.publish(msg_labels)
111 
112  if self.image_pub.get_num_connections() > 0:
113  frame = tracker.visualize(frame, rects)
114  msg = bridge.cv2_to_imgmsg(frame, "bgr8")
115  msg.header = img_msg.header
116  self.image_pub.publish(msg)
117 
118 
119 def main():
120  rospy.init_node('deep_sort_tracker_node')
121  dstn = DeepSortTrackerNode() # NOQA
122  rospy.spin()
123 
124 
125 if __name__ == '__main__':
126  main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27