ssd_train_dataset.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 # Author: Yuki Furuta <furushchev@jsk.imi.i.u-tokyo.ac.jp>
00004 
00005 import argparse
00006 import copy
00007 import json
00008 import numpy as np
00009 import os
00010 import sys
00011 import yaml
00012 
00013 # chainer
00014 import chainer
00015 from chainer import serializers
00016 from chainer import training
00017 from chainer.datasets import TransformDataset
00018 from chainer.training import extensions
00019 from chainer.optimizer import WeightDecay
00020 
00021 # chainercv
00022 from chainercv import transforms
00023 from chainercv.extensions import DetectionVOCEvaluator
00024 from chainercv.links import SSD300
00025 from chainercv.links.model.ssd import GradientScaling
00026 from chainercv.links.model.ssd import multibox_loss
00027 from chainercv.links.model.ssd import random_distort
00028 from chainercv.links.model.ssd import random_crop_with_bbox_constraints
00029 from chainercv.links.model.ssd import resize_with_random_interpolation
00030 from chainercv.utils import read_image
00031 
00032 
00033 class SSDDataset(chainer.dataset.DatasetMixin):
00034 
00035     def __init__(self, base_dir, label_names):
00036         self.base_dir = base_dir
00037         self.label_names = label_names
00038 
00039         self.img_filenames = []
00040         for name in os.listdir(base_dir):
00041             # If the file is not an image, ignore the file.
00042             if os.path.splitext(name)[1] != '.jpg':
00043                 continue
00044             self.img_filenames.append(os.path.join(base_dir, name))
00045 
00046     def __len__(self):
00047         return len(self.img_filenames)
00048 
00049     def get_example(self, i):
00050         img_filename = self.img_filenames[i]
00051         img = read_image(img_filename)
00052 
00053         anno_filename = os.path.splitext(img_filename)[0] + '__labels.json'
00054 
00055         with open(anno_filename, 'r') as f:
00056             anno = json.load(f)
00057         anno = anno['labels']
00058 
00059         bbox = []
00060         label = []
00061         for anno_i in anno:
00062             h = anno_i['size']['y']
00063             w = anno_i['size']['x']
00064             center_y = anno_i['centre']['y']
00065             center_x = anno_i['centre']['x']
00066             try:
00067                 l = self.label_names.index(anno_i['label_class'])
00068             except Exception as e:
00069                 print >> sys.stderr, "Failed to index label class: {}".format(anno_i)
00070                 print >> sys.stderr, "image file name: {}".format(img_filename)
00071                 print >> sys.stderr, "annotation file name: {}".format(anno_filename)
00072                 continue
00073             bbox.append(
00074                 [center_y - h / 2, center_x - w / 2,
00075                  center_y + h / 2, center_x + w / 2])
00076             label.append(l)
00077         return img, np.array(bbox, dtype=np.float32), np.array(label, dtype=np.int32)
00078 
00079 
00080 class MultiboxTrainChain(chainer.Chain):
00081 
00082     def __init__(self, model, alpha=1, k=3):
00083         super(MultiboxTrainChain, self).__init__()
00084         with self.init_scope():
00085             self.model = model
00086         self.alpha = alpha
00087         self.k = k
00088 
00089     def __call__(self, imgs, gt_mb_locs, gt_mb_labs):
00090         mb_locs, mb_confs = self.model(imgs)
00091         loc_loss, conf_loss = multibox_loss(
00092             mb_locs, mb_confs, gt_mb_locs, gt_mb_labs, self.k)
00093         loss = loc_loss * self.alpha + conf_loss
00094 
00095         chainer.reporter.report(
00096             {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
00097             self)
00098 
00099         return loss
00100 
00101 
00102 class Transform(object):
00103     """Class for augumentation"""
00104 
00105     def __init__(self, coder, size, mean):
00106         # copy to send to cpu
00107         self.coder = copy.copy(coder)
00108         self.coder.to_cpu()
00109 
00110         self.size = size
00111         self.mean = mean
00112 
00113     def __call__(self, in_data):
00114         img, bbox, label = in_data
00115 
00116         # 1. Color augumentation
00117         img = random_distort(img)
00118 
00119         # 2. Random expansion
00120         if np.random.randint(2):
00121             img, param = transforms.random_expand(
00122                 img, fill=self.mean, return_param=True)
00123             bbox = transforms.translate_bbox(
00124                 bbox, y_offset=param["y_offset"], x_offset=param["x_offset"])
00125 
00126         # 3. Random cropping
00127         img, param = random_crop_with_bbox_constraints(
00128             img, bbox, return_param=True)
00129         bbox, param = transforms.crop_bbox(
00130             bbox, y_slice=param["y_slice"], x_slice=param["x_slice"],
00131             allow_outside_center=False, return_param=True)
00132         label = label[param["index"]]
00133 
00134         # 4. Resizing with random interpolation
00135         _, H, W = img.shape
00136         img = resize_with_random_interpolation(img, (self.size, self.size))
00137         bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))
00138 
00139         # 5. Transformation for SSD network input
00140         img -= self.mean
00141         mb_loc, mb_lab = self.coder.encode(bbox, label)
00142 
00143         return img, mb_loc, mb_lab
00144 
00145 
00146 if __name__ == '__main__':
00147 
00148     p = argparse.ArgumentParser()
00149     p.add_argument("label_file", help="path to label file")
00150     p.add_argument("train", help="path to train dataset directory")
00151     p.add_argument("--val", help="path to validation dataset directory. If this argument is not specified, train dataset is used with ratio train:val = 8:2.", default=None)
00152     p.add_argument("--base-model", help="base model name", default="voc0712")
00153     p.add_argument("--batchsize", "-b", type=int, default=16)
00154     p.add_argument("--iteration", type=int, default=120000)
00155     p.add_argument("--gpu", "-g", type=int, default=-1)  # use CPU by default
00156     p.add_argument("--out", "-o", type=str, default="results")
00157     p.add_argument("--resume", type=str, default="")
00158     p.add_argument("--lr", type=float, default=1e-4)
00159     p.add_argument("--val-iter", type=int, default=100)
00160     p.add_argument("--log-iter", type=int, default=10)
00161     p.add_argument("--model-iter", type=int, default=200)
00162 
00163     args = p.parse_args()
00164 
00165     # load label file
00166     with open(args.label_file, "r") as f:
00167         label_names = tuple(yaml.load(f))
00168 
00169     print "Loaded %d labels" % len(label_names)
00170 
00171     if args.val is None:
00172         dataset = SSDDataset(args.train, label_names)
00173         train, test = chainer.datasets.split_dataset_random(
00174             dataset, int(len(dataset) * 0.8))
00175     else:
00176         train = SSDDataset(args.train, label_names)
00177         test  = SSDDataset(args.val, label_names)
00178 
00179     print "train: {}, test: {}".format(len(train), len(test))
00180 
00181     pretrained_model = SSD300(pretrained_model=args.base_model)
00182 
00183     # copy from pretrained model
00184     model = SSD300(n_fg_class=len(dataset.label_names))
00185     model.extractor.copyparams(pretrained_model.extractor)
00186     model.multibox.loc.copyparams(pretrained_model.multibox.loc)
00187 
00188     model.use_preset("evaluate")
00189 
00190     train_chain = MultiboxTrainChain(model)
00191 
00192     if args.gpu >= 0:
00193         chainer.cuda.get_device(args.gpu).use()
00194         model.to_gpu()
00195 
00196     train = TransformDataset(
00197         train, Transform(model.coder, model.insize, model.mean))
00198     train_iter = chainer.iterators.MultiprocessIterator(
00199         train, args.batchsize)
00200 
00201     test_iter = chainer.iterators.SerialIterator(
00202         test, args.batchsize,
00203         repeat=False, shuffle=False)
00204 
00205     optimizer = chainer.optimizers.MomentumSGD(lr=args.lr)
00206     optimizer.setup(train_chain)
00207 
00208     for param in train_chain.params():
00209         if param.name == 'b':
00210             param.update_rule.add_hook(GradientScaling(2))
00211         else:
00212             param.update_rule.add_hook(WeightDecay(0.0005))
00213 
00214     updater = training.StandardUpdater(
00215         train_iter, optimizer, device=args.gpu)
00216     trainer = training.Trainer(
00217         updater, (args.iteration, "iteration"), args.out)
00218 
00219     val_interval = args.val_iter, "iteration"
00220     trainer.extend(
00221         DetectionVOCEvaluator(
00222             test_iter, model, use_07_metric=True,
00223             label_names=label_names),
00224         trigger=val_interval)
00225 
00226     log_interval = args.log_iter, "iteration"
00227     trainer.extend(extensions.LogReport(trigger=log_interval))
00228     trainer.extend(extensions.observe_lr(), trigger=log_interval)
00229     trainer.extend(extensions.PrintReport(
00230         ['epoch', 'iteration', 'lr',
00231          'main/loss', 'main/loss/loc', 'main/loss/conf',
00232          'validation/main/map']),
00233         trigger=log_interval)
00234     trainer.extend(extensions.ProgressBar(update_interval=10))
00235 
00236     trainer.extend(extensions.snapshot(), trigger=val_interval)
00237     trainer.extend(
00238         extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
00239         trigger=(args.model_iter, 'iteration'))
00240 
00241     if args.resume:
00242         serializers.load_npz(args.resume, trainer)
00243 
00244     trainer.run()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07