00001
00002 import argparse
00003
00004 import cv2
00005 import numpy as np
00006
00007 import deep_sort_app
00008 from deep_sort.iou_matching import iou
00009 from application_util import visualization
00010
00011
00012 DEFAULT_UPDATE_MS = 20
00013
00014
00015 def run(sequence_dir, result_file, show_false_alarms=False, detection_file=None,
00016 update_ms=None, video_filename=None):
00017 """Run tracking result visualization.
00018
00019 Parameters
00020 ----------
00021 sequence_dir : str
00022 Path to the MOTChallenge sequence directory.
00023 result_file : str
00024 Path to the tracking output file in MOTChallenge ground truth format.
00025 show_false_alarms : Optional[bool]
00026 If True, false alarms are highlighted as red boxes.
00027 detection_file : Optional[str]
00028 Path to the detection file.
00029 update_ms : Optional[int]
00030 Number of milliseconds between cosecutive frames. Defaults to (a) the
00031 frame rate specifid in the seqinfo.ini file or DEFAULT_UDPATE_MS ms if
00032 seqinfo.ini is not available.
00033 video_filename : Optional[Str]
00034 If not None, a video of the tracking results is written to this file.
00035
00036 """
00037 seq_info = deep_sort_app.gather_sequence_info(sequence_dir, detection_file)
00038 results = np.loadtxt(result_file, delimiter=',')
00039
00040 if show_false_alarms and seq_info["groundtruth"] is None:
00041 raise ValueError("No groundtruth available. Cannot show false alarms.")
00042
00043 def frame_callback(vis, frame_idx):
00044 print("Frame idx", frame_idx)
00045 image = cv2.imread(
00046 seq_info["image_filenames"][frame_idx], cv2.IMREAD_COLOR)
00047
00048 vis.set_image(image.copy())
00049
00050 if seq_info["detections"] is not None:
00051 detections = deep_sort_app.create_detections(
00052 seq_info["detections"], frame_idx)
00053 vis.draw_detections(detections)
00054
00055 mask = results[:, 0].astype(np.int) == frame_idx
00056 track_ids = results[mask, 1].astype(np.int)
00057 boxes = results[mask, 2:6]
00058 vis.draw_groundtruth(track_ids, boxes)
00059
00060 if show_false_alarms:
00061 groundtruth = seq_info["groundtruth"]
00062 mask = groundtruth[:, 0].astype(np.int) == frame_idx
00063 gt_boxes = groundtruth[mask, 2:6]
00064 for box in boxes:
00065
00066
00067 min_iou_overlap = 0.5
00068 if iou(box, gt_boxes).max() < min_iou_overlap:
00069 vis.viewer.color = 0, 0, 255
00070 vis.viewer.thickness = 4
00071 vis.viewer.rectangle(*box.astype(np.int))
00072
00073 if update_ms is None:
00074 update_ms = seq_info["update_ms"]
00075 if update_ms is None:
00076 update_ms = DEFAULT_UPDATE_MS
00077 visualizer = visualization.Visualization(seq_info, update_ms)
00078 if video_filename is not None:
00079 visualizer.viewer.enable_videowriter(video_filename)
00080 visualizer.run(frame_callback)
00081
00082
00083 def parse_args():
00084 """ Parse command line arguments.
00085 """
00086 parser = argparse.ArgumentParser(description="Siamese Tracking")
00087 parser.add_argument(
00088 "--sequence_dir", help="Path to the MOTChallenge sequence directory.",
00089 default=None, required=True)
00090 parser.add_argument(
00091 "--result_file", help="Tracking output in MOTChallenge file format.",
00092 default=None, required=True)
00093 parser.add_argument(
00094 "--detection_file", help="Path to custom detections (optional).",
00095 default=None)
00096 parser.add_argument(
00097 "--update_ms", help="Time between consecutive frames in milliseconds. "
00098 "Defaults to the frame_rate specified in seqinfo.ini, if available.",
00099 default=None)
00100 parser.add_argument(
00101 "--output_file", help="Filename of the (optional) output video.",
00102 default=None)
00103 parser.add_argument(
00104 "--show_false_alarms", help="Show false alarms as red bounding boxes.",
00105 type=bool, default=False)
00106 return parser.parse_args()
00107
00108
00109 if __name__ == "__main__":
00110 args = parse_args()
00111 run(
00112 args.sequence_dir, args.result_file, args.show_false_alarms,
00113 args.detection_file, args.update_ms, args.output_file)