Go to the documentation of this file.00001
00002
00003 import numpy as np
00004 import rospy
00005 import chainer
00006 from sensor_msgs.msg import Image
00007 from cv_bridge import CvBridge
00008 import message_filters
00009 from jsk_recognition_msgs.msg import RectArray
00010 from jsk_recognition_msgs.msg import ClassificationResult
00011 from jsk_recognition_msgs.msg import Label
00012 from jsk_recognition_msgs.msg import LabelArray
00013
00014 from deep_sort.deep_sort_tracker import DeepSortTracker
00015
00016
00017 class DeepSortTrackerNode(object):
00018
00019 def __init__(self):
00020 super(DeepSortTrackerNode, self).__init__()
00021 self.bridge = CvBridge()
00022
00023 self.target_labels = rospy.get_param('~target_labels', None)
00024 self.gpu = rospy.get_param('~gpu', 0)
00025 if self.gpu >= 0:
00026 chainer.cuda.get_device_from_id(self.gpu).use()
00027 chainer.global_config.train = False
00028 chainer.global_config.enable_backprop = False
00029
00030 pretrained_model = rospy.get_param('~pretrained_model')
00031 self.tracker = DeepSortTracker(gpu=self.gpu,
00032 pretrained_model=pretrained_model)
00033 self.image_pub = rospy.Publisher(
00034 '~output/viz',
00035 Image, queue_size=1)
00036 self.pub_labels = rospy.Publisher(
00037 '~output/labels', LabelArray, queue_size=1)
00038
00039 self.subscribe()
00040
00041 def subscribe(self):
00042 queue_size = rospy.get_param('~queue_size', 100)
00043 sub_img = message_filters.Subscriber(
00044 '~input', Image, queue_size=1)
00045 sub_rects = message_filters.Subscriber(
00046 '~input/rects', RectArray, queue_size=1)
00047 sub_class = message_filters.Subscriber(
00048 '~input/class', ClassificationResult, queue_size=1)
00049 self.subs = [sub_img, sub_rects, sub_class]
00050 if rospy.get_param('~approximate_sync', False):
00051 slop = rospy.get_param('~slop', 0.1)
00052 sync = message_filters.ApproximateTimeSynchronizer(
00053 fs=self.subs, queue_size=queue_size, slop=slop)
00054 else:
00055 sync = message_filters.TimeSynchronizer(
00056 fs=self.subs, queue_size=queue_size)
00057 sync.registerCallback(self.callback)
00058
00059 def callback(self, img_msg, rects_msg, class_msg):
00060 bridge = self.bridge
00061 tracker = self.tracker
00062
00063 frame = bridge.imgmsg_to_cv2(
00064 img_msg, desired_encoding='bgr8')
00065
00066 scores = []
00067 rects = []
00068 for i, r in enumerate(rects_msg.rects):
00069 if self.target_labels is not None and \
00070 class_msg.label_names[i] not in self.target_labels:
00071 continue
00072 rects.append((int(r.x), int(r.y),
00073 int(r.width),
00074 int(r.height)))
00075 scores.append(class_msg.label_proba[i])
00076 rects = np.array(rects, 'f')
00077 scores = np.array(scores, 'f')
00078
00079 if len(rects) > 0:
00080 tracker.track(frame, rects, scores)
00081
00082 if self.pub_labels.get_num_connections() > 0:
00083 msg_labels = LabelArray(header=img_msg.header)
00084 for object_id, target_object in tracker.tracking_objects.items():
00085 if target_object['out_of_frame']:
00086 continue
00087 msg_labels.labels.append(Label(id=object_id))
00088 self.pub_labels.publish(msg_labels)
00089
00090 if self.image_pub.get_num_connections() > 0:
00091 frame = tracker.visualize(frame, rects)
00092 msg = bridge.cv2_to_imgmsg(frame, "bgr8")
00093 msg.header = img_msg.header
00094 self.image_pub.publish(msg)
00095
00096
00097 def main():
00098 rospy.init_node('deep_sort_tracker_node')
00099 dstn = DeepSortTrackerNode()
00100 rospy.spin()
00101
00102
00103 if __name__ == '__main__':
00104 main()