00001
00002
00003 import argparse
00004 import datetime
00005 import os
00006 import os.path as osp
00007
00008 os.environ['MPLBACKEND'] = 'Agg'
00009
00010 import chainer
00011 from chainer import cuda
00012 from chainer.datasets import TransformDataset
00013 import chainer.serializers as S
00014 from chainer.training import extensions
00015 import fcn
00016 import numpy as np
00017
00018 from jsk_recognition_utils.datasets import SemanticSegmentationDataset
00019 import rospkg
00020
00021
00022 class TrainFCN(object):
00023
00024 def __init__(self):
00025 rospack = rospkg.RosPack()
00026 jsk_perception_datasets_path = osp.join(
00027 rospack.get_path('jsk_perception'), 'learning_datasets')
00028
00029 parser = argparse.ArgumentParser()
00030
00031
00032 parser.add_argument('--train_dataset_dir', type=str,
00033 default=osp.join(jsk_perception_datasets_path,
00034 'kitchen_dataset', 'train'))
00035 parser.add_argument('--val_dataset_dir', type=str,
00036 default=osp.join(jsk_perception_datasets_path,
00037 'kitchen_dataset', 'test'))
00038
00039
00040 parser.add_argument(
00041 '--model_name', type=str, default='fcn32s',
00042 choices=['fcn32s', 'fcn16s', 'fcn8s', 'fcn8s_at_once'])
00043
00044
00045 parser.add_argument('--gpu', type=int, default=0)
00046 parser.add_argument('--batch_size', type=int, default=1)
00047 parser.add_argument('--max_epoch', type=int, default=100)
00048 parser.add_argument('--lr', type=float, default=1e-10)
00049 parser.add_argument('--weight_decay', type=float, default=0.0001)
00050 parser.add_argument('--out_dir', type=str, default=None)
00051 parser.add_argument('--progressbar_update_interval', type=float,
00052 default=10)
00053 parser.add_argument('--print_interval', type=float, default=100)
00054 parser.add_argument('--print_interval_type', type=str,
00055 default='iteration',
00056 choices=['epoch', 'iteration'])
00057 parser.add_argument('--log_interval', type=float, default=10)
00058 parser.add_argument('--log_interval_type', type=str,
00059 default='iteration',
00060 choices=['epoch', 'iteration'])
00061 parser.add_argument('--plot_interval', type=float, default=5)
00062 parser.add_argument('--plot_interval_type', type=str,
00063 default='epoch',
00064 choices=['epoch', 'iteration'])
00065 parser.add_argument('--eval_interval', type=float, default=10)
00066 parser.add_argument('--eval_interval_type', type=str,
00067 default='epoch',
00068 choices=['epoch', 'iteration'])
00069 parser.add_argument('--save_interval', type=float, default=10)
00070 parser.add_argument('--save_interval_type', type=str,
00071 default='epoch',
00072 choices=['epoch', 'iteration'])
00073
00074 args = parser.parse_args()
00075
00076 self.train_dataset_dir = args.train_dataset_dir
00077 self.val_dataset_dir = args.val_dataset_dir
00078 self.model_name = args.model_name
00079 self.gpu = args.gpu
00080 self.batch_size = args.batch_size
00081 self.max_epoch = args.max_epoch
00082 self.lr = args.lr
00083 self.weight_decay = args.weight_decay
00084 self.out_dir = args.out_dir
00085 self.progressbar_update_interval = args.progressbar_update_interval
00086 self.print_interval = args.print_interval
00087 self.print_interval_type = args.print_interval_type
00088 self.log_interval = args.log_interval
00089 self.log_interval_type = args.log_interval_type
00090 self.plot_interval = args.plot_interval
00091 self.plot_interval_type = args.plot_interval_type
00092 self.eval_interval = args.eval_interval
00093 self.eval_interval_type = args.eval_interval_type
00094 self.save_interval = args.save_interval
00095 self.save_interval_type = args.save_interval_type
00096
00097 now = datetime.datetime.now()
00098 self.timestamp_iso = now.isoformat()
00099 timestamp = now.strftime('%Y%m%d-%H%M%S')
00100 if self.out_dir is None:
00101 self.out_dir = osp.join(
00102 rospkg.get_ros_home(), 'learning_logs', timestamp)
00103
00104
00105 self.load_dataset()
00106 self.setup_iterator()
00107 self.load_model()
00108 self.setup_optimizer()
00109 self.setup_trainer()
00110 self.trainer.run()
00111
00112 def load_dataset(self):
00113 self.train_dataset = SemanticSegmentationDataset(
00114 self.train_dataset_dir)
00115 self.val_dataset = SemanticSegmentationDataset(self.val_dataset_dir)
00116
00117 def transform_dataset(self, in_data):
00118 rgb_img, lbl = in_data
00119
00120 mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
00121 bgr_img = rgb_img[:, :, ::-1]
00122 bgr_img = bgr_img.astype(np.float32)
00123 bgr_img -= mean_bgr
00124
00125 bgr_img = bgr_img.transpose((2, 0, 1))
00126
00127 return bgr_img, lbl
00128
00129 def setup_iterator(self):
00130 train_dataset_transformed = TransformDataset(
00131 self.train_dataset, self.transform_dataset)
00132 val_dataset_transformed = TransformDataset(
00133 self.val_dataset, self.transform_dataset)
00134 self.train_iterator = chainer.iterators.MultiprocessIterator(
00135 train_dataset_transformed, batch_size=self.batch_size,
00136 shared_mem=10 ** 7)
00137 self.val_iterator = chainer.iterators.MultiprocessIterator(
00138 val_dataset_transformed, batch_size=self.batch_size,
00139 shared_mem=10 ** 7, repeat=False, shuffle=False)
00140
00141 def load_model(self):
00142 n_class = len(self.train_dataset.class_names)
00143 if self.model_name == 'fcn32s':
00144 self.model = fcn.models.FCN32s(n_class=n_class)
00145 vgg = fcn.models.VGG16()
00146 vgg_path = vgg.download()
00147 S.load_npz(vgg_path, vgg)
00148 self.model.init_from_vgg16(vgg)
00149 elif self.model_name == 'fcn16s':
00150 self.model = fcn.models.FCN16s(n_class=n_class)
00151 fcn32s = fcn.models.FCN32s()
00152 fcn32s_path = fcn32s.download()
00153 S.load_npz(fcn32s_path, fcn32s)
00154 self.model.init_from_fcn32s(fcn32s_path, fcn32s)
00155 elif self.model_name == 'fcn8s':
00156 self.model = fcn.models.FCN8s(n_class=n_class)
00157 fcn16s = fcn.models.FCN16s()
00158 fcn16s_path = fcn16s.download()
00159 S.load_npz(fcn16s_path, fcn16s)
00160 self.model.init_from_fcn16s(fcn16s_path, fcn16s)
00161 elif self.model_name == 'fcn8s_at_once':
00162 self.model = fcn.models.FCN8sAtOnce(n_class=n_class)
00163 vgg = fcn.models.VGG16()
00164 vgg_path = vgg.download()
00165 S.load_npz(vgg_path, vgg)
00166 self.model.init_from_vgg16(vgg)
00167 else:
00168 raise ValueError(
00169 'Unsupported model_name: {}'.format(self.model_name))
00170
00171 if self.gpu >= 0:
00172 cuda.get_device_from_id(self.gpu).use()
00173 self.model.to_gpu()
00174
00175 def setup_optimizer(self):
00176 self.optimizer = chainer.optimizers.MomentumSGD(
00177 lr=self.lr, momentum=0.9)
00178 self.optimizer.setup(self.model)
00179 self.optimizer.add_hook(
00180 chainer.optimizer.WeightDecay(rate=self.weight_decay))
00181
00182 def setup_trainer(self):
00183 self.updater = chainer.training.updater.StandardUpdater(
00184 self.train_iterator, self.optimizer, device=self.gpu)
00185 self.trainer = chainer.training.Trainer(
00186 self.updater, (self.max_epoch, 'epoch'), out=self.out_dir)
00187
00188 self.trainer.extend(
00189 extensions.Evaluator(
00190 self.val_iterator, self.model, device=self.gpu),
00191 trigger=(self.eval_interval, self.eval_interval_type))
00192
00193
00194 self.trainer.extend(
00195 extensions.snapshot_object(
00196 self.model,
00197 savefun=S.save_npz,
00198 filename='model_snapshot.npz'),
00199 trigger=chainer.training.triggers.MinValueTrigger(
00200 'validation/main/loss',
00201 (self.save_interval, self.save_interval_type)))
00202
00203
00204 self.trainer.extend(
00205 extensions.dump_graph(
00206 root_name='main/loss',
00207 out_name='network_architecture.dot'))
00208
00209
00210 self.trainer.extend(
00211 extensions.ProgressBar(
00212 update_interval=self.progressbar_update_interval))
00213 self.trainer.extend(
00214 extensions.observe_lr(),
00215 trigger=(self.log_interval, self.log_interval_type))
00216 self.trainer.extend(
00217 extensions.LogReport(
00218 log_name='log.json',
00219 trigger=(self.log_interval, self.log_interval_type)))
00220 self.trainer.extend(
00221 extensions.PrintReport([
00222 'iteration',
00223 'epoch',
00224 'elapsed_time',
00225 'lr',
00226 'main/loss',
00227 'validation/main/loss',
00228 ]), trigger=(self.print_interval, self.print_interval_type))
00229
00230
00231 self.trainer.extend(
00232 extensions.PlotReport([
00233 'main/loss',
00234 'validation/main/loss',
00235 ],
00236 file_name='loss_plot.png',
00237 x_key=self.plot_interval_type,
00238 trigger=(self.plot_interval, self.plot_interval_type)),
00239 trigger=(self.plot_interval, self.plot_interval_type))
00240
00241
00242 params = dict()
00243 params['model_name'] = self.model_name
00244 params['train_dataset_dir'] = self.train_dataset_dir
00245 params['val_dataset_dir'] = self.val_dataset_dir
00246 params['class_names'] = self.train_dataset.class_names
00247 params['timestamp'] = self.timestamp_iso
00248 params['out_dir'] = self.out_dir
00249 params['gpu'] = self.gpu
00250 params['batch_size'] = self.batch_size
00251 params['max_epoch'] = self.max_epoch
00252 params['lr'] = self.lr
00253 params['weight_decay'] = self.weight_decay
00254 self.trainer.extend(
00255 fcn.extensions.ParamsReport(params, file_name='params.yaml'))
00256
00257
00258 model_name = dict()
00259 model_name['model_name'] = self.model_name
00260 self.trainer.extend(
00261 fcn.extensions.ParamsReport(
00262 model_name, file_name='model_name.yaml'))
00263 target_names = dict()
00264 target_names['target_names'] = self.train_dataset.class_names
00265 self.trainer.extend(
00266 fcn.extensions.ParamsReport(
00267 target_names, file_name='target_names.yaml'))
00268
00269
00270 if __name__ == '__main__':
00271 app = TrainFCN()