9 from chainer
import dataset
10 from chainer
import training
11 from chainer.links
import VGG16Layers
12 import chainer.backends.cuda
13 from chainer.serializers
import npz
14 from chainer.training
import extensions
21 from os
import makedirs
23 from PIL
import Image
as Image_
32 rospack = rospkg.RosPack()
34 self.
root = osp.join(rospack.get_path(
35 'sound_classification'),
'train_data')
37 self.
base = chainer.datasets.LabeledImageDataset(
38 path, osp.join(self.
root,
'dataset'))
43 with open(osp.join(self.
root,
'n_class.txt'), mode=
'r')
as f:
48 mean_img_path = osp.join(rospack.get_path(
'sound_classification'),
49 'train_data',
'dataset',
'mean_of_dataset.png')
50 mean = np.array(Image_.open(mean_img_path), np.float32).transpose(
52 self.
mean = mean.astype(chainer.get_dtype())
58 image, label = self.
base[i]
63 ret = image - self.
mean
71 'vgg16': VGG16BatchNormalization
73 model = archs[model_name](n_class=n_class)
74 if model_name ==
'nin':
76 elif model_name ==
'vgg16':
77 rospack = rospkg.RosPack()
78 model_path = osp.join(rospack.get_path(
'sound_classification'),
'scripts',
79 'vgg16',
'VGG_ILSVRC_16_layers.npz')
80 if not osp.exists(model_path):
81 from chainer.dataset
import download
82 from chainer.links.caffe.caffe_function
import CaffeFunction
83 path_caffemodel = download.cached_download(
'http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel')
84 caffemodel = CaffeFunction(path_caffemodel)
85 npz.save_npz(model_path, caffemodel, compression=
False)
87 vgg16 = VGG16Layers(pretrained_model=model_path)
88 print(
'Load model from {}'.format(model_path))
89 for l
in model.children():
90 if l.name.startswith(
'conv'):
92 l1 = getattr(vgg16, l.name)
93 l2 = getattr(model, l.name)
94 assert l1.W.shape == l2.W.shape
95 assert l1.b.shape == l2.b.shape
96 l2.W.data[...] = l1.W.data[...]
97 l2.b.data[...] = l1.b.data[...]
98 elif l.name
in [
'fc6',
'fc7']:
99 l1 = getattr(vgg16, l.name)
100 l2 = getattr(model, l.name)
101 assert l1.W.size == l2.W.size
102 assert l1.b.size == l2.b.size
103 l2.W.data[...] = l1.W.data.reshape(l2.W.shape)[...]
104 l2.b.data[...] = l1.b.data.reshape(l2.b.shape)[...]
106 print(
'Model type {} is invalid.'.format(model_name))
113 rospack = rospkg.RosPack()
115 parser = argparse.ArgumentParser(
116 description=
'Learning convnet from ILSVRC2012 dataset')
117 parser.add_argument(
'--epoch',
'-e', type=int, default=100,
118 help=
'Number of epochs to train')
119 parser.add_argument(
'--gpu',
'-g', type=int, default=0,
120 help=
'GPU ID (negative value indicates CPU)')
121 parser.add_argument(
'-m',
'--model', type=str,
122 choices=[
'nin',
'vgg16'], default=
'nin',
123 help=
'Neural network model to use dataset')
125 parser.add_argument(
'__name:', help=argparse.SUPPRESS, nargs=
'?')
126 parser.add_argument(
'__log:', help=argparse.SUPPRESS, nargs=
'?')
128 args = parser.parse_args()
132 device = chainer.cuda.get_device_from_id(args.gpu)
137 train_labels = osp.join(rospack.get_path(
'sound_classification'),
138 'train_data',
'dataset',
'train_images.txt')
140 val_labels = osp.join(rospack.get_path(
'sound_classification'),
141 'train_data',
'dataset',
'test_images.txt')
144 print(
'Device: {}'.format(device))
145 print(
'Model: {}'.format(args.model))
146 print(
'Dtype: {}'.format(chainer.config.dtype))
147 print(
'Minibatch-size: {}'.format(batchsize))
148 print(
'epoch: {}'.format(args.epoch))
156 if device
is not None:
157 if hasattr(model,
'to_device'):
158 model.to_device(device)
165 train_iter = chainer.iterators.MultiprocessIterator(
167 val_iter = chainer.iterators.MultiprocessIterator(
168 val, batchsize, repeat=
False)
169 converter = dataset.concat_examples
172 optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
173 optimizer.setup(model)
177 out = osp.join(rospack.get_path(
'sound_classification'),
178 'train_data',
'result', args.model)
179 if not osp.exists(out):
181 updater = training.updaters.StandardUpdater(
182 train_iter, optimizer, converter=converter, device=device)
183 trainer = training.Trainer(updater, (args.epoch,
'epoch'), out)
185 val_interval = 10,
'iteration'
186 log_interval = 10,
'iteration'
188 trainer.extend(extensions.Evaluator(val_iter, model, converter=converter,
189 device=device), trigger=val_interval)
190 trainer.extend(extensions.snapshot_object(
191 target=model, filename=
'model_best.npz'),
192 trigger=chainer.training.triggers.MinValueTrigger(
193 key=
'validation/main/loss',
194 trigger=val_interval))
197 trainer.extend(extensions.LogReport(trigger=log_interval))
198 trainer.extend(extensions.observe_lr(), trigger=log_interval)
199 trainer.extend(extensions.PrintReport([
200 'epoch',
'iteration',
'main/loss',
'validation/main/loss',
201 'main/accuracy',
'validation/main/accuracy',
'lr'
202 ]), trigger=log_interval)
203 trainer.extend(extensions.PlotReport([
'main/loss',
'validation/main/loss'], x_key=
'iteration', file_name=
'loss.png'))
204 trainer.extend(extensions.PlotReport([
'main/accuracy',
'validation/main/accuracy'], x_key=
'iteration', file_name=
'accuracy.png'))
205 trainer.extend(extensions.ProgressBar(update_interval=10))
210 if __name__ ==
'__main__':