3 from __future__
import print_function
7 import itertools, pkg_resources, sys
8 from distutils.version
import LooseVersion
9 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0')
and \
10 sys.version_info.major == 2:
11 print(
'''Please install chainer < 7.0.0:
13 sudo pip install chainer==6.7.0
15 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
18 if [p
for p
in list(itertools.chain(*[pkg_resources.find_distributions(_)
for _
in sys.path]))
if "cupy-" in p.project_name
or "cupy" == p.project_name ] == []:
19 print(
'''Please install CuPy
21 sudo pip install cupy-cuda[your cuda version]
23 sudo pip install cupy-cuda91
28 from sensor_msgs.msg
import Image
29 from cv_bridge
import CvBridge
30 import message_filters
31 from jsk_recognition_msgs.msg
import RectArray
32 from jsk_recognition_msgs.msg
import ClassificationResult
33 from jsk_recognition_msgs.msg
import Label
34 from jsk_recognition_msgs.msg
import LabelArray
36 from deep_sort.deep_sort_tracker
import DeepSortTracker
42 super(DeepSortTrackerNode, self).
__init__()
46 self.
gpu = rospy.get_param(
'~gpu', 0)
48 chainer.cuda.get_device_from_id(self.
gpu).use()
49 chainer.global_config.train =
False
50 chainer.global_config.enable_backprop =
False
52 pretrained_model = rospy.get_param(
'~pretrained_model')
54 pretrained_model=pretrained_model)
59 '~output/labels', LabelArray, queue_size=1)
64 queue_size = rospy.get_param(
'~queue_size', 100)
66 '~input', Image, queue_size=1)
68 '~input/rects', RectArray, queue_size=1)
70 '~input/class', ClassificationResult, queue_size=1)
71 self.
subs = [sub_img, sub_rects, sub_class]
72 if rospy.get_param(
'~approximate_sync',
False):
73 slop = rospy.get_param(
'~slop', 0.1)
74 sync = message_filters.ApproximateTimeSynchronizer(
75 fs=self.
subs, queue_size=queue_size, slop=slop)
78 fs=self.
subs, queue_size=queue_size)
81 def callback(self, img_msg, rects_msg, class_msg):
85 frame = bridge.imgmsg_to_cv2(
86 img_msg, desired_encoding=
'bgr8')
91 if len(rects_msg.rects) != len(class_msg.label_proba):
92 rospy.logwarn(
'The sizes of RectArray and LabelArray does not match. Skipping...')
95 for i, r
in enumerate(rects_msg.rects):
99 rects.append((
int(r.x),
int(r.y),
102 scores.append(class_msg.label_proba[i])
103 rects = np.array(rects,
'f')
104 scores = np.array(scores,
'f')
107 tracker.track(frame, rects, scores)
110 msg_labels = LabelArray(header=img_msg.header)
111 for object_id, target_object
in tracker.tracking_objects.items():
112 if target_object[
'out_of_frame']:
114 msg_labels.labels.append(Label(id=object_id))
117 if self.
image_pub.get_num_connections() > 0:
118 frame = tracker.visualize(frame, rects)
119 msg = bridge.cv2_to_imgmsg(frame,
"bgr8")
120 msg.header = img_msg.header
125 rospy.init_node(
'deep_sort_tracker_node')
130 if __name__ ==
'__main__':