train_fcn_depth_prediction.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 
7 import argparse
8 import datetime
9 import os
10 import os.path as osp
11 
12 import itertools, pkg_resources, sys
13 from distutils.version import LooseVersion
14 if LooseVersion(pkg_resources.get_distribution("chainer").version) >= LooseVersion('7.0.0') and \
15  sys.version_info.major == 2:
16  print('''Please install chainer < 7.0.0:
17 
18  sudo pip install chainer==6.7.0
19 
20 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485
21 ''', file=sys.stderr)
22  sys.exit(1)
23 if [p for p in list(itertools.chain(*[pkg_resources.find_distributions(_) for _ in sys.path])) if "cupy-" in p.project_name] == []:
24  print('''Please install CuPy
25 
26  sudo pip install cupy-cuda[your cuda version]
27 i.e.
28  sudo pip install cupy-cuda91
29 
30 ''', file=sys.stderr)
31 import chainer
32 from chainer import cuda
33 from chainer.datasets import TransformDataset
34 from chainer.training import extensions
35 import cv2
36 import fcn
37 from jsk_recognition_utils.chainermodels import FCN8sDepthPredictionConcatFirst
38 from jsk_recognition_utils.datasets import DepthPredictionDataset
39 import numpy as np
40 import rospkg
41 
42 
43 def colorize_depth(depth, min_value=None, max_value=None):
44  min_value = np.nanmin(depth) if min_value is None else min_value
45  max_value = np.nanmax(depth) if max_value is None else max_value
46 
47  gray_depth = depth.copy()
48  nan_mask = np.isnan(gray_depth)
49  gray_depth[nan_mask] = 0
50  gray_depth = 255 * (gray_depth - min_value) / (max_value - min_value)
51  gray_depth[gray_depth < 0] = 0
52  gray_depth[gray_depth > 255] = 255
53  gray_depth = gray_depth.astype(np.uint8)
54  colorized = cv2.applyColorMap(gray_depth, cv2.COLORMAP_JET)
55  colorized[nan_mask] = (0, 0, 0)
56 
57  return colorized
58 
59 
60 def transform(in_data):
61  min_value = 0.5
62  max_value = 5.0
63 
64  label_gt = in_data[0][2]
65  depth_gt = in_data[0][3]
66 
67  image_rgb, depth, label_gt, depth_gt, _ = in_data
68 
69  # RGB -> BGR
70  image_bgr = image_rgb[:, :, ::-1]
71  image_bgr = image_rgb.astype(np.float32)
72  mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
73  image_bgr -= mean_bgr
74  # (H, W, 3) -> (3, H, W)
75  image_bgr = image_bgr.transpose((2, 0, 1))
76 
77  # depth -> depth_bgr: (H, W) -> (H, W, 3) -> (3, H, W)
78  depth_bgr = colorize_depth(
79  depth, min_value=min_value, max_value=max_value)
80  depth_bgr = depth_bgr.astype(np.float32)
81  depth_bgr -= mean_bgr
82  depth_bgr = depth_bgr.transpose((2, 0, 1))
83 
84  return image_bgr, depth_bgr, label_gt, depth_gt
85 
86 
87 def main():
88  rospack = rospkg.RosPack()
89  jsk_perception_datasets_path = osp.join(
90  rospack.get_path('jsk_perception'), 'learning_datasets')
91 
92  parser = argparse.ArgumentParser(
93  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
94  parser.add_argument(
95  '-g', '--gpu', default=0, type=int, help='GPU id')
96  parser.add_argument(
97  '-d', '--dataset_dir',
98  default=osp.join(
99  jsk_perception_datasets_path, 'human_size_mirror_dataset'),
100  type=str, help='Path to root directory of dataset')
101  parser.add_argument(
102  '-m', '--model', default='FCN8sDepthPredictionConcatFirst', type=str,
103  help='Model class name')
104  parser.add_argument(
105  '-b', '--batch_size', default=1, type=int, help='Batch size')
106  parser.add_argument(
107  '-e', '--epoch', default=100, type=int, help='Training epoch')
108  parser.add_argument(
109  '-o', '--out', type=str, default=None, help='Output directory')
110  args = parser.parse_args()
111 
112  gpu = args.gpu
113  out = args.out
114 
115  # 0. config
116 
117  timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
118  if out is None:
119  out = osp.join(rospkg.get_ros_home(), 'learning_logs', timestamp)
120 
121  max_iter_epoch = args.epoch, 'epoch'
122  progress_bar_update_interval = 10 # iteration
123  print_interval = 100, 'iteration'
124  log_interval = 100, 'iteration'
125  test_interval = 5, 'epoch'
126  save_interval = 5, 'epoch'
127 
128  # 1. dataset
129 
130  dataset_train = DepthPredictionDataset(
131  args.dataset_dir, split='train', aug=True)
132  dataset_valid = DepthPredictionDataset(
133  args.dataset_dir, split='test', aug=False)
134 
135  dataset_train_transformed = TransformDataset(dataset_train, transform)
136  dataset_valid_transformed = TransformDataset(dataset_valid, transform)
137 
138  iter_train = chainer.iterators.MultiprocessIterator(
139  dataset_train_transformed, batch_size=args.batch_size,
140  shared_mem=10 ** 8)
141  iter_valid = chainer.iterators.MultiprocessIterator(
142  dataset_valid_transformed, batch_size=1, shared_mem=10 ** 8,
143  repeat=False, shuffle=False)
144 
145  # 2. model
146 
147  vgg = fcn.models.VGG16()
148  vgg_path = vgg.download()
149  chainer.serializers.load_npz(vgg_path, vgg)
150 
151  n_class = len(dataset_train.class_names)
152  assert n_class == 2
153 
154  if args.model == 'FCN8sDepthPredictionConcatFirst':
155  model = FCN8sDepthPredictionConcatFirst(n_class=n_class, masking=True)
156  else:
157  print('Invalid model class.')
158  exit(1)
159 
160  model.init_from_vgg16(vgg)
161 
162  if gpu >= 0:
163  cuda.get_device_from_id(gpu).use()
164  model.to_gpu()
165 
166  # 3. optimizer
167 
168  optimizer = chainer.optimizers.Adam(alpha=1.0e-5)
169  optimizer.setup(model)
170  optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005))
171 
172  updater = chainer.training.updater.StandardUpdater(
173  iter_train, optimizer, device=gpu)
174 
175  trainer = chainer.training.Trainer(updater, max_iter_epoch, out=out)
176 
177  trainer.extend(extensions.ExponentialShift("alpha", 0.99997))
178 
179  if not osp.isdir(out):
180  os.makedirs(out)
181 
182  with open(osp.join(out, 'dataset.txt'), 'w') as f:
183  f.write(dataset_train.__class__.__name__)
184 
185  with open(osp.join(out, 'model.txt'), 'w') as f:
186  f.write(model.__class__.__name__)
187 
188  with open(osp.join(out, 'batch_size.txt'), 'w') as f:
189  f.write(str(args.batch_size))
190 
191  trainer.extend(
192  extensions.snapshot_object(
193  model,
194  savefun=chainer.serializers.save_npz,
195  filename='model_snapshot.npz'),
196  trigger=chainer.training.triggers.MaxValueTrigger(
197  'validation/main/depth_acc<0.10', save_interval))
198 
199  trainer.extend(
200  extensions.dump_graph(
201  root_name='main/loss',
202  out_name='network_architecture.dot'))
203 
204  trainer.extend(
205  extensions.LogReport(
206  log_name='log.json',
207  trigger=log_interval))
208 
209  trainer.extend(
210  extensions.PlotReport([
211  'main/loss',
212  'validation/main/loss',
213  ],
214  file_name='loss_plot.png',
215  x_key='epoch',
216  trigger=(5, 'epoch')),
217  trigger=(5, 'epoch'))
218 
219  trainer.extend(chainer.training.extensions.PrintReport([
220  'iteration',
221  'epoch',
222  'elapsed_time',
223  'lr',
224  'main/loss',
225  'main/seg_loss',
226  'main/reg_loss',
227  'main/miou',
228  'main/depth_acc<0.03',
229  'main/depth_acc<0.10',
230  'main/depth_acc<0.30',
231  'validation/main/miou',
232  'validation/main/depth_acc<0.03',
233  'validation/main/depth_acc<0.10',
234  'validation/main/depth_acc<0.30',
235  ]), trigger=print_interval)
236 
237  trainer.extend(
238  extensions.observe_lr(),
239  trigger=log_interval)
240  trainer.extend(
241  extensions.ProgressBar(update_interval=progress_bar_update_interval))
242  trainer.extend(
243  extensions.Evaluator(iter_valid, model, device=gpu),
244  trigger=test_interval)
245 
246  trainer.run()
247 
248 
249 if __name__ == '__main__':
250  main()
def colorize_depth(depth, min_value=None, max_value=None)


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Mon May 3 2021 03:03:27