00001 import cv2
00002 import numpy as np
00003 import chainer
00004
00005 from jsk_recognition_utils.chainermodels.deep_sort_net\
00006 import DeepSortFeatureExtractor
00007
00008 from vis_bboxes import vis_bboxes
00009 import deep_sort
00010
00011
00012 def extract_image_patch(image, bbox, patch_shape):
00013 """Extract image patch from bounding box.
00014 copied from
00015 https://github.com/nwojke/deep_sort/blob/master/tools/generate_detections.py
00016
00017 Parameters
00018 ----------
00019 image : ndarray
00020 The full image.
00021 bbox : array_like
00022 The bounding box in format (x, y, width, height).
00023 patch_shape : Optional[array_like]
00024 This parameter can be used to enforce a desired patch shape
00025 (height, width). First, the `bbox` is adapted to the aspect ratio
00026 of the patch shape, then it is clipped at the image boundaries.
00027 If None, the shape is computed from :arg:`bbox`.
00028
00029 Returns
00030 -------
00031 ndarray | NoneType
00032 An image patch showing the :arg:`bbox`, optionally reshaped to
00033 :arg:`patch_shape`.
00034 Returns None if the bounding box is empty or fully outside of the image
00035 boundaries.
00036
00037 """
00038 bbox = np.array(bbox)
00039 if patch_shape is not None:
00040
00041 target_aspect = float(patch_shape[1]) / patch_shape[0]
00042 new_width = target_aspect * bbox[3]
00043 bbox[0] -= (new_width - bbox[2]) / 2
00044 bbox[2] = new_width
00045
00046
00047 bbox[2:] += bbox[:2]
00048 bbox = bbox.astype(np.int)
00049
00050
00051 bbox[:2] = np.maximum(0, bbox[:2])
00052 bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:])
00053 if np.any(bbox[:2] >= bbox[2:]):
00054 return None
00055 sx, sy, ex, ey = bbox
00056 image = image[sy:ey, sx:ex]
00057 image = cv2.resize(image, tuple(patch_shape[::-1]))
00058 return image
00059
00060
00061 def encoder(image_encoder):
00062
00063 def _encoder(image, boxes):
00064 image_shape = 128, 64, 3
00065 image_patches = []
00066 for box in boxes:
00067 patch = extract_image_patch(
00068 image, box, image_shape[:2])
00069 if patch is None:
00070 patch = np.random.uniform(
00071 0., 255., image_shape).astype(np.uint8)
00072 image_patches.append(patch)
00073 image_patches = np.asarray(image_patches, 'f')
00074 image_patches = image_patches.transpose(0, 3, 1, 2)
00075 image_patches = image_encoder.xp.asarray(image_patches)
00076 with chainer.using_config('train', False):
00077 ret = image_encoder(image_patches)
00078 return chainer.cuda.to_cpu(ret.data)
00079
00080 return _encoder
00081
00082
00083 class DeepSortTracker(object):
00084
00085 def __init__(self, gpu=-1,
00086 pretrained_model=None,
00087 nms_max_overlap=1.0,
00088 max_cosine_distance=0.2,
00089 budget=None):
00090 self.max_cosine_distance = max_cosine_distance
00091 self.nms_max_overlap = nms_max_overlap
00092 self.budget = budget
00093
00094
00095 self.gpu = gpu
00096 self.extractor = DeepSortFeatureExtractor()
00097 if pretrained_model is not None:
00098 chainer.serializers.load_npz(
00099 pretrained_model, self.extractor)
00100 if self.gpu >= 0:
00101 self.extractor = self.extractor.to_gpu()
00102 self.encoder = encoder(self.extractor)
00103
00104
00105 self.n_tracked = 0
00106 self.tracking_objects = {}
00107 self.tracker = None
00108 self.track_id_to_object_id = {}
00109 self.reset()
00110
00111 def reset(self):
00112 self.track_id_to_object_id = {}
00113 self.tracking_objects = {}
00114 metric = deep_sort.deep_sort.nn_matching.NearestNeighborDistanceMetric(
00115 'cosine',
00116 matching_threshold=self.max_cosine_distance,
00117 budget=self.budget)
00118 self.tracker = deep_sort.deep_sort.tracker.Tracker(metric)
00119
00120 def track(self, frame, bboxes, scores):
00121
00122 indices = deep_sort.application_util.preprocessing.non_max_suppression(
00123 bboxes, self.nms_max_overlap, scores)
00124 bboxes = bboxes[indices]
00125 scores = scores[indices]
00126
00127
00128 features = self.encoder(frame, np.array(bboxes))
00129 n_bbox = len(bboxes)
00130 detections = [
00131 deep_sort.deep_sort.detection.Detection(
00132 bboxes[i], scores[i], features[i]) for i in range(n_bbox)]
00133
00134
00135 self.tracker.predict()
00136 self.tracker.update(detections)
00137
00138 for target_object in self.tracking_objects.values():
00139 target_object['out_of_frame'] = True
00140
00141
00142 for track in self.tracker.tracks:
00143 if not track.is_confirmed() or track.time_since_update > 1:
00144 continue
00145 bbox = track.to_tlwh()
00146
00147 if track.track_id in self.track_id_to_object_id:
00148
00149 target_object = self.tracking_objects[
00150 self.track_id_to_object_id[track.track_id]]
00151 target_object['out_of_frame'] = False
00152 target_object['bbox'] = bbox
00153 else:
00154
00155 object_id = self.n_tracked
00156 self.n_tracked += 1
00157 self.track_id_to_object_id[track.track_id] = object_id
00158 self.tracking_objects[object_id] = dict(
00159 out_of_frame=False,
00160 bbox=bbox)
00161
00162 def visualize(self, frame, bboxes):
00163 vis_frame = frame.copy()
00164 for x1, y1, w, h in bboxes:
00165 x2, y2 = x1 + w, y1 + h
00166 x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
00167 cv2.rectangle(vis_frame,
00168 (x1, y1), (x2, y2),
00169 (255, 255, 255), 3)
00170 labels, bboxes = [], []
00171 for object_id, target_object in self.tracking_objects.items():
00172 if target_object['out_of_frame']:
00173 continue
00174 labels.append(object_id)
00175 bboxes.append(target_object['bbox'])
00176 vis_bboxes(vis_frame, bboxes, labels)
00177 return vis_frame