train_ssd.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 import argparse
6 import copy
7 import datetime
8 import numpy as np
9 import os.path as osp
10 
11 import itertools, pkg_resources, sys
12 from distutils.version import LooseVersion
13 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
14  sys.version_info.major == 2:
15  print('''Please install chainer < 7.0.0:
16 
17  sudo pip install chainer==6.7.0
18 
19 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
20 ''', file=sys.stderr)
21  sys.exit(1)
22 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
23  print('''Please install CuPy
24 
25  sudo pip install cupy-cuda[your cuda version]
26 i.e.
27  sudo pip install cupy-cuda91
28 
29 ''', file=sys.stderr)
30  # sys.exit(1)
31 import chainer
32 from chainer.datasets import TransformDataset
33 from chainer.optimizer_hooks import WeightDecay
34 from chainer import training
35 from chainer.training import extensions
36 from chainer.training import triggers
37 
38 from chainercv.extensions import DetectionVOCEvaluator
39 from chainercv.links.model.ssd import GradientScaling
40 from chainercv.links.model.ssd import multibox_loss
41 from chainercv.links import SSD300
42 from chainercv.links import SSD512
43 from chainercv import transforms
44 
45 from chainercv.links.model.ssd import random_crop_with_bbox_constraints
46 from chainercv.links.model.ssd import random_distort
47 from chainercv.links.model.ssd import resize_with_random_interpolation
48 
49 from jsk_recognition_utils.datasets import DetectionDataset
50 from jsk_recognition_utils.datasets import BboxDetectionDataset
51 import rospkg
52 
53 # https://docs.chainer.org/en/stable/tips.html#my-training-process-gets-stuck-when-using-multiprocessiterator
54 import cv2
55 cv2.setNumThreads(0)
56 chainer.config.cv_resize_backend = 'cv2'
57 
58 
59 class MultiboxTrainChain(chainer.Chain):
60 
61  def __init__(self, model, alpha=1, k=3):
62  super(MultiboxTrainChain, self).__init__()
63  with self.init_scope():
64  self.model = model
65  self.alpha = alpha
66  self.k = k
67 
68  def forward(self, imgs, gt_mb_locs, gt_mb_labels):
69  mb_locs, mb_confs = self.model(imgs)
70  loc_loss, conf_loss = multibox_loss(
71  mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
72  loss = loc_loss * self.alpha + conf_loss
73 
74  chainer.reporter.report(
75  {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
76  self)
77 
78  return loss
79 
80 
82 
83  def __init__(self, coder, size, mean):
84  # to send cpu, make a copy
85  self.coder = copy.copy(coder)
86  self.coder.to_cpu()
87 
88  self.size = size
89  self.mean = mean
90 
91  def __call__(self, in_data):
92  # There are five data augmentation steps
93  # 1. Color augmentation
94  # 2. Random expansion
95  # 3. Random cropping
96  # 4. Resizing with random interpolation
97  # 5. Random horizontal flipping
98 
99  img, bbox, label = in_data
100 
101  # 1. Color augmentation
102  img = random_distort(img)
103 
104  # 2. Random expansion
105  if np.random.randint(2):
106  img, param = transforms.random_expand(
107  img, fill=self.mean, return_param=True)
108  bbox = transforms.translate_bbox(
109  bbox, y_offset=param['y_offset'], x_offset=param['x_offset'])
110 
111  # 3. Random cropping
112  img, param = random_crop_with_bbox_constraints(
113  img, bbox, return_param=True)
114  bbox, param = transforms.crop_bbox(
115  bbox, y_slice=param['y_slice'], x_slice=param['x_slice'],
116  allow_outside_center=False, return_param=True)
117  label = label[param['index']]
118 
119  # 4. Resizing with random interpolatation
120  _, H, W = img.shape
121  img = resize_with_random_interpolation(img, (self.size, self.size))
122  bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))
123 
124  # 5. Random horizontal flipping
125  img, params = transforms.random_flip(
126  img, x_random=True, return_param=True)
127  bbox = transforms.flip_bbox(
128  bbox, (self.size, self.size), x_flip=params['x_flip'])
129 
130  # Preparation for SSD network
131  img -= self.mean
132  mb_loc, mb_label = self.coder.encode(bbox, label)
133 
134  return img, mb_loc, mb_label
135 
136 
137 def main():
138  rospack = rospkg.RosPack()
139  jsk_perception_datasets_path = osp.join(
140  rospack.get_path('jsk_perception'), 'learning_datasets')
141 
142  parser = argparse.ArgumentParser()
143  # Dataset directory
144  parser.add_argument('--train-dataset-dir', type=str,
145  default=osp.join(jsk_perception_datasets_path,
146  'kitchen_dataset', 'train'))
147  parser.add_argument('--val-dataset-dir', type=str,
148  default=osp.join(jsk_perception_datasets_path,
149  'kitchen_dataset', 'test'))
150  parser.add_argument('--dataset-type', type=str,
151  default='instance')
152  parser.add_argument(
153  '--model-name', choices=('ssd300', 'ssd512'), default='ssd512')
154  parser.add_argument('--gpu', type=int, default=0)
155  parser.add_argument('--batch-size', type=int, default=8)
156  parser.add_argument('--max-epoch', type=int, default=100)
157  parser.add_argument('--out-dir', type=str, default=None)
158  args = parser.parse_args()
159 
160  if (args.dataset_type == 'instance'):
161  train_dataset = DetectionDataset(args.train_dataset_dir)
162  elif (args.dataset_type == 'bbox'):
163  train_dataset = BboxDetectionDataset(args.train_dataset_dir)
164  else:
165  print('unsuppported dataset type')
166  return
167 
168  fg_label_names = train_dataset.fg_class_names
169 
170  if args.model_name == 'ssd300':
171  model = SSD300(
172  n_fg_class=len(fg_label_names),
173  pretrained_model='imagenet')
174  elif args.model_name == 'ssd512':
175  model = SSD512(
176  n_fg_class=len(fg_label_names),
177  pretrained_model='imagenet')
178 
179  model.use_preset('evaluate')
180  train_chain = MultiboxTrainChain(model)
181  if args.gpu >= 0:
182  chainer.cuda.get_device_from_id(args.gpu).use()
183  model.to_gpu()
184 
185  train = TransformDataset(
186  train_dataset,
187  Transform(model.coder, model.insize, model.mean))
188  train_iter = chainer.iterators.MultiprocessIterator(train, args.batch_size)
189 
190  if (args.dataset_type == 'instance'):
191  test_dataset = DetectionDataset(args.val_dataset_dir)
192  elif (args.dataset_type == 'bbox'):
193  test_dataset = BboxDetectionDataset(args.val_dataset_dir)
194 
195  test_iter = chainer.iterators.SerialIterator(
196  test_dataset, args.batch_size, repeat=False, shuffle=False)
197 
198  # initial lr is set to 1e-3 by ExponentialShift
199  optimizer = chainer.optimizers.MomentumSGD()
200  optimizer.setup(train_chain)
201  for param in train_chain.params():
202  if param.name == 'b':
203  param.update_rule.add_hook(GradientScaling(2))
204  else:
205  param.update_rule.add_hook(WeightDecay(0.0005))
206 
207  updater = training.updaters.StandardUpdater(
208  train_iter, optimizer, device=args.gpu)
209 
210  now = datetime.datetime.now()
211  timestamp = now.strftime('%Y%m%d-%H%M%S')
212  if args.out_dir is None:
213  out_dir = osp.join(
214  rospkg.get_ros_home(), 'learning_logs', timestamp)
215  else:
216  out_dir = args.out_dir
217 
218  step_epoch = [args.max_epoch * 2 // 3, args.max_epoch * 5 // 6]
219  trainer = training.Trainer(
220  updater, (args.max_epoch, 'epoch'), out_dir)
221  trainer.extend(
222  extensions.ExponentialShift('lr', 0.1, init=1e-3),
223  trigger=triggers.ManualScheduleTrigger(step_epoch, 'epoch'))
224 
225  trainer.extend(
226  DetectionVOCEvaluator(
227  test_iter, model, use_07_metric=True,
228  label_names=fg_label_names),
229  trigger=triggers.ManualScheduleTrigger(
230  step_epoch + [args.max_epoch], 'epoch'))
231 
232  log_interval = 10, 'iteration'
233  trainer.extend(
234  extensions.LogReport(log_name='log.json', trigger=log_interval))
235  trainer.extend(extensions.observe_lr(), trigger=log_interval)
236  trainer.extend(extensions.PrintReport(
237  ['epoch', 'iteration', 'lr',
238  'main/loss', 'main/loss/loc', 'main/loss/conf',
239  'validation/main/map']),
240  trigger=log_interval)
241  trainer.extend(extensions.ProgressBar(update_interval=10))
242 
243  trainer.extend(
244  extensions.snapshot_object(
245  model, 'model_snapshot.npz'),
246  trigger=(args.max_epoch, 'epoch'))
247 
248  trainer.run()
249 
250 
251 if __name__ == '__main__':
252  main()
def __init__(self, coder, size, mean)
Definition: train_ssd.py:83
def __call__(self, in_data)
Definition: train_ssd.py:91
def forward(self, imgs, gt_mb_locs, gt_mb_labels)
Definition: train_ssd.py:68
def __init__(self, model, alpha=1, k=3)
Definition: train_ssd.py:61
def main()
Definition: train_ssd.py:137


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27