train_mask_rcnn.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 from __future__ import division
00004 
00005 import argparse
00006 import datetime
00007 import functools
00008 import os
00009 import os.path as osp
00010 
00011 os.environ['MPLBACKEND'] = 'Agg'  # NOQA
00012 
00013 import chainer
00014 from chainer import cuda
00015 from chainer.datasets import TransformDataset
00016 from chainer.training import extensions
00017 import chainer_mask_rcnn as cmr
00018 import fcn
00019 
00020 from jsk_recognition_utils.datasets import InstanceSegmentationDataset
00021 import rospkg
00022 
00023 
00024 class TrainMaskRCNN(object):
00025 
00026     def __init__(self):
00027         rospack = rospkg.RosPack()
00028         jsk_perception_datasets_path = osp.join(
00029             rospack.get_path('jsk_perception'), 'learning_datasets')
00030 
00031         parser = argparse.ArgumentParser()
00032 
00033         # Dataset directory
00034         parser.add_argument('--train_dataset_dir', type=str,
00035                             default=osp.join(jsk_perception_datasets_path,
00036                                              'kitchen_dataset', 'train'))
00037         parser.add_argument('--val_dataset_dir', type=str,
00038                             default=osp.join(jsk_perception_datasets_path,
00039                                              'kitchen_dataset', 'test'))
00040 
00041         # Model
00042         parser.add_argument(
00043             '--model_name', type=str, default='resnet50',
00044             choices=['vgg16', 'resnet50', 'resnet101'])
00045 
00046         # Training parameters
00047         parser.add_argument('--gpu', type=int, default=0)
00048         parser.add_argument('--batch_size', type=int, default=1)
00049         parser.add_argument('--max_epoch', type=int, default=100)
00050         parser.add_argument('--lr', type=float, default=0.00125)
00051         parser.add_argument('--weight_decay', type=float, default=0.0001)
00052         parser.add_argument('--out_dir', type=str, default=None)
00053         parser.add_argument('--progressbar_update_interval', type=float,
00054                             default=10)
00055         parser.add_argument('--print_interval', type=float, default=100)
00056         parser.add_argument('--print_interval_type', type=str,
00057                             default='iteration',
00058                             choices=['epoch', 'iteration'])
00059         parser.add_argument('--log_interval', type=float, default=10)
00060         parser.add_argument('--log_interval_type', type=str,
00061                             default='iteration',
00062                             choices=['epoch', 'iteration'])
00063         parser.add_argument('--plot_interval', type=float, default=5)
00064         parser.add_argument('--plot_interval_type', type=str,
00065                             default='epoch',
00066                             choices=['epoch', 'iteration'])
00067         parser.add_argument('--eval_interval', type=float, default=10)
00068         parser.add_argument('--eval_interval_type', type=str,
00069                             default='epoch',
00070                             choices=['epoch', 'iteration'])
00071         parser.add_argument('--save_interval', type=float, default=10)
00072         parser.add_argument('--save_interval_type', type=str,
00073                             default='epoch',
00074                             choices=['epoch', 'iteration'])
00075 
00076         args = parser.parse_args()
00077 
00078         self.train_dataset_dir = args.train_dataset_dir
00079         self.val_dataset_dir = args.val_dataset_dir
00080         self.model_name = args.model_name
00081         self.gpu = args.gpu
00082         self.batch_size = args.batch_size
00083         self.max_epoch = args.max_epoch
00084         self.lr = args.lr
00085         self.weight_decay = args.weight_decay
00086         self.out_dir = args.out_dir
00087         self.progressbar_update_interval = args.progressbar_update_interval
00088         self.print_interval = args.print_interval
00089         self.print_interval_type = args.print_interval_type
00090         self.log_interval = args.log_interval
00091         self.log_interval_type = args.log_interval_type
00092         self.plot_interval = args.plot_interval
00093         self.plot_interval_type = args.plot_interval_type
00094         self.eval_interval = args.eval_interval
00095         self.eval_interval_type = args.eval_interval_type
00096         self.save_interval = args.save_interval
00097         self.save_interval_type = args.save_interval_type
00098 
00099         now = datetime.datetime.now()
00100         self.timestamp_iso = now.isoformat()
00101         timestamp = now.strftime('%Y%m%d-%H%M%S')
00102         if self.out_dir is None:
00103             self.out_dir = osp.join(
00104                 rospkg.get_ros_home(), 'learning_logs', timestamp)
00105 
00106         # Main process
00107         self.load_dataset()
00108         self.load_model()
00109         self.setup_optimizer()
00110         self.setup_iterator()
00111         self.setup_trainer()
00112         self.trainer.run()
00113 
00114     def load_dataset(self):
00115         self.train_dataset = InstanceSegmentationDataset(
00116             self.train_dataset_dir)
00117         self.val_dataset = InstanceSegmentationDataset(self.val_dataset_dir)
00118 
00119     def load_model(self):
00120         n_fg_class = len(self.train_dataset.fg_class_names)
00121 
00122         pooling_func = cmr.functions.roi_align_2d
00123         anchor_scales = (4, 8, 16, 32)
00124         roi_size = 14
00125         min_size = 600
00126         max_size = 1000
00127         mask_initialW = chainer.initializers.Normal(0.01)
00128 
00129         if self.model_name == 'vgg16':
00130             self.mask_rcnn = cmr.models.MaskRCNNVGG16(
00131                 n_fg_class=n_fg_class,
00132                 pretrained_model='imagenet',
00133                 pooling_func=pooling_func,
00134                 anchor_scales=anchor_scales,
00135                 roi_size=roi_size,
00136                 min_size=min_size,
00137                 max_size=max_size,
00138                 mask_initialW=mask_initialW,
00139             )
00140         elif self.model_name in ['resnet50', 'resnet101']:
00141             n_layers = int(self.model_name.lstrip('resnet'))
00142             self.mask_rcnn = cmr.models.MaskRCNNResNet(
00143                 n_layers=n_layers,
00144                 n_fg_class=n_fg_class,
00145                 pooling_func=pooling_func,
00146                 anchor_scales=anchor_scales,
00147                 roi_size=roi_size,
00148                 min_size=min_size,
00149                 max_size=max_size,
00150                 mask_initialW=mask_initialW,
00151             )
00152         else:
00153             raise ValueError(
00154                 'Unsupported model_name: {}'.format(self.model_name))
00155         self.model = cmr.models.MaskRCNNTrainChain(self.mask_rcnn)
00156 
00157         if self.gpu >= 0:
00158             cuda.get_device_from_id(self.gpu).use()
00159             self.model.to_gpu()
00160 
00161     def setup_optimizer(self):
00162         self.optimizer = chainer.optimizers.MomentumSGD(
00163             lr=self.lr, momentum=0.9)
00164         self.optimizer.setup(self.model)
00165         self.optimizer.add_hook(
00166             chainer.optimizer.WeightDecay(rate=self.weight_decay))
00167 
00168         if self.model_name in ['resnet50', 'resnet101']:
00169             # ResNetExtractor.freeze_at is not enough to freeze params
00170             # since WeightDecay updates the param little by little.
00171             self.mask_rcnn.extractor.conv1.disable_update()
00172             self.mask_rcnn.extractor.bn1.disable_update()
00173             self.mask_rcnn.extractor.res2.disable_update()
00174             for link in self.mask_rcnn.links():
00175                 if isinstance(link, cmr.links.AffineChannel2D):
00176                     link.disable_update()
00177 
00178     def setup_iterator(self):
00179         train_dataset_transformed = TransformDataset(
00180             self.train_dataset, cmr.datasets.MaskRCNNTransform(self.mask_rcnn))
00181         val_dataset_transformed = TransformDataset(
00182             self.val_dataset,
00183             cmr.datasets.MaskRCNNTransform(self.mask_rcnn, train=False))
00184         # FIXME: MultiProcessIterator sometimes hangs
00185         self.train_iterator = chainer.iterators.SerialIterator(
00186             train_dataset_transformed, batch_size=self.batch_size)
00187         self.val_iterator = chainer.iterators.SerialIterator(
00188             val_dataset_transformed, batch_size=self.batch_size,
00189             repeat=False, shuffle=False)
00190 
00191     def setup_trainer(self):
00192         converter = functools.partial(
00193             cmr.datasets.concat_examples,
00194             padding=0,
00195             # img, bboxes, labels, masks, scales
00196             indices_concat=[0, 2, 3, 4],  # img, _, labels, masks, scales
00197             indices_to_device=[0, 1],  # img, bbox
00198         )
00199         self.updater = chainer.training.updater.StandardUpdater(
00200             self.train_iterator, self.optimizer, device=self.gpu,
00201             converter=converter)
00202         self.trainer = chainer.training.Trainer(
00203             self.updater, (self.max_epoch, 'epoch'), out=self.out_dir)
00204 
00205         step_size = [
00206             (120e3 / 180e3) * self.max_epoch,
00207             (160e3 / 180e3) * self.max_epoch,
00208         ]
00209         self.trainer.extend(
00210             extensions.ExponentialShift('lr', 0.1),
00211             trigger=chainer.training.triggers.ManualScheduleTrigger(
00212                 step_size, 'epoch'))
00213 
00214         evaluator = cmr.extensions.InstanceSegmentationVOCEvaluator(
00215             self.val_iterator, self.model.mask_rcnn, device=self.gpu,
00216             use_07_metric=True, label_names=self.train_dataset.fg_class_names)
00217         self.trainer.extend(
00218             evaluator, trigger=(self.eval_interval, self.eval_interval_type))
00219 
00220         # Save snapshot
00221         self.trainer.extend(
00222             extensions.snapshot_object(
00223                 self.model.mask_rcnn, 'snapshot_model.npz'),
00224             trigger=chainer.training.triggers.MaxValueTrigger(
00225                 'validation/main/map',
00226                 (self.save_interval, self.save_interval_type)))
00227 
00228         # Dump network architecture
00229         self.trainer.extend(
00230             extensions.dump_graph(
00231                 root_name='main/loss',
00232                 out_name='network_architecture.dot'))
00233 
00234         # Logging
00235         self.trainer.extend(
00236             extensions.ProgressBar(
00237                 update_interval=self.progressbar_update_interval))
00238         self.trainer.extend(
00239             extensions.observe_lr(),
00240             trigger=(self.log_interval, self.log_interval_type))
00241         self.trainer.extend(
00242             extensions.LogReport(
00243                 log_name='log.json',
00244                 trigger=(self.log_interval, self.log_interval_type)))
00245         self.trainer.extend(
00246             extensions.PrintReport([
00247                 'iteration',
00248                 'epoch',
00249                 'elapsed_time',
00250                 'lr',
00251                 'main/loss',
00252                 'main/roi_loc_loss',
00253                 'main/roi_cls_loss',
00254                 'main/roi_mask_loss',
00255                 'main/rpn_loc_loss',
00256                 'main/rpn_cls_loss',
00257                 'validation/main/map',
00258             ]), trigger=(self.print_interval, self.print_interval_type))
00259 
00260         # Plot
00261         self.trainer.extend(
00262             extensions.PlotReport([
00263                 'main/loss',
00264                 'main/roi_loc_loss',
00265                 'main/roi_cls_loss',
00266                 'main/roi_mask_loss',
00267                 'main/rpn_loc_loss',
00268                 'main/rpn_cls_loss',
00269             ],
00270                 file_name='loss_plot.png',
00271                 x_key=self.plot_interval_type,
00272                 trigger=(self.plot_interval, self.plot_interval_type)),
00273             trigger=(self.plot_interval, self.plot_interval_type))
00274         self.trainer.extend(
00275             extensions.PlotReport(
00276                 ['validation/main/map'],
00277                 file_name='accuracy_plot.png',
00278                 x_key=self.plot_interval_type,
00279                 trigger=(self.plot_interval, self.plot_interval_type)),
00280             trigger=(self.eval_interval, self.eval_interval_type))
00281 
00282         # Dump params
00283         params = dict()
00284         params['model_name'] = self.model_name
00285         params['train_dataset_dir'] = self.train_dataset_dir
00286         params['val_dataset_dir'] = self.val_dataset_dir
00287         params['fg_class_names'] = self.train_dataset.fg_class_names
00288         params['timestamp'] = self.timestamp_iso
00289         params['out_dir'] = self.out_dir
00290         params['gpu'] = self.gpu
00291         params['batch_size'] = self.batch_size
00292         params['max_epoch'] = self.max_epoch
00293         params['lr'] = self.lr
00294         params['weight_decay'] = self.weight_decay
00295         self.trainer.extend(
00296             fcn.extensions.ParamsReport(params, file_name='params.yaml'))
00297 
00298         # Dump param for mask_rcnn_instance_segmentation.py
00299         target_names = dict()
00300         target_names['fg_class_names'] = self.train_dataset.fg_class_names
00301         self.trainer.extend(
00302             fcn.extensions.ParamsReport(
00303                 target_names, file_name='fg_class_names.yaml'))
00304 
00305 
00306 if __name__ == '__main__':
00307     app = TrainMaskRCNN()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07