00001
00002
00003
00004
00005 import argparse
00006 import copy
00007 import json
00008 import numpy as np
00009 import os
00010 import sys
00011 import yaml
00012
00013
00014 import chainer
00015 from chainer import serializers
00016 from chainer import training
00017 from chainer.datasets import TransformDataset
00018 from chainer.training import extensions
00019 from chainer.optimizer import WeightDecay
00020
00021
00022 from chainercv import transforms
00023 from chainercv.extensions import DetectionVOCEvaluator
00024 from chainercv.links import SSD300
00025 from chainercv.links.model.ssd import GradientScaling
00026 from chainercv.links.model.ssd import multibox_loss
00027 from chainercv.links.model.ssd import random_distort
00028 from chainercv.links.model.ssd import random_crop_with_bbox_constraints
00029 from chainercv.links.model.ssd import resize_with_random_interpolation
00030 from chainercv.utils import read_image
00031
00032
00033 class SSDDataset(chainer.dataset.DatasetMixin):
00034
00035 def __init__(self, base_dir, label_names):
00036 self.base_dir = base_dir
00037 self.label_names = label_names
00038
00039 self.img_filenames = []
00040 for name in os.listdir(base_dir):
00041
00042 if os.path.splitext(name)[1] != '.jpg':
00043 continue
00044 self.img_filenames.append(os.path.join(base_dir, name))
00045
00046 def __len__(self):
00047 return len(self.img_filenames)
00048
00049 def get_example(self, i):
00050 img_filename = self.img_filenames[i]
00051 img = read_image(img_filename)
00052
00053 anno_filename = os.path.splitext(img_filename)[0] + '__labels.json'
00054
00055 with open(anno_filename, 'r') as f:
00056 anno = json.load(f)
00057 anno = anno['labels']
00058
00059 bbox = []
00060 label = []
00061 for anno_i in anno:
00062 h = anno_i['size']['y']
00063 w = anno_i['size']['x']
00064 center_y = anno_i['centre']['y']
00065 center_x = anno_i['centre']['x']
00066 try:
00067 l = self.label_names.index(anno_i['label_class'])
00068 except Exception as e:
00069 print >> sys.stderr, "Failed to index label class: {}".format(anno_i)
00070 print >> sys.stderr, "image file name: {}".format(img_filename)
00071 print >> sys.stderr, "annotation file name: {}".format(anno_filename)
00072 continue
00073 bbox.append(
00074 [center_y - h / 2, center_x - w / 2,
00075 center_y + h / 2, center_x + w / 2])
00076 label.append(l)
00077 return img, np.array(bbox, dtype=np.float32), np.array(label, dtype=np.int32)
00078
00079
00080 class MultiboxTrainChain(chainer.Chain):
00081
00082 def __init__(self, model, alpha=1, k=3):
00083 super(MultiboxTrainChain, self).__init__()
00084 with self.init_scope():
00085 self.model = model
00086 self.alpha = alpha
00087 self.k = k
00088
00089 def __call__(self, imgs, gt_mb_locs, gt_mb_labs):
00090 mb_locs, mb_confs = self.model(imgs)
00091 loc_loss, conf_loss = multibox_loss(
00092 mb_locs, mb_confs, gt_mb_locs, gt_mb_labs, self.k)
00093 loss = loc_loss * self.alpha + conf_loss
00094
00095 chainer.reporter.report(
00096 {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
00097 self)
00098
00099 return loss
00100
00101
00102 class Transform(object):
00103 """Class for augumentation"""
00104
00105 def __init__(self, coder, size, mean):
00106
00107 self.coder = copy.copy(coder)
00108 self.coder.to_cpu()
00109
00110 self.size = size
00111 self.mean = mean
00112
00113 def __call__(self, in_data):
00114 img, bbox, label = in_data
00115
00116
00117 img = random_distort(img)
00118
00119
00120 if np.random.randint(2):
00121 img, param = transforms.random_expand(
00122 img, fill=self.mean, return_param=True)
00123 bbox = transforms.translate_bbox(
00124 bbox, y_offset=param["y_offset"], x_offset=param["x_offset"])
00125
00126
00127 img, param = random_crop_with_bbox_constraints(
00128 img, bbox, return_param=True)
00129 bbox, param = transforms.crop_bbox(
00130 bbox, y_slice=param["y_slice"], x_slice=param["x_slice"],
00131 allow_outside_center=False, return_param=True)
00132 label = label[param["index"]]
00133
00134
00135 _, H, W = img.shape
00136 img = resize_with_random_interpolation(img, (self.size, self.size))
00137 bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))
00138
00139
00140 img -= self.mean
00141 mb_loc, mb_lab = self.coder.encode(bbox, label)
00142
00143 return img, mb_loc, mb_lab
00144
00145
00146 if __name__ == '__main__':
00147
00148 p = argparse.ArgumentParser()
00149 p.add_argument("label_file", help="path to label file")
00150 p.add_argument("train", help="path to train dataset directory")
00151 p.add_argument("--val", help="path to validation dataset directory. If this argument is not specified, train dataset is used with ratio train:val = 8:2.", default=None)
00152 p.add_argument("--base-model", help="base model name", default="voc0712")
00153 p.add_argument("--batchsize", "-b", type=int, default=16)
00154 p.add_argument("--iteration", type=int, default=120000)
00155 p.add_argument("--gpu", "-g", type=int, default=-1)
00156 p.add_argument("--out", "-o", type=str, default="results")
00157 p.add_argument("--resume", type=str, default="")
00158 p.add_argument("--lr", type=float, default=1e-4)
00159 p.add_argument("--val-iter", type=int, default=100)
00160 p.add_argument("--log-iter", type=int, default=10)
00161 p.add_argument("--model-iter", type=int, default=200)
00162
00163 args = p.parse_args()
00164
00165
00166 with open(args.label_file, "r") as f:
00167 label_names = tuple(yaml.load(f))
00168
00169 print "Loaded %d labels" % len(label_names)
00170
00171 if args.val is None:
00172 dataset = SSDDataset(args.train, label_names)
00173 train, test = chainer.datasets.split_dataset_random(
00174 dataset, int(len(dataset) * 0.8))
00175 else:
00176 train = SSDDataset(args.train, label_names)
00177 test = SSDDataset(args.val, label_names)
00178
00179 print "train: {}, test: {}".format(len(train), len(test))
00180
00181 pretrained_model = SSD300(pretrained_model=args.base_model)
00182
00183
00184 model = SSD300(n_fg_class=len(dataset.label_names))
00185 model.extractor.copyparams(pretrained_model.extractor)
00186 model.multibox.loc.copyparams(pretrained_model.multibox.loc)
00187
00188 model.use_preset("evaluate")
00189
00190 train_chain = MultiboxTrainChain(model)
00191
00192 if args.gpu >= 0:
00193 chainer.cuda.get_device(args.gpu).use()
00194 model.to_gpu()
00195
00196 train = TransformDataset(
00197 train, Transform(model.coder, model.insize, model.mean))
00198 train_iter = chainer.iterators.MultiprocessIterator(
00199 train, args.batchsize)
00200
00201 test_iter = chainer.iterators.SerialIterator(
00202 test, args.batchsize,
00203 repeat=False, shuffle=False)
00204
00205 optimizer = chainer.optimizers.MomentumSGD(lr=args.lr)
00206 optimizer.setup(train_chain)
00207
00208 for param in train_chain.params():
00209 if param.name == 'b':
00210 param.update_rule.add_hook(GradientScaling(2))
00211 else:
00212 param.update_rule.add_hook(WeightDecay(0.0005))
00213
00214 updater = training.StandardUpdater(
00215 train_iter, optimizer, device=args.gpu)
00216 trainer = training.Trainer(
00217 updater, (args.iteration, "iteration"), args.out)
00218
00219 val_interval = args.val_iter, "iteration"
00220 trainer.extend(
00221 DetectionVOCEvaluator(
00222 test_iter, model, use_07_metric=True,
00223 label_names=label_names),
00224 trigger=val_interval)
00225
00226 log_interval = args.log_iter, "iteration"
00227 trainer.extend(extensions.LogReport(trigger=log_interval))
00228 trainer.extend(extensions.observe_lr(), trigger=log_interval)
00229 trainer.extend(extensions.PrintReport(
00230 ['epoch', 'iteration', 'lr',
00231 'main/loss', 'main/loss/loc', 'main/loss/conf',
00232 'validation/main/map']),
00233 trigger=log_interval)
00234 trainer.extend(extensions.ProgressBar(update_interval=10))
00235
00236 trainer.extend(extensions.snapshot(), trigger=val_interval)
00237 trainer.extend(
00238 extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
00239 trigger=(args.model_iter, 'iteration'))
00240
00241 if args.resume:
00242 serializers.load_npz(args.resume, trainer)
00243
00244 trainer.run()