train_mask_rcnn.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import division
4 from __future__ import print_function
5 
6 import argparse
7 import datetime
8 import functools
9 import os
10 import os.path as osp
11 
12 os.environ['MPLBACKEND'] = 'Agg' # NOQA
13 
14 import itertools, pkg_resources, sys
15 from distutils.version import LooseVersion
16 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
17  sys.version_info.major == 2:
18  print('''Please install chainer < 7.0.0:
19 
20  sudo pip install chainer==6.7.0
21 
22 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
23 ''', file=sys.stderr)
24  sys.exit(1)
25 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
26  print('''Please install CuPy
27 
28  sudo pip install cupy-cuda[your cuda version]
29 i.e.
30  sudo pip install cupy-cuda91
31 
32 ''', file=sys.stderr)
33  # sys.exit(1)
34 import chainer
35 from chainer import cuda
36 from chainer.datasets import TransformDataset
37 from chainer.training import extensions
38 import chainer_mask_rcnn as cmr
39 import fcn
40 
41 from jsk_recognition_utils.datasets import InstanceSegmentationDataset
42 import rospkg
43 
44 
46 
47  def __init__(self):
48  rospack = rospkg.RosPack()
49  jsk_perception_datasets_path = osp.join(
50  rospack.get_path('jsk_perception'), 'learning_datasets')
51 
52  parser = argparse.ArgumentParser()
53 
54  # Dataset directory
55  parser.add_argument('--train_dataset_dir', type=str,
56  default=osp.join(jsk_perception_datasets_path,
57  'kitchen_dataset', 'train'))
58  parser.add_argument('--val_dataset_dir', type=str,
59  default=osp.join(jsk_perception_datasets_path,
60  'kitchen_dataset', 'test'))
61 
62  # Model
63  parser.add_argument(
64  '--model_name', type=str, default='resnet50',
65  choices=['vgg16', 'resnet50', 'resnet101'])
66 
67  # Training parameters
68  parser.add_argument('--gpu', type=int, default=0)
69  parser.add_argument('--batch_size', type=int, default=1)
70  parser.add_argument('--max_epoch', type=int, default=100)
71  parser.add_argument('--lr', type=float, default=0.00125)
72  parser.add_argument('--weight_decay', type=float, default=0.0001)
73  parser.add_argument('--out_dir', type=str, default=None)
74  parser.add_argument('--progressbar_update_interval', type=float,
75  default=10)
76  parser.add_argument('--print_interval', type=float, default=100)
77  parser.add_argument('--print_interval_type', type=str,
78  default='iteration',
79  choices=['epoch', 'iteration'])
80  parser.add_argument('--log_interval', type=float, default=10)
81  parser.add_argument('--log_interval_type', type=str,
82  default='iteration',
83  choices=['epoch', 'iteration'])
84  parser.add_argument('--plot_interval', type=float, default=5)
85  parser.add_argument('--plot_interval_type', type=str,
86  default='epoch',
87  choices=['epoch', 'iteration'])
88  parser.add_argument('--eval_interval', type=float, default=10)
89  parser.add_argument('--eval_interval_type', type=str,
90  default='epoch',
91  choices=['epoch', 'iteration'])
92  parser.add_argument('--save_interval', type=float, default=10)
93  parser.add_argument('--save_interval_type', type=str,
94  default='epoch',
95  choices=['epoch', 'iteration'])
96 
97  args = parser.parse_args()
98 
99  self.train_dataset_dir = args.train_dataset_dir
100  self.val_dataset_dir = args.val_dataset_dir
101  self.model_name = args.model_name
102  self.gpu = args.gpu
103  self.batch_size = args.batch_size
104  self.max_epoch = args.max_epoch
105  self.lr = args.lr
106  self.weight_decay = args.weight_decay
107  self.out_dir = args.out_dir
108  self.progressbar_update_interval = args.progressbar_update_interval
109  self.print_interval = args.print_interval
110  self.print_interval_type = args.print_interval_type
111  self.log_interval = args.log_interval
112  self.log_interval_type = args.log_interval_type
113  self.plot_interval = args.plot_interval
114  self.plot_interval_type = args.plot_interval_type
115  self.eval_interval = args.eval_interval
116  self.eval_interval_type = args.eval_interval_type
117  self.save_interval = args.save_interval
118  self.save_interval_type = args.save_interval_type
119 
120  now = datetime.datetime.now()
121  self.timestamp_iso = now.isoformat()
122  timestamp = now.strftime('%Y%m%d-%H%M%S')
123  if self.out_dir is None:
124  self.out_dir = osp.join(
125  rospkg.get_ros_home(), 'learning_logs', timestamp)
126 
127  # Main process
128  self.load_dataset()
129  self.load_model()
130  self.setup_optimizer()
131  self.setup_iterator()
132  self.setup_trainer()
133  self.trainer.run()
134 
135  def load_dataset(self):
136  self.train_dataset = InstanceSegmentationDataset(
137  self.train_dataset_dir)
138  self.val_dataset = InstanceSegmentationDataset(self.val_dataset_dir)
139 
140  def load_model(self):
141  n_fg_class = len(self.train_dataset.fg_class_names)
142 
143  pooling_func = cmr.functions.roi_align_2d
144  anchor_scales = (4, 8, 16, 32)
145  roi_size = 14
146  min_size = 600
147  max_size = 1000
148  mask_initialW = chainer.initializers.Normal(0.01)
149 
150  if self.model_name == 'vgg16':
151  self.mask_rcnn = cmr.models.MaskRCNNVGG16(
152  n_fg_class=n_fg_class,
153  pretrained_model='imagenet',
154  pooling_func=pooling_func,
155  anchor_scales=anchor_scales,
156  roi_size=roi_size,
157  min_size=min_size,
158  max_size=max_size,
159  mask_initialW=mask_initialW,
160  )
161  elif self.model_name in ['resnet50', 'resnet101']:
162  n_layers = int(self.model_name.lstrip('resnet'))
163  self.mask_rcnn = cmr.models.MaskRCNNResNet(
164  n_layers=n_layers,
165  n_fg_class=n_fg_class,
166  pooling_func=pooling_func,
167  anchor_scales=anchor_scales,
168  roi_size=roi_size,
169  min_size=min_size,
170  max_size=max_size,
171  mask_initialW=mask_initialW,
172  )
173  else:
174  raise ValueError(
175  'Unsupported model_name: {}'.format(self.model_name))
176  self.model = cmr.models.MaskRCNNTrainChain(self.mask_rcnn)
177 
178  if self.gpu >= 0:
179  cuda.get_device_from_id(self.gpu).use()
180  self.model.to_gpu()
181 
182  def setup_optimizer(self):
183  self.optimizer = chainer.optimizers.MomentumSGD(
184  lr=self.lr, momentum=0.9)
185  self.optimizer.setup(self.model)
186  self.optimizer.add_hook(
187  chainer.optimizer.WeightDecay(rate=self.weight_decay))
188 
189  if self.model_name in ['resnet50', 'resnet101']:
190  # ResNetExtractor.freeze_at is not enough to freeze params
191  # since WeightDecay updates the param little by little.
192  self.mask_rcnn.extractor.conv1.disable_update()
193  self.mask_rcnn.extractor.bn1.disable_update()
194  self.mask_rcnn.extractor.res2.disable_update()
195  for link in self.mask_rcnn.links():
196  if isinstance(link, cmr.links.AffineChannel2D):
197  link.disable_update()
198 
199  def setup_iterator(self):
200  train_dataset_transformed = TransformDataset(
201  self.train_dataset, cmr.datasets.MaskRCNNTransform(self.mask_rcnn))
202  val_dataset_transformed = TransformDataset(
203  self.val_dataset,
204  cmr.datasets.MaskRCNNTransform(self.mask_rcnn, train=False))
205  # FIXME: MultiProcessIterator sometimes hangs
206  self.train_iterator = chainer.iterators.SerialIterator(
207  train_dataset_transformed, batch_size=self.batch_size)
208  self.val_iterator = chainer.iterators.SerialIterator(
209  val_dataset_transformed, batch_size=self.batch_size,
210  repeat=False, shuffle=False)
211 
212  def setup_trainer(self):
213  converter = functools.partial(
214  cmr.datasets.concat_examples,
215  padding=0,
216  # img, bboxes, labels, masks, scales
217  indices_concat=[0, 2, 3, 4], # img, _, labels, masks, scales
218  indices_to_device=[0, 1], # img, bbox
219  )
220  self.updater = chainer.training.updater.StandardUpdater(
221  self.train_iterator, self.optimizer, device=self.gpu,
222  converter=converter)
223  self.trainer = chainer.training.Trainer(
224  self.updater, (self.max_epoch, 'epoch'), out=self.out_dir)
225 
226  step_size = [
227  (120e3 / 180e3) * self.max_epoch,
228  (160e3 / 180e3) * self.max_epoch,
229  ]
230  self.trainer.extend(
231  extensions.ExponentialShift('lr', 0.1),
232  trigger=chainer.training.triggers.ManualScheduleTrigger(
233  step_size, 'epoch'))
234 
235  evaluator = cmr.extensions.InstanceSegmentationVOCEvaluator(
236  self.val_iterator, self.model.mask_rcnn, device=self.gpu,
237  use_07_metric=True, label_names=self.train_dataset.fg_class_names)
238  self.trainer.extend(
239  evaluator, trigger=(self.eval_interval, self.eval_interval_type))
240 
241  # Save snapshot
242  self.trainer.extend(
243  extensions.snapshot_object(
244  self.model.mask_rcnn, 'snapshot_model.npz'),
245  trigger=chainer.training.triggers.MaxValueTrigger(
246  'validation/main/map',
247  (self.save_interval, self.save_interval_type)))
248 
249  # Dump network architecture
250  self.trainer.extend(
251  extensions.dump_graph(
252  root_name='main/loss',
253  out_name='network_architecture.dot'))
254 
255  # Logging
256  self.trainer.extend(
257  extensions.ProgressBar(
258  update_interval=self.progressbar_update_interval))
259  self.trainer.extend(
260  extensions.observe_lr(),
261  trigger=(self.log_interval, self.log_interval_type))
262  self.trainer.extend(
263  extensions.LogReport(
264  log_name='log.json',
265  trigger=(self.log_interval, self.log_interval_type)))
266  self.trainer.extend(
267  extensions.PrintReport([
268  'iteration',
269  'epoch',
270  'elapsed_time',
271  'lr',
272  'main/loss',
273  'main/roi_loc_loss',
274  'main/roi_cls_loss',
275  'main/roi_mask_loss',
276  'main/rpn_loc_loss',
277  'main/rpn_cls_loss',
278  'validation/main/map',
279  ]), trigger=(self.print_interval, self.print_interval_type))
280 
281  # Plot
282  self.trainer.extend(
283  extensions.PlotReport([
284  'main/loss',
285  'main/roi_loc_loss',
286  'main/roi_cls_loss',
287  'main/roi_mask_loss',
288  'main/rpn_loc_loss',
289  'main/rpn_cls_loss',
290  ],
291  file_name='loss_plot.png',
292  x_key=self.plot_interval_type,
293  trigger=(self.plot_interval, self.plot_interval_type)),
294  trigger=(self.plot_interval, self.plot_interval_type))
295  self.trainer.extend(
296  extensions.PlotReport(
297  ['validation/main/map'],
298  file_name='accuracy_plot.png',
299  x_key=self.plot_interval_type,
300  trigger=(self.plot_interval, self.plot_interval_type)),
301  trigger=(self.eval_interval, self.eval_interval_type))
302 
303  # Dump params
304  params = dict()
305  params['model_name'] = self.model_name
306  params['train_dataset_dir'] = self.train_dataset_dir
307  params['val_dataset_dir'] = self.val_dataset_dir
308  params['fg_class_names'] = self.train_dataset.fg_class_names
309  params['timestamp'] = self.timestamp_iso
310  params['out_dir'] = self.out_dir
311  params['gpu'] = self.gpu
312  params['batch_size'] = self.batch_size
313  params['max_epoch'] = self.max_epoch
314  params['lr'] = self.lr
315  params['weight_decay'] = self.weight_decay
316  self.trainer.extend(
317  fcn.extensions.ParamsReport(params, file_name='params.yaml'))
318 
319  # Dump param for mask_rcnn_instance_segmentation.py
320  target_names = dict()
321  target_names['fg_class_names'] = self.train_dataset.fg_class_names
322  self.trainer.extend(
323  fcn.extensions.ParamsReport(
324  target_names, file_name='fg_class_names.yaml'))
325 
326 
327 if __name__ == '__main__':


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27