3 from __future__ 
import absolute_import
 
    4 from __future__ 
import division
 
    5 from __future__ 
import print_function
 
   12 import itertools, pkg_resources, sys
 
   13 from distutils.version 
import LooseVersion
 
   14 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0') 
and \
 
   15         sys.version_info.major == 2:
 
   16     print(
'''Please install chainer < 7.0.0: 
   18     sudo pip install chainer==6.7.0 
   20 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 
   23 if [p 
for p 
in list(itertools.chain(*[pkg_resources.find_distributions(_) 
for _ 
in sys.path])) 
if "cupy-" in p.project_name] == []:
 
   24     print(
'''Please install CuPy 
   26     sudo pip install cupy-cuda[your cuda version] 
   28     sudo pip install cupy-cuda91 
   32 from chainer 
import cuda
 
   33 from chainer.datasets 
import TransformDataset
 
   34 from chainer.training 
import extensions
 
   44     min_value = np.nanmin(depth) 
if min_value 
is None else min_value
 
   45     max_value = np.nanmax(depth) 
if max_value 
is None else max_value
 
   47     gray_depth = depth.copy()
 
   48     nan_mask = np.isnan(gray_depth)
 
   49     gray_depth[nan_mask] = 0
 
   50     gray_depth = 255 * (gray_depth - min_value) / (max_value - min_value)
 
   51     gray_depth[gray_depth < 0] = 0
 
   52     gray_depth[gray_depth > 255] = 255
 
   53     gray_depth = gray_depth.astype(np.uint8)
 
   54     colorized = cv2.applyColorMap(gray_depth, cv2.COLORMAP_JET)
 
   55     colorized[nan_mask] = (0, 0, 0)
 
   64     label_gt = in_data[0][2]
 
   65     depth_gt = in_data[0][3]
 
   67     image_rgb, depth, label_gt, depth_gt, _ = in_data
 
   70     image_bgr = image_rgb[:, :, ::-1]
 
   71     image_bgr = image_rgb.astype(np.float32)
 
   72     mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
 
   75     image_bgr = image_bgr.transpose((2, 0, 1))
 
   79         depth, min_value=min_value, max_value=max_value)
 
   80     depth_bgr = depth_bgr.astype(np.float32)
 
   82     depth_bgr = depth_bgr.transpose((2, 0, 1))
 
   84     return image_bgr, depth_bgr, label_gt, depth_gt
 
   88     rospack = rospkg.RosPack()
 
   89     jsk_perception_datasets_path = osp.join(
 
   90         rospack.get_path(
'jsk_perception'), 
'learning_datasets')
 
   92     parser = argparse.ArgumentParser(
 
   93         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
   95         '-g', 
'--gpu', default=0, type=int, help=
'GPU id')
 
   97         '-d', 
'--dataset_dir',
 
   99             jsk_perception_datasets_path, 
'human_size_mirror_dataset'),
 
  100         type=str, help=
'Path to root directory of dataset')
 
  102         '-m', 
'--model', default=
'FCN8sDepthPredictionConcatFirst', type=str,
 
  103         help=
'Model class name')
 
  105         '-b', 
'--batch_size', default=1, type=int, help=
'Batch size')
 
  107         '-e', 
'--epoch', default=100, type=int, help=
'Training epoch')
 
  109         '-o', 
'--out', type=str, default=
None, help=
'Output directory')
 
  110     args = parser.parse_args()
 
  117     timestamp = datetime.datetime.now().strftime(
'%Y%m%d-%H%M%S')
 
  119         out = osp.join(rospkg.get_ros_home(), 
'learning_logs', timestamp)
 
  121     max_iter_epoch = args.epoch, 
'epoch' 
  122     progress_bar_update_interval = 10  
 
  123     print_interval = 100, 
'iteration' 
  124     log_interval = 100, 
'iteration' 
  125     test_interval = 5, 
'epoch' 
  126     save_interval = 5, 
'epoch' 
  130     dataset_train = DepthPredictionDataset(
 
  131         args.dataset_dir, split=
'train', aug=
True)
 
  132     dataset_valid = DepthPredictionDataset(
 
  133         args.dataset_dir, split=
'test', aug=
False)
 
  135     dataset_train_transformed = TransformDataset(dataset_train, transform)
 
  136     dataset_valid_transformed = TransformDataset(dataset_valid, transform)
 
  138     iter_train = chainer.iterators.MultiprocessIterator(
 
  139         dataset_train_transformed, batch_size=args.batch_size,
 
  141     iter_valid = chainer.iterators.MultiprocessIterator(
 
  142         dataset_valid_transformed, batch_size=1, shared_mem=10 ** 8,
 
  143         repeat=
False, shuffle=
False)
 
  147     vgg = fcn.models.VGG16()
 
  148     vgg_path = vgg.download()
 
  149     chainer.serializers.load_npz(vgg_path, vgg)
 
  151     n_class = len(dataset_train.class_names)
 
  154     if args.model == 
'FCN8sDepthPredictionConcatFirst':
 
  155         model = FCN8sDepthPredictionConcatFirst(n_class=n_class, masking=
True)
 
  157         print(
'Invalid model class.')
 
  160     model.init_from_vgg16(vgg)
 
  163         cuda.get_device_from_id(gpu).use()
 
  168     optimizer = chainer.optimizers.Adam(alpha=1.0e-5)
 
  169     optimizer.setup(model)
 
  170     optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005))
 
  172     updater = chainer.training.updater.StandardUpdater(
 
  173         iter_train, optimizer, device=gpu)
 
  175     trainer = chainer.training.Trainer(updater, max_iter_epoch, out=out)
 
  177     trainer.extend(extensions.ExponentialShift(
"alpha", 0.99997))
 
  179     if not osp.isdir(out):
 
  182     with open(osp.join(out, 
'dataset.txt'), 
'w') 
as f:
 
  183         f.write(dataset_train.__class__.__name__)
 
  185     with open(osp.join(out, 
'model.txt'), 
'w') 
as f:
 
  186         f.write(model.__class__.__name__)
 
  188     with open(osp.join(out, 
'batch_size.txt'), 
'w') 
as f:
 
  189         f.write(
str(args.batch_size))
 
  192         extensions.snapshot_object(
 
  194             savefun=chainer.serializers.save_npz,
 
  195             filename=
'model_snapshot.npz'),
 
  196         trigger=chainer.training.triggers.MaxValueTrigger(
 
  197             'validation/main/depth_acc<0.10', save_interval))
 
  200         extensions.dump_graph(
 
  201             root_name=
'main/loss',
 
  202             out_name=
'network_architecture.dot'))
 
  205         extensions.LogReport(
 
  207             trigger=log_interval))
 
  210         extensions.PlotReport([
 
  212             'validation/main/loss',
 
  214             file_name=
'loss_plot.png',
 
  216             trigger=(5, 
'epoch')),
 
  217         trigger=(5, 
'epoch'))
 
  219     trainer.extend(chainer.training.extensions.PrintReport([
 
  228         'main/depth_acc<0.03',
 
  229         'main/depth_acc<0.10',
 
  230         'main/depth_acc<0.30',
 
  231         'validation/main/miou',
 
  232         'validation/main/depth_acc<0.03',
 
  233         'validation/main/depth_acc<0.10',
 
  234         'validation/main/depth_acc<0.30',
 
  235     ]), trigger=print_interval)
 
  238         extensions.observe_lr(),
 
  239         trigger=log_interval)
 
  241         extensions.ProgressBar(update_interval=progress_bar_update_interval))
 
  243         extensions.Evaluator(iter_valid, model, device=gpu),
 
  244         trigger=test_interval)
 
  249 if __name__ == 
'__main__':