train_fcn.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import argparse
00004 import datetime
00005 import os
00006 import os.path as osp
00007 
00008 os.environ['MPLBACKEND'] = 'Agg'  # NOQA
00009 
00010 import chainer
00011 from chainer import cuda
00012 from chainer.datasets import TransformDataset
00013 import chainer.serializers as S
00014 from chainer.training import extensions
00015 import fcn
00016 import numpy as np
00017 
00018 from jsk_recognition_utils.datasets import SemanticSegmentationDataset
00019 import rospkg
00020 
00021 
00022 class TrainFCN(object):
00023 
00024     def __init__(self):
00025         rospack = rospkg.RosPack()
00026         jsk_perception_datasets_path = osp.join(
00027             rospack.get_path('jsk_perception'), 'learning_datasets')
00028 
00029         parser = argparse.ArgumentParser()
00030 
00031         # Dataset directory
00032         parser.add_argument('--train_dataset_dir', type=str,
00033                             default=osp.join(jsk_perception_datasets_path,
00034                                              'kitchen_dataset', 'train'))
00035         parser.add_argument('--val_dataset_dir', type=str,
00036                             default=osp.join(jsk_perception_datasets_path,
00037                                              'kitchen_dataset', 'test'))
00038 
00039         # Model
00040         parser.add_argument(
00041             '--model_name', type=str, default='fcn32s',
00042             choices=['fcn32s', 'fcn16s', 'fcn8s', 'fcn8s_at_once'])
00043 
00044         # Training parameters
00045         parser.add_argument('--gpu', type=int, default=0)
00046         parser.add_argument('--batch_size', type=int, default=1)
00047         parser.add_argument('--max_epoch', type=int, default=100)
00048         parser.add_argument('--lr', type=float, default=1e-10)
00049         parser.add_argument('--weight_decay', type=float, default=0.0001)
00050         parser.add_argument('--out_dir', type=str, default=None)
00051         parser.add_argument('--progressbar_update_interval', type=float,
00052                             default=10)
00053         parser.add_argument('--print_interval', type=float, default=100)
00054         parser.add_argument('--print_interval_type', type=str,
00055                             default='iteration',
00056                             choices=['epoch', 'iteration'])
00057         parser.add_argument('--log_interval', type=float, default=10)
00058         parser.add_argument('--log_interval_type', type=str,
00059                             default='iteration',
00060                             choices=['epoch', 'iteration'])
00061         parser.add_argument('--plot_interval', type=float, default=5)
00062         parser.add_argument('--plot_interval_type', type=str,
00063                             default='epoch',
00064                             choices=['epoch', 'iteration'])
00065         parser.add_argument('--eval_interval', type=float, default=10)
00066         parser.add_argument('--eval_interval_type', type=str,
00067                             default='epoch',
00068                             choices=['epoch', 'iteration'])
00069         parser.add_argument('--save_interval', type=float, default=10)
00070         parser.add_argument('--save_interval_type', type=str,
00071                             default='epoch',
00072                             choices=['epoch', 'iteration'])
00073 
00074         args = parser.parse_args()
00075 
00076         self.train_dataset_dir = args.train_dataset_dir
00077         self.val_dataset_dir = args.val_dataset_dir
00078         self.model_name = args.model_name
00079         self.gpu = args.gpu
00080         self.batch_size = args.batch_size
00081         self.max_epoch = args.max_epoch
00082         self.lr = args.lr
00083         self.weight_decay = args.weight_decay
00084         self.out_dir = args.out_dir
00085         self.progressbar_update_interval = args.progressbar_update_interval
00086         self.print_interval = args.print_interval
00087         self.print_interval_type = args.print_interval_type
00088         self.log_interval = args.log_interval
00089         self.log_interval_type = args.log_interval_type
00090         self.plot_interval = args.plot_interval
00091         self.plot_interval_type = args.plot_interval_type
00092         self.eval_interval = args.eval_interval
00093         self.eval_interval_type = args.eval_interval_type
00094         self.save_interval = args.save_interval
00095         self.save_interval_type = args.save_interval_type
00096 
00097         now = datetime.datetime.now()
00098         self.timestamp_iso = now.isoformat()
00099         timestamp = now.strftime('%Y%m%d-%H%M%S')
00100         if self.out_dir is None:
00101             self.out_dir = osp.join(
00102                 rospkg.get_ros_home(), 'learning_logs', timestamp)
00103 
00104         # Main process
00105         self.load_dataset()
00106         self.setup_iterator()
00107         self.load_model()
00108         self.setup_optimizer()
00109         self.setup_trainer()
00110         self.trainer.run()
00111 
00112     def load_dataset(self):
00113         self.train_dataset = SemanticSegmentationDataset(
00114             self.train_dataset_dir)
00115         self.val_dataset = SemanticSegmentationDataset(self.val_dataset_dir)
00116 
00117     def transform_dataset(self, in_data):
00118         rgb_img, lbl = in_data
00119         # RGB -> BGR
00120         mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00121         bgr_img = rgb_img[:, :, ::-1]
00122         bgr_img = bgr_img.astype(np.float32)
00123         bgr_img -= mean_bgr
00124         # H, W, C -> C, H, W
00125         bgr_img = bgr_img.transpose((2, 0, 1))
00126 
00127         return bgr_img, lbl
00128 
00129     def setup_iterator(self):
00130         train_dataset_transformed = TransformDataset(
00131             self.train_dataset, self.transform_dataset)
00132         val_dataset_transformed = TransformDataset(
00133             self.val_dataset, self.transform_dataset)
00134         self.train_iterator = chainer.iterators.MultiprocessIterator(
00135             train_dataset_transformed, batch_size=self.batch_size,
00136             shared_mem=10 ** 7)
00137         self.val_iterator = chainer.iterators.MultiprocessIterator(
00138             val_dataset_transformed, batch_size=self.batch_size,
00139             shared_mem=10 ** 7, repeat=False, shuffle=False)
00140 
00141     def load_model(self):
00142         n_class = len(self.train_dataset.class_names)
00143         if self.model_name == 'fcn32s':
00144             self.model = fcn.models.FCN32s(n_class=n_class)
00145             vgg = fcn.models.VGG16()
00146             vgg_path = vgg.download()
00147             S.load_npz(vgg_path, vgg)
00148             self.model.init_from_vgg16(vgg)
00149         elif self.model_name == 'fcn16s':
00150             self.model = fcn.models.FCN16s(n_class=n_class)
00151             fcn32s = fcn.models.FCN32s()
00152             fcn32s_path = fcn32s.download()
00153             S.load_npz(fcn32s_path, fcn32s)
00154             self.model.init_from_fcn32s(fcn32s_path, fcn32s)
00155         elif self.model_name == 'fcn8s':
00156             self.model = fcn.models.FCN8s(n_class=n_class)
00157             fcn16s = fcn.models.FCN16s()
00158             fcn16s_path = fcn16s.download()
00159             S.load_npz(fcn16s_path, fcn16s)
00160             self.model.init_from_fcn16s(fcn16s_path, fcn16s)
00161         elif self.model_name == 'fcn8s_at_once':
00162             self.model = fcn.models.FCN8sAtOnce(n_class=n_class)
00163             vgg = fcn.models.VGG16()
00164             vgg_path = vgg.download()
00165             S.load_npz(vgg_path, vgg)
00166             self.model.init_from_vgg16(vgg)
00167         else:
00168             raise ValueError(
00169                 'Unsupported model_name: {}'.format(self.model_name))
00170 
00171         if self.gpu >= 0:
00172             cuda.get_device_from_id(self.gpu).use()
00173             self.model.to_gpu()
00174 
00175     def setup_optimizer(self):
00176         self.optimizer = chainer.optimizers.MomentumSGD(
00177             lr=self.lr, momentum=0.9)
00178         self.optimizer.setup(self.model)
00179         self.optimizer.add_hook(
00180             chainer.optimizer.WeightDecay(rate=self.weight_decay))
00181 
00182     def setup_trainer(self):
00183         self.updater = chainer.training.updater.StandardUpdater(
00184             self.train_iterator, self.optimizer, device=self.gpu)
00185         self.trainer = chainer.training.Trainer(
00186             self.updater, (self.max_epoch, 'epoch'), out=self.out_dir)
00187 
00188         self.trainer.extend(
00189             extensions.Evaluator(
00190                 self.val_iterator, self.model, device=self.gpu),
00191             trigger=(self.eval_interval, self.eval_interval_type))
00192 
00193         # Save snapshot
00194         self.trainer.extend(
00195             extensions.snapshot_object(
00196                 self.model,
00197                 savefun=S.save_npz,
00198                 filename='model_snapshot.npz'),
00199             trigger=chainer.training.triggers.MinValueTrigger(
00200                 'validation/main/loss',
00201                 (self.save_interval, self.save_interval_type)))
00202 
00203         # Dump network architecture
00204         self.trainer.extend(
00205             extensions.dump_graph(
00206                 root_name='main/loss',
00207                 out_name='network_architecture.dot'))
00208 
00209         # Logging
00210         self.trainer.extend(
00211             extensions.ProgressBar(
00212                 update_interval=self.progressbar_update_interval))
00213         self.trainer.extend(
00214             extensions.observe_lr(),
00215             trigger=(self.log_interval, self.log_interval_type))
00216         self.trainer.extend(
00217             extensions.LogReport(
00218                 log_name='log.json',
00219                 trigger=(self.log_interval, self.log_interval_type)))
00220         self.trainer.extend(
00221             extensions.PrintReport([
00222                 'iteration',
00223                 'epoch',
00224                 'elapsed_time',
00225                 'lr',
00226                 'main/loss',
00227                 'validation/main/loss',
00228             ]), trigger=(self.print_interval, self.print_interval_type))
00229 
00230         # Plot
00231         self.trainer.extend(
00232             extensions.PlotReport([
00233                 'main/loss',
00234                 'validation/main/loss',
00235             ],
00236                 file_name='loss_plot.png',
00237                 x_key=self.plot_interval_type,
00238                 trigger=(self.plot_interval, self.plot_interval_type)),
00239             trigger=(self.plot_interval, self.plot_interval_type))
00240 
00241         # Dump params
00242         params = dict()
00243         params['model_name'] = self.model_name
00244         params['train_dataset_dir'] = self.train_dataset_dir
00245         params['val_dataset_dir'] = self.val_dataset_dir
00246         params['class_names'] = self.train_dataset.class_names
00247         params['timestamp'] = self.timestamp_iso
00248         params['out_dir'] = self.out_dir
00249         params['gpu'] = self.gpu
00250         params['batch_size'] = self.batch_size
00251         params['max_epoch'] = self.max_epoch
00252         params['lr'] = self.lr
00253         params['weight_decay'] = self.weight_decay
00254         self.trainer.extend(
00255             fcn.extensions.ParamsReport(params, file_name='params.yaml'))
00256 
00257         # Dump param for fcn_object_segmentation.py
00258         model_name = dict()
00259         model_name['model_name'] = self.model_name
00260         self.trainer.extend(
00261             fcn.extensions.ParamsReport(
00262                 model_name, file_name='model_name.yaml'))
00263         target_names = dict()
00264         target_names['target_names'] = self.train_dataset.class_names
00265         self.trainer.extend(
00266             fcn.extensions.ParamsReport(
00267                 target_names, file_name='target_names.yaml'))
00268 
00269 
00270 if __name__ == '__main__':
00271     app = TrainFCN()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07