5 from __future__
import print_function
16 import itertools, pkg_resources
17 from distutils.version
import LooseVersion
18 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0')
and \
19 sys.version_info.major == 2:
20 print(
'''Please install chainer < 7.0.0: 22 sudo pip install chainer==6.7.0 24 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 27 if [p
for p
in list(itertools.chain(*[pkg_resources.find_distributions(_)
for _
in sys.path]))
if "cupy-" in p.project_name ] == []:
28 print(
'''Please install CuPy 30 sudo pip install cupy-cuda[your cuda version] 32 sudo pip install cupy-cuda91 37 from chainer
import serializers
38 from chainer
import training
39 from chainer.datasets
import TransformDataset
40 from chainer.training
import extensions
41 from chainer.optimizer
import WeightDecay
44 from chainercv
import transforms
45 from chainercv.extensions
import DetectionVOCEvaluator
46 from chainercv.links
import SSD300
47 from chainercv.links.model.ssd
import GradientScaling
48 from chainercv.links.model.ssd
import multibox_loss
49 from chainercv.links.model.ssd
import random_distort
50 from chainercv.links.model.ssd
import random_crop_with_bbox_constraints
51 from chainercv.links.model.ssd
import resize_with_random_interpolation
52 from chainercv.utils
import read_image
62 for name
in os.listdir(base_dir):
64 if os.path.splitext(name)[1] !=
'.jpg':
66 self.img_filenames.append(os.path.join(base_dir, name))
73 img = read_image(img_filename)
75 anno_filename = os.path.splitext(img_filename)[0] +
'__labels.json' 77 with open(anno_filename,
'r') as f: 84 h = anno_i[
'size'][
'y']
85 w = anno_i[
'size'][
'x']
86 center_y = anno_i[
'centre'][
'y']
87 center_x = anno_i[
'centre'][
'x']
89 l = self.label_names.index(anno_i[
'label_class'])
90 except Exception
as e:
91 print(
"Failed to index label class: {}".format(anno_i), file=sys.stderr)
92 print(
"image file name: {}".format(img_filename), file=sys.stderr)
93 print(
"annotation file name: {}".format(anno_filename), file=sys.stderr)
96 [center_y - h / 2, center_x - w / 2,
97 center_y + h / 2, center_x + w / 2])
99 return img, np.array(bbox, dtype=np.float32), np.array(label, dtype=np.int32)
105 super(MultiboxTrainChain, self).
__init__()
106 with self.init_scope():
112 mb_locs, mb_confs = self.
model(imgs)
113 loc_loss, conf_loss = multibox_loss(
114 mb_locs, mb_confs, gt_mb_locs, gt_mb_labs, self.
k)
115 loss = loc_loss * self.
alpha + conf_loss
117 chainer.reporter.report(
118 {
'loss': loss,
'loss/loc': loc_loss,
'loss/conf': conf_loss},
125 """Class for augumentation""" 136 img, bbox, label = in_data
139 img = random_distort(img)
142 if np.random.randint(2):
143 img, param = transforms.random_expand(
144 img, fill=self.
mean, return_param=
True)
145 bbox = transforms.translate_bbox(
146 bbox, y_offset=param[
"y_offset"], x_offset=param[
"x_offset"])
149 img, param = random_crop_with_bbox_constraints(
150 img, bbox, return_param=
True)
151 bbox, param = transforms.crop_bbox(
152 bbox, y_slice=param[
"y_slice"], x_slice=param[
"x_slice"],
153 allow_outside_center=
False, return_param=
True)
154 label = label[param[
"index"]]
158 img = resize_with_random_interpolation(img, (self.
size, self.
size))
159 bbox = transforms.resize_bbox(bbox, (H, W), (self.
size, self.
size))
163 mb_loc, mb_lab = self.coder.encode(bbox, label)
165 return img, mb_loc, mb_lab
168 if __name__ ==
'__main__':
170 p = argparse.ArgumentParser()
171 p.add_argument(
"label_file", help=
"path to label file")
172 p.add_argument(
"train", help=
"path to train dataset directory")
173 p.add_argument(
"--val", help=
"path to validation dataset directory. If this argument is not specified, train dataset is used with ratio train:val = 8:2.", default=
None)
174 p.add_argument(
"--base-model", help=
"base model name", default=
"voc0712")
175 p.add_argument(
"--batchsize",
"-b", type=int, default=16)
176 p.add_argument(
"--iteration", type=int, default=120000)
177 p.add_argument(
"--gpu",
"-g", type=int, default=-1)
178 p.add_argument(
"--out",
"-o", type=str, default=
"results")
179 p.add_argument(
"--resume", type=str, default=
"")
180 p.add_argument(
"--lr", type=float, default=1e-4)
181 p.add_argument(
"--val-iter", type=int, default=100)
182 p.add_argument(
"--log-iter", type=int, default=10)
183 p.add_argument(
"--model-iter", type=int, default=200)
185 args = p.parse_args()
188 with open(args.label_file,
"r") as f: 189 label_names = tuple(yaml.load(f)) 191 print("Loaded %d labels" % len(label_names))
195 train, test = chainer.datasets.split_dataset_random(
196 dataset,
int(len(dataset) * 0.8))
201 print(
"train: {}, test: {}".format(len(train), len(test)))
203 pretrained_model = SSD300(pretrained_model=args.base_model)
206 model = SSD300(n_fg_class=len(dataset.label_names))
207 model.extractor.copyparams(pretrained_model.extractor)
208 model.multibox.loc.copyparams(pretrained_model.multibox.loc)
210 model.use_preset(
"evaluate")
215 chainer.cuda.get_device(args.gpu).use()
218 train = TransformDataset(
219 train,
Transform(model.coder, model.insize, model.mean))
220 train_iter = chainer.iterators.MultiprocessIterator(
221 train, args.batchsize)
223 test_iter = chainer.iterators.SerialIterator(
224 test, args.batchsize,
225 repeat=
False, shuffle=
False)
227 optimizer = chainer.optimizers.MomentumSGD(lr=args.lr)
228 optimizer.setup(train_chain)
230 for param
in train_chain.params():
231 if param.name ==
'b':
232 param.update_rule.add_hook(GradientScaling(2))
234 param.update_rule.add_hook(WeightDecay(0.0005))
236 updater = training.StandardUpdater(
237 train_iter, optimizer, device=args.gpu)
238 trainer = training.Trainer(
239 updater, (args.iteration,
"iteration"), args.out)
241 val_interval = args.val_iter,
"iteration" 243 DetectionVOCEvaluator(
244 test_iter, model, use_07_metric=
True,
245 label_names=label_names),
246 trigger=val_interval)
248 log_interval = args.log_iter,
"iteration" 249 trainer.extend(extensions.LogReport(trigger=log_interval))
250 trainer.extend(extensions.observe_lr(), trigger=log_interval)
251 trainer.extend(extensions.PrintReport(
252 [
'epoch',
'iteration',
'lr',
253 'main/loss',
'main/loss/loc',
'main/loss/conf',
254 'validation/main/map']),
255 trigger=log_interval)
256 trainer.extend(extensions.ProgressBar(update_interval=10))
258 trainer.extend(extensions.snapshot(), trigger=val_interval)
260 extensions.snapshot_object(model,
'model_iter_{.updater.iteration}'),
261 trigger=(args.model_iter,
'iteration'))
264 serializers.load_npz(args.resume, trainer)
def __init__(self, model, alpha=1, k=3)
def __call__(self, imgs, gt_mb_locs, gt_mb_labs)
def __init__(self, base_dir, label_names)