deep_sort_tracker_node.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 import numpy as np
6 import rospy
7 import itertools, pkg_resources, sys
8 from distutils.version import LooseVersion
9 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
10  sys.version_info.major == 2:
11  print('''Please install chainer < 7.0.0:
12 
13  sudo pip install chainer==6.7.0
14 
15 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
16 ''', file=sys.stderr)
17  sys.exit(1)
18 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name or "cupy" == p.project_name ] == []:
19  print('''Please install CuPy
20 
21  sudo pip install cupy-cuda[your cuda version]
22 i.e.
23  sudo pip install cupy-cuda91
24 
25 ''', file=sys.stderr)
26  # sys.exit(1)
27 import chainer
28 from sensor_msgs.msg import Image
29 from cv_bridge import CvBridge
30 import message_filters
31 from jsk_recognition_msgs.msg import RectArray
32 from jsk_recognition_msgs.msg import ClassificationResult
33 from jsk_recognition_msgs.msg import Label
34 from jsk_recognition_msgs.msg import LabelArray
35 
36 from deep_sort.deep_sort_tracker import DeepSortTracker
37 
38 
40 
41  def __init__(self):
42  super(DeepSortTrackerNode, self).__init__()
43  self.bridge = CvBridge()
44 
45  self.target_labels = rospy.get_param('~target_labels', None)
46  self.gpu = rospy.get_param('~gpu', 0)
47  if self.gpu >= 0:
48  chainer.cuda.get_device_from_id(self.gpu).use()
49  chainer.global_config.train = False
50  chainer.global_config.enable_backprop = False
51 
52  pretrained_model = rospy.get_param('~pretrained_model')
53  self.tracker = DeepSortTracker(gpu=self.gpu,
54  pretrained_model=pretrained_model)
55  self.image_pub = rospy.Publisher(
56  '~output/viz',
57  Image, queue_size=1)
58  self.pub_labels = rospy.Publisher(
59  '~output/labels', LabelArray, queue_size=1)
60 
61  self.subscribe()
62 
63  def subscribe(self):
64  queue_size = rospy.get_param('~queue_size', 100)
66  '~input', Image, queue_size=1)
67  sub_rects = message_filters.Subscriber(
68  '~input/rects', RectArray, queue_size=1)
69  sub_class = message_filters.Subscriber(
70  '~input/class', ClassificationResult, queue_size=1)
71  self.subs = [sub_img, sub_rects, sub_class]
72  if rospy.get_param('~approximate_sync', False):
73  slop = rospy.get_param('~slop', 0.1)
74  sync = message_filters.ApproximateTimeSynchronizer(
75  fs=self.subs, queue_size=queue_size, slop=slop)
76  else:
78  fs=self.subs, queue_size=queue_size)
79  sync.registerCallback(self.callback)
80 
81  def callback(self, img_msg, rects_msg, class_msg):
82  bridge = self.bridge
83  tracker = self.tracker
84 
85  frame = bridge.imgmsg_to_cv2(
86  img_msg, desired_encoding='bgr8')
87 
88  scores = []
89  rects = []
90 
91  if len(rects_msg.rects) != len(class_msg.label_proba):
92  rospy.logwarn('The sizes of RectArray and LabelArray does not match. Skipping...')
93  return
94 
95  for i, r in enumerate(rects_msg.rects):
96  if self.target_labels is not None and \
97  class_msg.label_names[i] not in self.target_labels:
98  continue
99  rects.append((int(r.x), int(r.y),
100  int(r.width),
101  int(r.height)))
102  scores.append(class_msg.label_proba[i])
103  rects = np.array(rects, 'f')
104  scores = np.array(scores, 'f')
105 
106  if len(rects) > 0:
107  tracker.track(frame, rects, scores)
108 
109  if self.pub_labels.get_num_connections() > 0:
110  msg_labels = LabelArray(header=img_msg.header)
111  for object_id, target_object in tracker.tracking_objects.items():
112  if target_object['out_of_frame']:
113  continue
114  msg_labels.labels.append(Label(id=object_id))
115  self.pub_labels.publish(msg_labels)
116 
117  if self.image_pub.get_num_connections() > 0:
118  frame = tracker.visualize(frame, rects)
119  msg = bridge.cv2_to_imgmsg(frame, "bgr8")
120  msg.header = img_msg.header
121  self.image_pub.publish(msg)
122 
123 
124 def main():
125  rospy.init_node('deep_sort_tracker_node')
126  dstn = DeepSortTrackerNode() # NOQA
127  rospy.spin()
128 
129 
130 if __name__ == '__main__':
131  main()
object
node_scripts.deep_sort.deep_sort_tracker.DeepSortTracker
Definition: deep_sort_tracker.py:105
ssd_train_dataset.int
int
Definition: ssd_train_dataset.py:175
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.pub_labels
pub_labels
Definition: deep_sort_tracker_node.py:58
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.subscribe
def subscribe(self)
Definition: deep_sort_tracker_node.py:63
message_filters::Subscriber
node_scripts.deep_sort_tracker_node.main
def main()
Definition: deep_sort_tracker_node.py:124
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode
Definition: deep_sort_tracker_node.py:39
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.bridge
bridge
Definition: deep_sort_tracker_node.py:43
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.tracker
tracker
Definition: deep_sort_tracker_node.py:53
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.image_pub
image_pub
Definition: deep_sort_tracker_node.py:55
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.target_labels
target_labels
Definition: deep_sort_tracker_node.py:45
message_filters::TimeSynchronizer
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.gpu
gpu
Definition: deep_sort_tracker_node.py:46
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.subs
subs
Definition: deep_sort_tracker_node.py:71
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.callback
def callback(self, img_msg, rects_msg, class_msg)
Definition: deep_sort_tracker_node.py:81
node_scripts.deep_sort_tracker_node.DeepSortTrackerNode.__init__
def __init__(self)
Definition: deep_sort_tracker_node.py:41


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:16