detection.py
Go to the documentation of this file.
1 from __future__ import print_function
2 
3 import os
4 import os.path as osp
5 
6 from chainercv.chainer_experimental.datasets.sliceable import GetterDataset
7 import cv2
8 import numpy as np
9 
10 
11 class DetectionDataset(GetterDataset):
12 
13  def __init__(self, root_dir):
14  super(DetectionDataset, self).__init__()
15  self.root_dir = root_dir
16 
17  class_names_path = osp.join(root_dir, 'class_names.txt')
18  with open(class_names_path, 'r') as f:
19  class_names = f.readlines()
20  self.class_names = [name.rstrip() for name in class_names]
21  self.fg_class_names = self.class_names[1:]
22 
23  self._imgs = []
24  self._class_labels = []
25  self._instance_labels = []
26  imgs_dir = osp.join(root_dir, 'JPEGImages')
27  class_labels_dir = osp.join(root_dir, 'SegmentationClass')
28  instance_labels_dir = osp.join(root_dir, 'SegmentationObject')
29  for img_ in sorted(os.listdir(imgs_dir)):
30  img_path = osp.join(imgs_dir, img_)
31  basename = img_.rstrip('.jpg')
32  class_label_path = osp.join(
33  class_labels_dir, basename + '.npy')
34  instance_label_path = osp.join(
35  instance_labels_dir, basename + '.npy')
36  self._imgs.append(img_path)
37  self._class_labels.append(class_label_path)
38  self._instance_labels.append(instance_label_path)
39  self.add_getter(('img', 'bbox', 'label'), self._get_example)
40 
41  def __len__(self):
42  return len(self._imgs)
43 
44  def _get_example(self, i):
45  img_path = self._imgs[i]
46  class_label_path = self._class_labels[i]
47  instance_label_path = self._instance_labels[i]
48 
49  img = cv2.imread(img_path)
50  assert img.dtype == np.uint8
51  assert img.ndim == 3
52 
53  class_label = np.load(class_label_path)
54  assert class_label.dtype == np.int32
55  assert class_label.ndim == 2
56 
57  instance_label = np.load(instance_label_path)
58  assert instance_label.dtype == np.int32
59  assert instance_label.ndim == 2
60 
61  assert img.shape[:2] == class_label.shape == instance_label.shape
62 
63  # instance id 0 is '_background_' and should be ignored.
64  R = np.max(instance_label)
65  label = np.zeros((R, ), dtype=np.int32)
66  bbox = np.zeros((R, 4), dtype=np.float32)
67  for inst_lbl in range(R):
68  if inst_lbl == 0:
69  continue
70  inst_mask = instance_label == inst_lbl
71  cls_lbl = np.argmax(np.bincount(class_label[inst_mask]))
72  label[inst_lbl] = cls_lbl - 1
73  yind, xind = np.where(inst_mask)
74  ymin = yind.min()
75  ymax = yind.max() + 1
76  xmin = xind.min()
77  xmax = xind.max() + 1
78  bbox[inst_lbl] = np.array(
79  [ymin, xmin, ymax, xmax], dtype=np.float32)
80  img = img[:, :, ::-1]
81  img = img.transpose((2, 0, 1))
82  return img, bbox, label


jsk_recognition_utils
Author(s):
autogenerated on Mon May 3 2021 03:03:03