00001
00002
00003 from __future__ import division
00004
00005 import argparse
00006 import datetime
00007 import functools
00008 import os
00009 import os.path as osp
00010
00011 os.environ['MPLBACKEND'] = 'Agg'
00012
00013 import chainer
00014 from chainer import cuda
00015 from chainer.datasets import TransformDataset
00016 from chainer.training import extensions
00017 import chainer_mask_rcnn as cmr
00018 import fcn
00019
00020 from jsk_recognition_utils.datasets import InstanceSegmentationDataset
00021 import rospkg
00022
00023
00024 class TrainMaskRCNN(object):
00025
00026 def __init__(self):
00027 rospack = rospkg.RosPack()
00028 jsk_perception_datasets_path = osp.join(
00029 rospack.get_path('jsk_perception'), 'learning_datasets')
00030
00031 parser = argparse.ArgumentParser()
00032
00033
00034 parser.add_argument('--train_dataset_dir', type=str,
00035 default=osp.join(jsk_perception_datasets_path,
00036 'kitchen_dataset', 'train'))
00037 parser.add_argument('--val_dataset_dir', type=str,
00038 default=osp.join(jsk_perception_datasets_path,
00039 'kitchen_dataset', 'test'))
00040
00041
00042 parser.add_argument(
00043 '--model_name', type=str, default='resnet50',
00044 choices=['vgg16', 'resnet50', 'resnet101'])
00045
00046
00047 parser.add_argument('--gpu', type=int, default=0)
00048 parser.add_argument('--batch_size', type=int, default=1)
00049 parser.add_argument('--max_epoch', type=int, default=100)
00050 parser.add_argument('--lr', type=float, default=0.00125)
00051 parser.add_argument('--weight_decay', type=float, default=0.0001)
00052 parser.add_argument('--out_dir', type=str, default=None)
00053 parser.add_argument('--progressbar_update_interval', type=float,
00054 default=10)
00055 parser.add_argument('--print_interval', type=float, default=100)
00056 parser.add_argument('--print_interval_type', type=str,
00057 default='iteration',
00058 choices=['epoch', 'iteration'])
00059 parser.add_argument('--log_interval', type=float, default=10)
00060 parser.add_argument('--log_interval_type', type=str,
00061 default='iteration',
00062 choices=['epoch', 'iteration'])
00063 parser.add_argument('--plot_interval', type=float, default=5)
00064 parser.add_argument('--plot_interval_type', type=str,
00065 default='epoch',
00066 choices=['epoch', 'iteration'])
00067 parser.add_argument('--eval_interval', type=float, default=10)
00068 parser.add_argument('--eval_interval_type', type=str,
00069 default='epoch',
00070 choices=['epoch', 'iteration'])
00071 parser.add_argument('--save_interval', type=float, default=10)
00072 parser.add_argument('--save_interval_type', type=str,
00073 default='epoch',
00074 choices=['epoch', 'iteration'])
00075
00076 args = parser.parse_args()
00077
00078 self.train_dataset_dir = args.train_dataset_dir
00079 self.val_dataset_dir = args.val_dataset_dir
00080 self.model_name = args.model_name
00081 self.gpu = args.gpu
00082 self.batch_size = args.batch_size
00083 self.max_epoch = args.max_epoch
00084 self.lr = args.lr
00085 self.weight_decay = args.weight_decay
00086 self.out_dir = args.out_dir
00087 self.progressbar_update_interval = args.progressbar_update_interval
00088 self.print_interval = args.print_interval
00089 self.print_interval_type = args.print_interval_type
00090 self.log_interval = args.log_interval
00091 self.log_interval_type = args.log_interval_type
00092 self.plot_interval = args.plot_interval
00093 self.plot_interval_type = args.plot_interval_type
00094 self.eval_interval = args.eval_interval
00095 self.eval_interval_type = args.eval_interval_type
00096 self.save_interval = args.save_interval
00097 self.save_interval_type = args.save_interval_type
00098
00099 now = datetime.datetime.now()
00100 self.timestamp_iso = now.isoformat()
00101 timestamp = now.strftime('%Y%m%d-%H%M%S')
00102 if self.out_dir is None:
00103 self.out_dir = osp.join(
00104 rospkg.get_ros_home(), 'learning_logs', timestamp)
00105
00106
00107 self.load_dataset()
00108 self.load_model()
00109 self.setup_optimizer()
00110 self.setup_iterator()
00111 self.setup_trainer()
00112 self.trainer.run()
00113
00114 def load_dataset(self):
00115 self.train_dataset = InstanceSegmentationDataset(
00116 self.train_dataset_dir)
00117 self.val_dataset = InstanceSegmentationDataset(self.val_dataset_dir)
00118
00119 def load_model(self):
00120 n_fg_class = len(self.train_dataset.fg_class_names)
00121
00122 pooling_func = cmr.functions.roi_align_2d
00123 anchor_scales = (4, 8, 16, 32)
00124 roi_size = 14
00125 min_size = 600
00126 max_size = 1000
00127 mask_initialW = chainer.initializers.Normal(0.01)
00128
00129 if self.model_name == 'vgg16':
00130 self.mask_rcnn = cmr.models.MaskRCNNVGG16(
00131 n_fg_class=n_fg_class,
00132 pretrained_model='imagenet',
00133 pooling_func=pooling_func,
00134 anchor_scales=anchor_scales,
00135 roi_size=roi_size,
00136 min_size=min_size,
00137 max_size=max_size,
00138 mask_initialW=mask_initialW,
00139 )
00140 elif self.model_name in ['resnet50', 'resnet101']:
00141 n_layers = int(self.model_name.lstrip('resnet'))
00142 self.mask_rcnn = cmr.models.MaskRCNNResNet(
00143 n_layers=n_layers,
00144 n_fg_class=n_fg_class,
00145 pooling_func=pooling_func,
00146 anchor_scales=anchor_scales,
00147 roi_size=roi_size,
00148 min_size=min_size,
00149 max_size=max_size,
00150 mask_initialW=mask_initialW,
00151 )
00152 else:
00153 raise ValueError(
00154 'Unsupported model_name: {}'.format(self.model_name))
00155 self.model = cmr.models.MaskRCNNTrainChain(self.mask_rcnn)
00156
00157 if self.gpu >= 0:
00158 cuda.get_device_from_id(self.gpu).use()
00159 self.model.to_gpu()
00160
00161 def setup_optimizer(self):
00162 self.optimizer = chainer.optimizers.MomentumSGD(
00163 lr=self.lr, momentum=0.9)
00164 self.optimizer.setup(self.model)
00165 self.optimizer.add_hook(
00166 chainer.optimizer.WeightDecay(rate=self.weight_decay))
00167
00168 if self.model_name in ['resnet50', 'resnet101']:
00169
00170
00171 self.mask_rcnn.extractor.conv1.disable_update()
00172 self.mask_rcnn.extractor.bn1.disable_update()
00173 self.mask_rcnn.extractor.res2.disable_update()
00174 for link in self.mask_rcnn.links():
00175 if isinstance(link, cmr.links.AffineChannel2D):
00176 link.disable_update()
00177
00178 def setup_iterator(self):
00179 train_dataset_transformed = TransformDataset(
00180 self.train_dataset, cmr.datasets.MaskRCNNTransform(self.mask_rcnn))
00181 val_dataset_transformed = TransformDataset(
00182 self.val_dataset,
00183 cmr.datasets.MaskRCNNTransform(self.mask_rcnn, train=False))
00184
00185 self.train_iterator = chainer.iterators.SerialIterator(
00186 train_dataset_transformed, batch_size=self.batch_size)
00187 self.val_iterator = chainer.iterators.SerialIterator(
00188 val_dataset_transformed, batch_size=self.batch_size,
00189 repeat=False, shuffle=False)
00190
00191 def setup_trainer(self):
00192 converter = functools.partial(
00193 cmr.datasets.concat_examples,
00194 padding=0,
00195
00196 indices_concat=[0, 2, 3, 4],
00197 indices_to_device=[0, 1],
00198 )
00199 self.updater = chainer.training.updater.StandardUpdater(
00200 self.train_iterator, self.optimizer, device=self.gpu,
00201 converter=converter)
00202 self.trainer = chainer.training.Trainer(
00203 self.updater, (self.max_epoch, 'epoch'), out=self.out_dir)
00204
00205 step_size = [
00206 (120e3 / 180e3) * self.max_epoch,
00207 (160e3 / 180e3) * self.max_epoch,
00208 ]
00209 self.trainer.extend(
00210 extensions.ExponentialShift('lr', 0.1),
00211 trigger=chainer.training.triggers.ManualScheduleTrigger(
00212 step_size, 'epoch'))
00213
00214 evaluator = cmr.extensions.InstanceSegmentationVOCEvaluator(
00215 self.val_iterator, self.model.mask_rcnn, device=self.gpu,
00216 use_07_metric=True, label_names=self.train_dataset.fg_class_names)
00217 self.trainer.extend(
00218 evaluator, trigger=(self.eval_interval, self.eval_interval_type))
00219
00220
00221 self.trainer.extend(
00222 extensions.snapshot_object(
00223 self.model.mask_rcnn, 'snapshot_model.npz'),
00224 trigger=chainer.training.triggers.MaxValueTrigger(
00225 'validation/main/map',
00226 (self.save_interval, self.save_interval_type)))
00227
00228
00229 self.trainer.extend(
00230 extensions.dump_graph(
00231 root_name='main/loss',
00232 out_name='network_architecture.dot'))
00233
00234
00235 self.trainer.extend(
00236 extensions.ProgressBar(
00237 update_interval=self.progressbar_update_interval))
00238 self.trainer.extend(
00239 extensions.observe_lr(),
00240 trigger=(self.log_interval, self.log_interval_type))
00241 self.trainer.extend(
00242 extensions.LogReport(
00243 log_name='log.json',
00244 trigger=(self.log_interval, self.log_interval_type)))
00245 self.trainer.extend(
00246 extensions.PrintReport([
00247 'iteration',
00248 'epoch',
00249 'elapsed_time',
00250 'lr',
00251 'main/loss',
00252 'main/roi_loc_loss',
00253 'main/roi_cls_loss',
00254 'main/roi_mask_loss',
00255 'main/rpn_loc_loss',
00256 'main/rpn_cls_loss',
00257 'validation/main/map',
00258 ]), trigger=(self.print_interval, self.print_interval_type))
00259
00260
00261 self.trainer.extend(
00262 extensions.PlotReport([
00263 'main/loss',
00264 'main/roi_loc_loss',
00265 'main/roi_cls_loss',
00266 'main/roi_mask_loss',
00267 'main/rpn_loc_loss',
00268 'main/rpn_cls_loss',
00269 ],
00270 file_name='loss_plot.png',
00271 x_key=self.plot_interval_type,
00272 trigger=(self.plot_interval, self.plot_interval_type)),
00273 trigger=(self.plot_interval, self.plot_interval_type))
00274 self.trainer.extend(
00275 extensions.PlotReport(
00276 ['validation/main/map'],
00277 file_name='accuracy_plot.png',
00278 x_key=self.plot_interval_type,
00279 trigger=(self.plot_interval, self.plot_interval_type)),
00280 trigger=(self.eval_interval, self.eval_interval_type))
00281
00282
00283 params = dict()
00284 params['model_name'] = self.model_name
00285 params['train_dataset_dir'] = self.train_dataset_dir
00286 params['val_dataset_dir'] = self.val_dataset_dir
00287 params['fg_class_names'] = self.train_dataset.fg_class_names
00288 params['timestamp'] = self.timestamp_iso
00289 params['out_dir'] = self.out_dir
00290 params['gpu'] = self.gpu
00291 params['batch_size'] = self.batch_size
00292 params['max_epoch'] = self.max_epoch
00293 params['lr'] = self.lr
00294 params['weight_decay'] = self.weight_decay
00295 self.trainer.extend(
00296 fcn.extensions.ParamsReport(params, file_name='params.yaml'))
00297
00298
00299 target_names = dict()
00300 target_names['fg_class_names'] = self.train_dataset.fg_class_names
00301 self.trainer.extend(
00302 fcn.extensions.ParamsReport(
00303 target_names, file_name='fg_class_names.yaml'))
00304
00305
00306 if __name__ == '__main__':
00307 app = TrainMaskRCNN()