deep_sort_tracker_node.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import numpy as np
00004 import rospy
00005 import chainer
00006 from sensor_msgs.msg import Image
00007 from cv_bridge import CvBridge
00008 import message_filters
00009 from jsk_recognition_msgs.msg import RectArray
00010 from jsk_recognition_msgs.msg import ClassificationResult
00011 from jsk_recognition_msgs.msg import Label
00012 from jsk_recognition_msgs.msg import LabelArray
00013 
00014 from deep_sort.deep_sort_tracker import DeepSortTracker
00015 
00016 
00017 class DeepSortTrackerNode(object):
00018 
00019     def __init__(self):
00020         super(DeepSortTrackerNode, self).__init__()
00021         self.bridge = CvBridge()
00022 
00023         self.target_labels = rospy.get_param('~target_labels', None)
00024         self.gpu = rospy.get_param('~gpu', 0)
00025         if self.gpu >= 0:
00026             chainer.cuda.get_device_from_id(self.gpu).use()
00027         chainer.global_config.train = False
00028         chainer.global_config.enable_backprop = False
00029 
00030         pretrained_model = rospy.get_param('~pretrained_model')
00031         self.tracker = DeepSortTracker(gpu=self.gpu,
00032                                        pretrained_model=pretrained_model)
00033         self.image_pub = rospy.Publisher(
00034             '~output/viz',
00035             Image, queue_size=1)
00036         self.pub_labels = rospy.Publisher(
00037             '~output/labels', LabelArray, queue_size=1)
00038 
00039         self.subscribe()
00040 
00041     def subscribe(self):
00042         queue_size = rospy.get_param('~queue_size', 100)
00043         sub_img = message_filters.Subscriber(
00044             '~input', Image, queue_size=1)
00045         sub_rects = message_filters.Subscriber(
00046             '~input/rects', RectArray, queue_size=1)
00047         sub_class = message_filters.Subscriber(
00048             '~input/class', ClassificationResult, queue_size=1)
00049         self.subs = [sub_img, sub_rects, sub_class]
00050         if rospy.get_param('~approximate_sync', False):
00051             slop = rospy.get_param('~slop', 0.1)
00052             sync = message_filters.ApproximateTimeSynchronizer(
00053                 fs=self.subs, queue_size=queue_size, slop=slop)
00054         else:
00055             sync = message_filters.TimeSynchronizer(
00056                 fs=self.subs, queue_size=queue_size)
00057         sync.registerCallback(self.callback)
00058 
00059     def callback(self, img_msg, rects_msg, class_msg):
00060         bridge = self.bridge
00061         tracker = self.tracker
00062 
00063         frame = bridge.imgmsg_to_cv2(
00064             img_msg, desired_encoding='bgr8')
00065 
00066         scores = []
00067         rects = []
00068         for i, r in enumerate(rects_msg.rects):
00069             if self.target_labels is not None and \
00070                class_msg.label_names[i] not in self.target_labels:
00071                 continue
00072             rects.append((int(r.x), int(r.y),
00073                           int(r.width),
00074                           int(r.height)))
00075             scores.append(class_msg.label_proba[i])
00076         rects = np.array(rects, 'f')
00077         scores = np.array(scores, 'f')
00078 
00079         if len(rects) > 0:
00080             tracker.track(frame, rects, scores)
00081 
00082         if self.pub_labels.get_num_connections() > 0:
00083             msg_labels = LabelArray(header=img_msg.header)
00084             for object_id, target_object in tracker.tracking_objects.items():
00085                 if target_object['out_of_frame']:
00086                     continue
00087                 msg_labels.labels.append(Label(id=object_id))
00088             self.pub_labels.publish(msg_labels)
00089 
00090         if self.image_pub.get_num_connections() > 0:
00091             frame = tracker.visualize(frame, rects)
00092             msg = bridge.cv2_to_imgmsg(frame, "bgr8")
00093             msg.header = img_msg.header
00094             self.image_pub.publish(msg)
00095 
00096 
00097 def main():
00098     rospy.init_node('deep_sort_tracker_node')
00099     dstn = DeepSortTrackerNode()  # NOQA
00100     rospy.spin()
00101 
00102 
00103 if __name__ == '__main__':
00104     main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07