train_fcn.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import print_function
4 
5 import argparse
6 import datetime
7 import os
8 import os.path as osp
9 
10 os.environ['MPLBACKEND'] = 'Agg' # NOQA
11 
12 import itertools, pkg_resources, sys
13 from distutils.version import LooseVersion
14 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
15  sys.version_info.major == 2:
16  print('''Please install chainer < 7.0.0:
17 
18  sudo pip install chainer==6.7.0
19 
20 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
21 ''', file=sys.stderr)
22  sys.exit(1)
23 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name ] == []:
24  print('''Please install CuPy
25 
26  sudo pip install cupy-cuda[your cuda version]
27 i.e.
28  sudo pip install cupy-cuda91
29 
30 ''', file=sys.stderr)
31  # sys.exit(1)
32 import chainer
33 from chainer import cuda
34 from chainer.datasets import TransformDataset
35 import chainer.serializers as S
36 from chainer.training import extensions
37 import fcn
38 import numpy as np
39 
40 from jsk_recognition_utils.datasets import SemanticSegmentationDataset
41 import rospkg
42 
43 
45 
46  def __init__(self):
47  rospack = rospkg.RosPack()
48  jsk_perception_datasets_path = osp.join(
49  rospack.get_path('jsk_perception'), 'learning_datasets')
50 
51  parser = argparse.ArgumentParser()
52 
53  # Dataset directory
54  parser.add_argument('--train_dataset_dir', type=str,
55  default=osp.join(jsk_perception_datasets_path,
56  'kitchen_dataset', 'train'))
57  parser.add_argument('--val_dataset_dir', type=str,
58  default=osp.join(jsk_perception_datasets_path,
59  'kitchen_dataset', 'test'))
60 
61  # Model
62  parser.add_argument(
63  '--model_name', type=str, default='fcn32s',
64  choices=['fcn32s', 'fcn16s', 'fcn8s', 'fcn8s_at_once'])
65 
66  # Training parameters
67  parser.add_argument('--gpu', type=int, default=0)
68  parser.add_argument('--batch_size', type=int, default=1)
69  parser.add_argument('--max_epoch', type=int, default=100)
70  parser.add_argument('--lr', type=float, default=1e-10)
71  parser.add_argument('--weight_decay', type=float, default=0.0001)
72  parser.add_argument('--out_dir', type=str, default=None)
73  parser.add_argument('--progressbar_update_interval', type=float,
74  default=10)
75  parser.add_argument('--print_interval', type=float, default=100)
76  parser.add_argument('--print_interval_type', type=str,
77  default='iteration',
78  choices=['epoch', 'iteration'])
79  parser.add_argument('--log_interval', type=float, default=10)
80  parser.add_argument('--log_interval_type', type=str,
81  default='iteration',
82  choices=['epoch', 'iteration'])
83  parser.add_argument('--plot_interval', type=float, default=5)
84  parser.add_argument('--plot_interval_type', type=str,
85  default='epoch',
86  choices=['epoch', 'iteration'])
87  parser.add_argument('--eval_interval', type=float, default=10)
88  parser.add_argument('--eval_interval_type', type=str,
89  default='epoch',
90  choices=['epoch', 'iteration'])
91  parser.add_argument('--save_interval', type=float, default=10)
92  parser.add_argument('--save_interval_type', type=str,
93  default='epoch',
94  choices=['epoch', 'iteration'])
95 
96  args = parser.parse_args()
97 
98  self.train_dataset_dir = args.train_dataset_dir
99  self.val_dataset_dir = args.val_dataset_dir
100  self.model_name = args.model_name
101  self.gpu = args.gpu
102  self.batch_size = args.batch_size
103  self.max_epoch = args.max_epoch
104  self.lr = args.lr
105  self.weight_decay = args.weight_decay
106  self.out_dir = args.out_dir
107  self.progressbar_update_interval = args.progressbar_update_interval
108  self.print_interval = args.print_interval
109  self.print_interval_type = args.print_interval_type
110  self.log_interval = args.log_interval
111  self.log_interval_type = args.log_interval_type
112  self.plot_interval = args.plot_interval
113  self.plot_interval_type = args.plot_interval_type
114  self.eval_interval = args.eval_interval
115  self.eval_interval_type = args.eval_interval_type
116  self.save_interval = args.save_interval
117  self.save_interval_type = args.save_interval_type
118 
119  now = datetime.datetime.now()
120  self.timestamp_iso = now.isoformat()
121  timestamp = now.strftime('%Y%m%d-%H%M%S')
122  if self.out_dir is None:
123  self.out_dir = osp.join(
124  rospkg.get_ros_home(), 'learning_logs', timestamp)
125 
126  # Main process
127  self.load_dataset()
128  self.setup_iterator()
129  self.load_model()
130  self.setup_optimizer()
131  self.setup_trainer()
132  self.trainer.run()
133 
134  def load_dataset(self):
135  self.train_dataset = SemanticSegmentationDataset(
136  self.train_dataset_dir)
137  self.val_dataset = SemanticSegmentationDataset(self.val_dataset_dir)
138 
139  def transform_dataset(self, in_data):
140  rgb_img, lbl = in_data
141  # RGB -> BGR
142  mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
143  bgr_img = rgb_img[:, :, ::-1]
144  bgr_img = bgr_img.astype(np.float32)
145  bgr_img -= mean_bgr
146  # H, W, C -> C, H, W
147  bgr_img = bgr_img.transpose((2, 0, 1))
148 
149  return bgr_img, lbl
150 
151  def setup_iterator(self):
152  train_dataset_transformed = TransformDataset(
153  self.train_dataset, self.transform_dataset)
154  val_dataset_transformed = TransformDataset(
155  self.val_dataset, self.transform_dataset)
156  self.train_iterator = chainer.iterators.MultiprocessIterator(
157  train_dataset_transformed, batch_size=self.batch_size,
158  shared_mem=10 ** 7)
159  self.val_iterator = chainer.iterators.MultiprocessIterator(
160  val_dataset_transformed, batch_size=self.batch_size,
161  shared_mem=10 ** 7, repeat=False, shuffle=False)
162 
163  def load_model(self):
164  n_class = len(self.train_dataset.class_names)
165  if self.model_name == 'fcn32s':
166  self.model = fcn.models.FCN32s(n_class=n_class)
167  vgg = fcn.models.VGG16()
168  vgg_path = vgg.download()
169  S.load_npz(vgg_path, vgg)
170  self.model.init_from_vgg16(vgg)
171  elif self.model_name == 'fcn16s':
172  self.model = fcn.models.FCN16s(n_class=n_class)
173  fcn32s = fcn.models.FCN32s()
174  fcn32s_path = fcn32s.download()
175  S.load_npz(fcn32s_path, fcn32s)
176  self.model.init_from_fcn32s(fcn32s_path, fcn32s)
177  elif self.model_name == 'fcn8s':
178  self.model = fcn.models.FCN8s(n_class=n_class)
179  fcn16s = fcn.models.FCN16s()
180  fcn16s_path = fcn16s.download()
181  S.load_npz(fcn16s_path, fcn16s)
182  self.model.init_from_fcn16s(fcn16s_path, fcn16s)
183  elif self.model_name == 'fcn8s_at_once':
184  self.model = fcn.models.FCN8sAtOnce(n_class=n_class)
185  vgg = fcn.models.VGG16()
186  vgg_path = vgg.download()
187  S.load_npz(vgg_path, vgg)
188  self.model.init_from_vgg16(vgg)
189  else:
190  raise ValueError(
191  'Unsupported model_name: {}'.format(self.model_name))
192 
193  if self.gpu >= 0:
194  cuda.get_device_from_id(self.gpu).use()
195  self.model.to_gpu()
196 
197  def setup_optimizer(self):
198  self.optimizer = chainer.optimizers.MomentumSGD(
199  lr=self.lr, momentum=0.9)
200  self.optimizer.setup(self.model)
201  self.optimizer.add_hook(
202  chainer.optimizer.WeightDecay(rate=self.weight_decay))
203 
204  def setup_trainer(self):
205  self.updater = chainer.training.updater.StandardUpdater(
206  self.train_iterator, self.optimizer, device=self.gpu)
207  self.trainer = chainer.training.Trainer(
208  self.updater, (self.max_epoch, 'epoch'), out=self.out_dir)
209 
210  self.trainer.extend(
211  extensions.Evaluator(
212  self.val_iterator, self.model, device=self.gpu),
213  trigger=(self.eval_interval, self.eval_interval_type))
214 
215  # Save snapshot
216  self.trainer.extend(
217  extensions.snapshot_object(
218  self.model,
219  savefun=S.save_npz,
220  filename='model_snapshot.npz'),
221  trigger=chainer.training.triggers.MinValueTrigger(
222  'validation/main/loss',
223  (self.save_interval, self.save_interval_type)))
224 
225  # Dump network architecture
226  self.trainer.extend(
227  extensions.dump_graph(
228  root_name='main/loss',
229  out_name='network_architecture.dot'))
230 
231  # Logging
232  self.trainer.extend(
233  extensions.ProgressBar(
234  update_interval=self.progressbar_update_interval))
235  self.trainer.extend(
236  extensions.observe_lr(),
237  trigger=(self.log_interval, self.log_interval_type))
238  self.trainer.extend(
239  extensions.LogReport(
240  log_name='log.json',
241  trigger=(self.log_interval, self.log_interval_type)))
242  self.trainer.extend(
243  extensions.PrintReport([
244  'iteration',
245  'epoch',
246  'elapsed_time',
247  'lr',
248  'main/loss',
249  'validation/main/loss',
250  ]), trigger=(self.print_interval, self.print_interval_type))
251 
252  # Plot
253  self.trainer.extend(
254  extensions.PlotReport([
255  'main/loss',
256  'validation/main/loss',
257  ],
258  file_name='loss_plot.png',
259  x_key=self.plot_interval_type,
260  trigger=(self.plot_interval, self.plot_interval_type)),
261  trigger=(self.plot_interval, self.plot_interval_type))
262 
263  # Dump params
264  params = dict()
265  params['model_name'] = self.model_name
266  params['train_dataset_dir'] = self.train_dataset_dir
267  params['val_dataset_dir'] = self.val_dataset_dir
268  params['class_names'] = self.train_dataset.class_names
269  params['timestamp'] = self.timestamp_iso
270  params['out_dir'] = self.out_dir
271  params['gpu'] = self.gpu
272  params['batch_size'] = self.batch_size
273  params['max_epoch'] = self.max_epoch
274  params['lr'] = self.lr
275  params['weight_decay'] = self.weight_decay
276  self.trainer.extend(
277  fcn.extensions.ParamsReport(params, file_name='params.yaml'))
278 
279  # Dump param for fcn_object_segmentation.py
280  model_name = dict()
281  model_name['model_name'] = self.model_name
282  self.trainer.extend(
283  fcn.extensions.ParamsReport(
284  model_name, file_name='model_name.yaml'))
285  target_names = dict()
286  target_names['target_names'] = self.train_dataset.class_names
287  self.trainer.extend(
288  fcn.extensions.ParamsReport(
289  target_names, file_name='target_names.yaml'))
290 
291 
292 if __name__ == '__main__':
293  app = TrainFCN()
def setup_trainer(self)
Definition: train_fcn.py:204
def setup_iterator(self)
Definition: train_fcn.py:151
def load_dataset(self)
Definition: train_fcn.py:134
def load_model(self)
Definition: train_fcn.py:163
def setup_optimizer(self)
Definition: train_fcn.py:197
def transform_dataset(self, in_data)
Definition: train_fcn.py:139
def __init__(self)
Definition: train_fcn.py:46


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27