deep_sort_tracker.py
Go to the documentation of this file.
00001 import cv2
00002 import numpy as np
00003 import chainer
00004 
00005 from jsk_recognition_utils.chainermodels.deep_sort_net\
00006     import DeepSortFeatureExtractor
00007 
00008 from vis_bboxes import vis_bboxes
00009 import deep_sort
00010 
00011 
00012 def extract_image_patch(image, bbox, patch_shape):
00013     """Extract image patch from bounding box.
00014     copied from
00015     https://github.com/nwojke/deep_sort/blob/master/tools/generate_detections.py
00016 
00017     Parameters
00018     ----------
00019     image : ndarray
00020         The full image.
00021     bbox : array_like
00022         The bounding box in format (x, y, width, height).
00023     patch_shape : Optional[array_like]
00024         This parameter can be used to enforce a desired patch shape
00025         (height, width). First, the `bbox` is adapted to the aspect ratio
00026         of the patch shape, then it is clipped at the image boundaries.
00027         If None, the shape is computed from :arg:`bbox`.
00028 
00029     Returns
00030     -------
00031     ndarray | NoneType
00032         An image patch showing the :arg:`bbox`, optionally reshaped to
00033         :arg:`patch_shape`.
00034         Returns None if the bounding box is empty or fully outside of the image
00035         boundaries.
00036 
00037     """
00038     bbox = np.array(bbox)
00039     if patch_shape is not None:
00040         # correct aspect ratio to patch shape
00041         target_aspect = float(patch_shape[1]) / patch_shape[0]
00042         new_width = target_aspect * bbox[3]
00043         bbox[0] -= (new_width - bbox[2]) / 2
00044         bbox[2] = new_width
00045 
00046     # convert to top left, bottom right
00047     bbox[2:] += bbox[:2]
00048     bbox = bbox.astype(np.int)
00049 
00050     # clip at image boundaries
00051     bbox[:2] = np.maximum(0, bbox[:2])
00052     bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
00053     if np.any(bbox[:2] >= bbox[2:]):
00054         return None
00055     sx, sy, ex, ey = bbox
00056     image = image[sy:ey, sx:ex]
00057     image = cv2.resize(image, tuple(patch_shape[::-1]))
00058     return image
00059 
00060 
00061 def encoder(image_encoder):
00062 
00063     def _encoder(image, boxes):
00064         image_shape = 128, 64, 3
00065         image_patches = []
00066         for box in boxes:
00067             patch = extract_image_patch(
00068                 image, box, image_shape[:2])
00069             if patch is None:
00070                 patch = np.random.uniform(
00071                     0., 255., image_shape).astype(np.uint8)
00072             image_patches.append(patch)
00073         image_patches = np.asarray(image_patches, 'f')
00074         image_patches = image_patches.transpose(0, 3, 1, 2)
00075         image_patches = image_encoder.xp.asarray(image_patches)
00076         with chainer.using_config('train', False):
00077             ret = image_encoder(image_patches)
00078         return chainer.cuda.to_cpu(ret.data)
00079 
00080     return _encoder
00081 
00082 
00083 class DeepSortTracker(object):
00084 
00085     def __init__(self, gpu=-1,
00086                  pretrained_model=None,
00087                  nms_max_overlap=1.0,
00088                  max_cosine_distance=0.2,
00089                  budget=None):
00090         self.max_cosine_distance = max_cosine_distance
00091         self.nms_max_overlap = nms_max_overlap
00092         self.budget = budget
00093 
00094         # feature extractor
00095         self.gpu = gpu
00096         self.extractor = DeepSortFeatureExtractor()
00097         if pretrained_model is not None:
00098             chainer.serializers.load_npz(
00099                 pretrained_model, self.extractor)
00100         if self.gpu >= 0:
00101             self.extractor = self.extractor.to_gpu()
00102         self.encoder = encoder(self.extractor)
00103 
00104         # variables for tracking objects
00105         self.n_tracked = 0  # number of tracked objects
00106         self.tracking_objects = {}
00107         self.tracker = None
00108         self.track_id_to_object_id = {}
00109         self.reset()
00110 
00111     def reset(self):
00112         self.track_id_to_object_id = {}
00113         self.tracking_objects = {}
00114         metric = deep_sort.deep_sort.nn_matching.NearestNeighborDistanceMetric(
00115             'cosine',
00116             matching_threshold=self.max_cosine_distance,
00117             budget=self.budget)
00118         self.tracker = deep_sort.deep_sort.tracker.Tracker(metric)
00119 
00120     def track(self, frame, bboxes, scores):
00121         # run non-maximam suppression.
00122         indices = deep_sort.application_util.preprocessing.non_max_suppression(
00123             bboxes, self.nms_max_overlap, scores)
00124         bboxes = bboxes[indices]
00125         scores = scores[indices]
00126 
00127         # generate detections.
00128         features = self.encoder(frame, np.array(bboxes))
00129         n_bbox = len(bboxes)
00130         detections = [
00131             deep_sort.deep_sort.detection.Detection(
00132                 bboxes[i], scores[i], features[i]) for i in range(n_bbox)]
00133 
00134         # update tracker.
00135         self.tracker.predict()
00136         self.tracker.update(detections)
00137 
00138         for target_object in self.tracking_objects.values():
00139             target_object['out_of_frame'] = True
00140 
00141         # store results
00142         for track in self.tracker.tracks:
00143             if not track.is_confirmed() or track.time_since_update > 1:
00144                 continue
00145             bbox = track.to_tlwh()
00146 
00147             if track.track_id in self.track_id_to_object_id:
00148                 # update tracked object
00149                 target_object = self.tracking_objects[
00150                     self.track_id_to_object_id[track.track_id]]
00151                 target_object['out_of_frame'] = False
00152                 target_object['bbox'] = bbox
00153             else:
00154                 # detected for the first time
00155                 object_id = self.n_tracked
00156                 self.n_tracked += 1
00157                 self.track_id_to_object_id[track.track_id] = object_id
00158                 self.tracking_objects[object_id] = dict(
00159                     out_of_frame=False,
00160                     bbox=bbox)
00161 
00162     def visualize(self, frame, bboxes):
00163         vis_frame = frame.copy()
00164         for x1, y1, w, h in bboxes:
00165             x2, y2 = x1 + w, y1 + h
00166             x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
00167             cv2.rectangle(vis_frame,
00168                           (x1, y1), (x2, y2),
00169                           (255, 255, 255), 3)
00170         labels, bboxes = [], []
00171         for object_id, target_object in self.tracking_objects.items():
00172             if target_object['out_of_frame']:
00173                 continue
00174             labels.append(object_id)
00175             bboxes.append(target_object['bbox'])
00176         vis_bboxes(vis_frame, bboxes, labels)
00177         return vis_frame


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07