5 from __future__ 
import print_function
 
   16 import itertools, pkg_resources
 
   17 from distutils.version 
import LooseVersion
 
   18 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0') 
and \
 
   19         sys.version_info.major == 2:
 
   20     print(
'''Please install chainer < 7.0.0: 
   22     sudo pip install chainer==6.7.0 
   24 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 
   27 if [p 
for p 
in list(itertools.chain(*[pkg_resources.find_distributions(_) 
for _ 
in sys.path])) 
if "cupy-" in p.project_name ] == []:
 
   28     print(
'''Please install CuPy 
   30     sudo pip install cupy-cuda[your cuda version] 
   32     sudo pip install cupy-cuda91 
   37 from chainer 
import serializers
 
   38 from chainer 
import training
 
   39 from chainer.datasets 
import TransformDataset
 
   40 from chainer.training 
import extensions
 
   41 from chainer.optimizer 
import WeightDecay
 
   44 from chainercv 
import transforms
 
   45 from chainercv.extensions 
import DetectionVOCEvaluator
 
   46 from chainercv.links 
import SSD300
 
   47 from chainercv.links.model.ssd 
import GradientScaling
 
   48 from chainercv.links.model.ssd 
import multibox_loss
 
   49 from chainercv.links.model.ssd 
import random_distort
 
   50 from chainercv.links.model.ssd 
import random_crop_with_bbox_constraints
 
   51 from chainercv.links.model.ssd 
import resize_with_random_interpolation
 
   52 from chainercv.utils 
import read_image
 
   62         for name 
in os.listdir(base_dir):
 
   64             if os.path.splitext(name)[1] != 
'.jpg':
 
   73         img = read_image(img_filename)
 
   75         anno_filename = os.path.splitext(img_filename)[0] + 
'__labels.json' 
   77         with open(anno_filename, 
'r') 
as f:
 
   84             h = anno_i[
'size'][
'y']
 
   85             w = anno_i[
'size'][
'x']
 
   86             center_y = anno_i[
'centre'][
'y']
 
   87             center_x = anno_i[
'centre'][
'x']
 
   90             except Exception 
as e:
 
   91                 print(
"Failed to index label class: {}".format(anno_i), file=sys.stderr)
 
   92                 print(
"image file name: {}".format(img_filename), file=sys.stderr)
 
   93                 print(
"annotation file name: {}".format(anno_filename), file=sys.stderr)
 
   96                 [center_y - h / 2, center_x - w / 2,
 
   97                  center_y + h / 2, center_x + w / 2])
 
   99         return img, np.array(bbox, dtype=np.float32), np.array(label, dtype=np.int32)
 
  105         super(MultiboxTrainChain, self).
__init__()
 
  106         with self.init_scope():
 
  112         mb_locs, mb_confs = self.
model(imgs)
 
  113         loc_loss, conf_loss = multibox_loss(
 
  114             mb_locs, mb_confs, gt_mb_locs, gt_mb_labs, self.
k)
 
  115         loss = loc_loss * self.
alpha + conf_loss
 
  117         chainer.reporter.report(
 
  118             {
'loss': loss, 
'loss/loc': loc_loss, 
'loss/conf': conf_loss},
 
  125     """Class for augumentation""" 
  136         img, bbox, label = in_data
 
  139         img = random_distort(img)
 
  142         if np.random.randint(2):
 
  143             img, param = transforms.random_expand(
 
  144                 img, fill=self.
mean, return_param=
True)
 
  145             bbox = transforms.translate_bbox(
 
  146                 bbox, y_offset=param[
"y_offset"], x_offset=param[
"x_offset"])
 
  149         img, param = random_crop_with_bbox_constraints(
 
  150             img, bbox, return_param=
True)
 
  151         bbox, param = transforms.crop_bbox(
 
  152             bbox, y_slice=param[
"y_slice"], x_slice=param[
"x_slice"],
 
  153             allow_outside_center=
False, return_param=
True)
 
  154         label = label[param[
"index"]]
 
  158         img = resize_with_random_interpolation(img, (self.
size, self.
size))
 
  159         bbox = transforms.resize_bbox(bbox, (H, W), (self.
size, self.
size))
 
  163         mb_loc, mb_lab = self.
coder.encode(bbox, label)
 
  165         return img, mb_loc, mb_lab
 
  168 if __name__ == 
'__main__':
 
  170     p = argparse.ArgumentParser()
 
  171     p.add_argument(
"label_file", help=
"path to label file")
 
  172     p.add_argument(
"train", help=
"path to train dataset directory")
 
  173     p.add_argument(
"--val", help=
"path to validation dataset directory. If this argument is not specified, train dataset is used with ratio train:val = 8:2.", default=
None)
 
  174     p.add_argument(
"--base-model", help=
"base model name", default=
"voc0712")
 
  175     p.add_argument(
"--batchsize", 
"-b", type=int, default=16)
 
  176     p.add_argument(
"--iteration", type=int, default=120000)
 
  177     p.add_argument(
"--gpu", 
"-g", type=int, default=-1)  
 
  178     p.add_argument(
"--out", 
"-o", type=str, default=
"results")
 
  179     p.add_argument(
"--resume", type=str, default=
"")
 
  180     p.add_argument(
"--lr", type=float, default=1e-4)
 
  181     p.add_argument(
"--val-iter", type=int, default=100)
 
  182     p.add_argument(
"--log-iter", type=int, default=10)
 
  183     p.add_argument(
"--model-iter", type=int, default=200)
 
  185     args = p.parse_args()
 
  188     with open(args.label_file, 
"r") 
as f:
 
  189         label_names = tuple(yaml.load(f))
 
  191     print(
"Loaded %d labels" % len(label_names))
 
  195         train, test = chainer.datasets.split_dataset_random(
 
  196             dataset, 
int(len(dataset) * 0.8))
 
  201     print(
"train: {}, test: {}".format(len(train), len(test)))
 
  203     pretrained_model = SSD300(pretrained_model=args.base_model)
 
  206     model = SSD300(n_fg_class=len(dataset.label_names))
 
  207     model.extractor.copyparams(pretrained_model.extractor)
 
  208     model.multibox.loc.copyparams(pretrained_model.multibox.loc)
 
  210     model.use_preset(
"evaluate")
 
  215         chainer.cuda.get_device(args.gpu).use()
 
  218     train = TransformDataset(
 
  219         train, 
Transform(model.coder, model.insize, model.mean))
 
  220     train_iter = chainer.iterators.MultiprocessIterator(
 
  221         train, args.batchsize)
 
  223     test_iter = chainer.iterators.SerialIterator(
 
  224         test, args.batchsize,
 
  225         repeat=
False, shuffle=
False)
 
  227     optimizer = chainer.optimizers.MomentumSGD(lr=args.lr)
 
  228     optimizer.setup(train_chain)
 
  230     for param 
in train_chain.params():
 
  231         if param.name == 
'b':
 
  232             param.update_rule.add_hook(GradientScaling(2))
 
  234             param.update_rule.add_hook(WeightDecay(0.0005))
 
  236     updater = training.StandardUpdater(
 
  237         train_iter, optimizer, device=args.gpu)
 
  238     trainer = training.Trainer(
 
  239         updater, (args.iteration, 
"iteration"), args.out)
 
  241     val_interval = args.val_iter, 
"iteration" 
  243         DetectionVOCEvaluator(
 
  244             test_iter, model, use_07_metric=
True,
 
  245             label_names=label_names),
 
  246         trigger=val_interval)
 
  248     log_interval = args.log_iter, 
"iteration" 
  249     trainer.extend(extensions.LogReport(trigger=log_interval))
 
  250     trainer.extend(extensions.observe_lr(), trigger=log_interval)
 
  251     trainer.extend(extensions.PrintReport(
 
  252         [
'epoch', 
'iteration', 
'lr',
 
  253          'main/loss', 
'main/loss/loc', 
'main/loss/conf',
 
  254          'validation/main/map']),
 
  255         trigger=log_interval)
 
  256     trainer.extend(extensions.ProgressBar(update_interval=10))
 
  258     trainer.extend(extensions.snapshot(), trigger=val_interval)
 
  260         extensions.snapshot_object(model, 
'model_iter_{.updater.iteration}'),
 
  261         trigger=(args.model_iter, 
'iteration'))
 
  264         serializers.load_npz(args.resume, trainer)