ssd_train_dataset.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Author: Yuki Furuta <furushchev@jsk.imi.i.u-tokyo.ac.jp>
4 
5 from __future__ import print_function
6 
7 import argparse
8 import copy
9 import json
10 import numpy as np
11 import os
12 import sys
13 import yaml
14 
15 # chainer
16 import itertools, pkg_resources
17 from distutils.version import LooseVersion
18 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
19  sys.version_info.major == 2:
20  print('''Please install chainer < 7.0.0:
21 
22  sudo pip install chainer==6.7.0
23 
24 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
25 ''', file=sys.stderr)
26  sys.exit(1)
27 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
28  print('''Please install CuPy
29 
30  sudo pip install cupy-cuda[your cuda version]
31 i.e.
32  sudo pip install cupy-cuda91
33 
34 ''', file=sys.stderr)
35  # sys.exit(1)
36 import chainer
37 from chainer import serializers
38 from chainer import training
39 from chainer.datasets import TransformDataset
40 from chainer.training import extensions
41 from chainer.optimizer import WeightDecay
42 
43 # chainercv
44 from chainercv import transforms
45 from chainercv.extensions import DetectionVOCEvaluator
46 from chainercv.links import SSD300
47 from chainercv.links.model.ssd import GradientScaling
48 from chainercv.links.model.ssd import multibox_loss
49 from chainercv.links.model.ssd import random_distort
50 from chainercv.links.model.ssd import random_crop_with_bbox_constraints
51 from chainercv.links.model.ssd import resize_with_random_interpolation
52 from chainercv.utils import read_image
53 
54 
55 class SSDDataset(chainer.dataset.DatasetMixin):
56 
57  def __init__(self, base_dir, label_names):
58  self.base_dir = base_dir
59  self.label_names = label_names
60 
61  self.img_filenames = []
62  for name in os.listdir(base_dir):
63  # If the file is not an image, ignore the file.
64  if os.path.splitext(name)[1] != '.jpg':
65  continue
66  self.img_filenames.append(os.path.join(base_dir, name))
67 
68  def __len__(self):
69  return len(self.img_filenames)
70 
71  def get_example(self, i):
72  img_filename = self.img_filenames[i]
73  img = read_image(img_filename)
74 
75  anno_filename = os.path.splitext(img_filename)[0] + '__labels.json'
76 
77  with open(anno_filename, 'r') as f:
78  anno = json.load(f)
79  anno = anno['labels']
80 
81  bbox = []
82  label = []
83  for anno_i in anno:
84  h = anno_i['size']['y']
85  w = anno_i['size']['x']
86  center_y = anno_i['centre']['y']
87  center_x = anno_i['centre']['x']
88  try:
89  l = self.label_names.index(anno_i['label_class'])
90  except Exception as e:
91  print("Failed to index label class: {}".format(anno_i), file=sys.stderr)
92  print("image file name: {}".format(img_filename), file=sys.stderr)
93  print("annotation file name: {}".format(anno_filename), file=sys.stderr)
94  continue
95  bbox.append(
96  [center_y - h / 2, center_x - w / 2,
97  center_y + h / 2, center_x + w / 2])
98  label.append(l)
99  return img, np.array(bbox, dtype=np.float32), np.array(label, dtype=np.int32)
100 
101 
102 class MultiboxTrainChain(chainer.Chain):
103 
104  def __init__(self, model, alpha=1, k=3):
105  super(MultiboxTrainChain, self).__init__()
106  with self.init_scope():
107  self.model = model
108  self.alpha = alpha
109  self.k = k
110 
111  def __call__(self, imgs, gt_mb_locs, gt_mb_labs):
112  mb_locs, mb_confs = self.model(imgs)
113  loc_loss, conf_loss = multibox_loss(
114  mb_locs, mb_confs, gt_mb_locs, gt_mb_labs, self.k)
115  loss = loc_loss * self.alpha + conf_loss
116 
117  chainer.reporter.report(
118  {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
119  self)
120 
121  return loss
122 
123 
125  """Class for augumentation"""
126 
127  def __init__(self, coder, size, mean):
128  # copy to send to cpu
129  self.coder = copy.copy(coder)
130  self.coder.to_cpu()
131 
132  self.size = size
133  self.mean = mean
134 
135  def __call__(self, in_data):
136  img, bbox, label = in_data
137 
138  # 1. Color augumentation
139  img = random_distort(img)
140 
141  # 2. Random expansion
142  if np.random.randint(2):
143  img, param = transforms.random_expand(
144  img, fill=self.mean, return_param=True)
145  bbox = transforms.translate_bbox(
146  bbox, y_offset=param["y_offset"], x_offset=param["x_offset"])
147 
148  # 3. Random cropping
149  img, param = random_crop_with_bbox_constraints(
150  img, bbox, return_param=True)
151  bbox, param = transforms.crop_bbox(
152  bbox, y_slice=param["y_slice"], x_slice=param["x_slice"],
153  allow_outside_center=False, return_param=True)
154  label = label[param["index"]]
155 
156  # 4. Resizing with random interpolation
157  _, H, W = img.shape
158  img = resize_with_random_interpolation(img, (self.size, self.size))
159  bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))
160 
161  # 5. Transformation for SSD network input
162  img -= self.mean
163  mb_loc, mb_lab = self.coder.encode(bbox, label)
164 
165  return img, mb_loc, mb_lab
166 
167 
168 if __name__ == '__main__':
169 
170  p = argparse.ArgumentParser()
171  p.add_argument("label_file", help="path to label file")
172  p.add_argument("train", help="path to train dataset directory")
173  p.add_argument("--val", help="path to validation dataset directory. If this argument is not specified, train dataset is used with ratio train:val = 8:2.", default=None)
174  p.add_argument("--base-model", help="base model name", default="voc0712")
175  p.add_argument("--batchsize", "-b", type=int, default=16)
176  p.add_argument("--iteration", type=int, default=120000)
177  p.add_argument("--gpu", "-g", type=int, default=-1) # use CPU by default
178  p.add_argument("--out", "-o", type=str, default="results")
179  p.add_argument("--resume", type=str, default="")
180  p.add_argument("--lr", type=float, default=1e-4)
181  p.add_argument("--val-iter", type=int, default=100)
182  p.add_argument("--log-iter", type=int, default=10)
183  p.add_argument("--model-iter", type=int, default=200)
184 
185  args = p.parse_args()
186 
187  # load label file
188  with open(args.label_file, "r") as f:
189  label_names = tuple(yaml.load(f))
190 
191  print("Loaded %d labels" % len(label_names))
192 
193  if args.val is None:
194  dataset = SSDDataset(args.train, label_names)
195  train, test = chainer.datasets.split_dataset_random(
196  dataset, int(len(dataset) * 0.8))
197  else:
198  train = SSDDataset(args.train, label_names)
199  test = SSDDataset(args.val, label_names)
200 
201  print("train: {}, test: {}".format(len(train), len(test)))
202 
203  pretrained_model = SSD300(pretrained_model=args.base_model)
204 
205  # copy from pretrained model
206  model = SSD300(n_fg_class=len(dataset.label_names))
207  model.extractor.copyparams(pretrained_model.extractor)
208  model.multibox.loc.copyparams(pretrained_model.multibox.loc)
209 
210  model.use_preset("evaluate")
211 
212  train_chain = MultiboxTrainChain(model)
213 
214  if args.gpu >= 0:
215  chainer.cuda.get_device(args.gpu).use()
216  model.to_gpu()
217 
218  train = TransformDataset(
219  train, Transform(model.coder, model.insize, model.mean))
220  train_iter = chainer.iterators.MultiprocessIterator(
221  train, args.batchsize)
222 
223  test_iter = chainer.iterators.SerialIterator(
224  test, args.batchsize,
225  repeat=False, shuffle=False)
226 
227  optimizer = chainer.optimizers.MomentumSGD(lr=args.lr)
228  optimizer.setup(train_chain)
229 
230  for param in train_chain.params():
231  if param.name == 'b':
232  param.update_rule.add_hook(GradientScaling(2))
233  else:
234  param.update_rule.add_hook(WeightDecay(0.0005))
235 
236  updater = training.StandardUpdater(
237  train_iter, optimizer, device=args.gpu)
238  trainer = training.Trainer(
239  updater, (args.iteration, "iteration"), args.out)
240 
241  val_interval = args.val_iter, "iteration"
242  trainer.extend(
243  DetectionVOCEvaluator(
244  test_iter, model, use_07_metric=True,
245  label_names=label_names),
246  trigger=val_interval)
247 
248  log_interval = args.log_iter, "iteration"
249  trainer.extend(extensions.LogReport(trigger=log_interval))
250  trainer.extend(extensions.observe_lr(), trigger=log_interval)
251  trainer.extend(extensions.PrintReport(
252  ['epoch', 'iteration', 'lr',
253  'main/loss', 'main/loss/loc', 'main/loss/conf',
254  'validation/main/map']),
255  trigger=log_interval)
256  trainer.extend(extensions.ProgressBar(update_interval=10))
257 
258  trainer.extend(extensions.snapshot(), trigger=val_interval)
259  trainer.extend(
260  extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
261  trigger=(args.model_iter, 'iteration'))
262 
263  if args.resume:
264  serializers.load_npz(args.resume, trainer)
265 
266  trainer.run()
def __init__(self, coder, size, mean)
def __init__(self, model, alpha=1, k=3)
def __call__(self, imgs, gt_mb_locs, gt_mb_labs)
def __init__(self, base_dir, label_names)


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27