train.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 # mainly copied from chainer/train_imagenet.py
4 # https://github.com/chainer/chainer/blob/master/examples/imagenet/train_imagenet.py
5 
6 import argparse
7 
8 import chainer
9 from chainer import dataset
10 from chainer import training
11 from chainer.links import VGG16Layers
12 import chainer.backends.cuda
13 from chainer.serializers import npz
14 from chainer.training import extensions
15 
16 from sound_classification.nin.nin import NIN
17 from sound_classification.vgg16.vgg16_batch_normalization import VGG16BatchNormalization
18 
19 import matplotlib
20 import numpy as np
21 from os import makedirs
22 import os.path as osp
23 from PIL import Image as Image_
24 import rospkg
25 
26 matplotlib.use('Agg') # necessary not to raise Tcl_AsyncDelete Error
27 
28 
29 class PreprocessedDataset(chainer.dataset.DatasetMixin):
30 
31  def __init__(self, path=None, random=True):
32  rospack = rospkg.RosPack()
33  # Root directory path of train data
34  self.root = osp.join(rospack.get_path(
35  'sound_classification'), 'train_data')
36  if path is not None:
37  self.base = chainer.datasets.LabeledImageDataset(
38  path, osp.join(self.root, 'dataset'))
39  self.random = random
40  # how many classes to be classified
41  self.n_class = 0
42  self.target_classes = []
43  with open(osp.join(self.root, 'n_class.txt'), mode='r') as f:
44  for row in f:
45  self.n_class += 1
46  self.target_classes.append(row)
47  # Load mean image of dataset
48  mean_img_path = osp.join(rospack.get_path('sound_classification'),
49  'train_data', 'dataset', 'mean_of_dataset.png')
50  mean = np.array(Image_.open(mean_img_path), np.float32).transpose(
51  (2, 0, 1)) # (height, width, channel) -> (channel ,height, width), rgb
52  self.mean = mean.astype(chainer.get_dtype())
53 
54  def __len__(self):
55  return len(self.base)
56 
57  def get_example(self, i):
58  image, label = self.base[i] # (channel ,height, width), rgb
59  image = self.process_image(image)
60  return image, label
61 
62  def process_image(self, image):
63  ret = image - self.mean # Subtract mean image, (channel ,height, width), rgb
64  ret *= (1.0 / 255.0) # Scale to [0, 1.0]
65  return ret
66 
67 
68 def load_model(model_name, n_class):
69  archs = {
70  'nin': NIN,
71  'vgg16': VGG16BatchNormalization
72  }
73  model = archs[model_name](n_class=n_class)
74  if model_name == 'nin':
75  pass
76  elif model_name == 'vgg16':
77  rospack = rospkg.RosPack()
78  model_path = osp.join(rospack.get_path('sound_classification'), 'scripts',
79  'vgg16', 'VGG_ILSVRC_16_layers.npz')
80  if not osp.exists(model_path):
81  from chainer.dataset import download
82  from chainer.links.caffe.caffe_function import CaffeFunction
83  path_caffemodel = download.cached_download('http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel')
84  caffemodel = CaffeFunction(path_caffemodel)
85  npz.save_npz(model_path, caffemodel, compression=False)
86 
87  vgg16 = VGG16Layers(pretrained_model=model_path) # original VGG16 model
88  print('Load model from {}'.format(model_path))
89  for l in model.children():
90  if l.name.startswith('conv'):
91  # l.disable_update() # Comment-in for transfer learning, comment-out for fine tuning
92  l1 = getattr(vgg16, l.name)
93  l2 = getattr(model, l.name)
94  assert l1.W.shape == l2.W.shape
95  assert l1.b.shape == l2.b.shape
96  l2.W.data[...] = l1.W.data[...]
97  l2.b.data[...] = l1.b.data[...]
98  elif l.name in ['fc6', 'fc7']:
99  l1 = getattr(vgg16, l.name)
100  l2 = getattr(model, l.name)
101  assert l1.W.size == l2.W.size
102  assert l1.b.size == l2.b.size
103  l2.W.data[...] = l1.W.data.reshape(l2.W.shape)[...]
104  l2.b.data[...] = l1.b.data.reshape(l2.b.shape)[...]
105  else:
106  print('Model type {} is invalid.'.format(model_name))
107  exit()
108 
109  return model
110 
111 
112 def main():
113  rospack = rospkg.RosPack()
114 
115  parser = argparse.ArgumentParser(
116  description='Learning convnet from ILSVRC2012 dataset')
117  parser.add_argument('--epoch', '-e', type=int, default=100,
118  help='Number of epochs to train')
119  parser.add_argument('--gpu', '-g', type=int, default=0,
120  help='GPU ID (negative value indicates CPU)')
121  parser.add_argument('-m', '--model', type=str,
122  choices=['nin', 'vgg16'], default='nin',
123  help='Neural network model to use dataset')
124  # Ignore arguments sent by roslaunch.
125  parser.add_argument('__name:', help=argparse.SUPPRESS, nargs='?')
126  parser.add_argument('__log:', help=argparse.SUPPRESS, nargs='?')
127 
128  args = parser.parse_args()
129 
130  # Configs for train with chainer
131  if args.gpu >= 0:
132  device = chainer.cuda.get_device_from_id(args.gpu) # for python2
133  else:
134  device = None
135  batchsize = 32
136  # Path to training image-label list file
137  train_labels = osp.join(rospack.get_path('sound_classification'),
138  'train_data', 'dataset', 'train_images.txt')
139  # Path to validation image-label list file
140  val_labels = osp.join(rospack.get_path('sound_classification'),
141  'train_data', 'dataset', 'test_images.txt')
142 
143  # Initialize the model to train
144  print('Device: {}'.format(device))
145  print('Model: {}'.format(args.model))
146  print('Dtype: {}'.format(chainer.config.dtype))
147  print('Minibatch-size: {}'.format(batchsize))
148  print('epoch: {}'.format(args.epoch))
149  print('')
150 
151  # Load the dataset files
152  train = PreprocessedDataset(train_labels)
153  val = PreprocessedDataset(val_labels, False)
154 
155  model = load_model(args.model, train.n_class)
156  if device is not None:
157  if hasattr(model, 'to_device'):
158  model.to_device(device)
159  device.use()
160  else:
161  model.to_gpu(device)
162 
163  # These iterators load the images with subprocesses running in parallel
164  # to the training/validation.
165  train_iter = chainer.iterators.MultiprocessIterator(
166  train, batchsize)
167  val_iter = chainer.iterators.MultiprocessIterator(
168  val, batchsize, repeat=False)
169  converter = dataset.concat_examples
170 
171  # Set up an optimizer
172  optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
173  optimizer.setup(model)
174 
175  # Set up a trainer
176  # Output directory of train result
177  out = osp.join(rospack.get_path('sound_classification'),
178  'train_data', 'result', args.model)
179  if not osp.exists(out):
180  makedirs(out)
181  updater = training.updaters.StandardUpdater(
182  train_iter, optimizer, converter=converter, device=device)
183  trainer = training.Trainer(updater, (args.epoch, 'epoch'), out)
184 
185  val_interval = 10, 'iteration'
186  log_interval = 10, 'iteration'
187 
188  trainer.extend(extensions.Evaluator(val_iter, model, converter=converter,
189  device=device), trigger=val_interval)
190  trainer.extend(extensions.snapshot_object(
191  target=model, filename='model_best.npz'),
192  trigger=chainer.training.triggers.MinValueTrigger(
193  key='validation/main/loss',
194  trigger=val_interval))
195  # Be careful to pass the interval directly to LogReport
196  # (it determines when to emit log rather than when to read observations)
197  trainer.extend(extensions.LogReport(trigger=log_interval))
198  trainer.extend(extensions.observe_lr(), trigger=log_interval)
199  trainer.extend(extensions.PrintReport([
200  'epoch', 'iteration', 'main/loss', 'validation/main/loss',
201  'main/accuracy', 'validation/main/accuracy', 'lr'
202  ]), trigger=log_interval)
203  trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='iteration', file_name='loss.png'))
204  trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='iteration', file_name='accuracy.png'))
205  trainer.extend(extensions.ProgressBar(update_interval=10))
206 
207  trainer.run()
208 
209 
210 if __name__ == '__main__':
211  main()
sound_classification.vgg16.vgg16_batch_normalization
Definition: vgg16_batch_normalization.py:1
train.PreprocessedDataset.get_example
def get_example(self, i)
Definition: train.py:57
train.main
def main()
Definition: train.py:112
train.PreprocessedDataset.process_image
def process_image(self, image)
Definition: train.py:62
train.PreprocessedDataset.n_class
n_class
Definition: train.py:41
train.PreprocessedDataset.random
random
Definition: train.py:39
train.PreprocessedDataset.__init__
def __init__(self, path=None, random=True)
Definition: train.py:31
train.PreprocessedDataset.root
root
Definition: train.py:34
train.load_model
def load_model(model_name, n_class)
Definition: train.py:68
sound_classification.nin.nin
Definition: nin.py:1
train.PreprocessedDataset.mean
mean
Definition: train.py:52
train.PreprocessedDataset.__len__
def __len__(self)
Definition: train.py:54
train.PreprocessedDataset.target_classes
target_classes
Definition: train.py:42
train.PreprocessedDataset.base
base
Definition: train.py:37
train.PreprocessedDataset
Definition: train.py:29


sound_classification
Author(s):
autogenerated on Fri May 16 2025 03:12:55