3 from __future__
import print_function
11 import itertools, pkg_resources, sys
12 from distutils.version
import LooseVersion
13 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0')
and \
14 sys.version_info.major == 2:
15 print(
'''Please install chainer < 7.0.0: 17 sudo pip install chainer==6.7.0 19 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 22 if [p
for p
in list(itertools.chain(*[pkg_resources.find_distributions(_)
for _
in sys.path]))
if "cupy-" in p.project_name ] == []:
23 print(
'''Please install CuPy 25 sudo pip install cupy-cuda[your cuda version] 27 sudo pip install cupy-cuda91 32 from chainer.datasets
import TransformDataset
33 from chainer.optimizer_hooks
import WeightDecay
34 from chainer
import training
35 from chainer.training
import extensions
36 from chainer.training
import triggers
38 from chainercv.extensions
import DetectionVOCEvaluator
39 from chainercv.links.model.ssd
import GradientScaling
40 from chainercv.links.model.ssd
import multibox_loss
41 from chainercv.links
import SSD300
42 from chainercv.links
import SSD512
43 from chainercv
import transforms
45 from chainercv.links.model.ssd
import random_crop_with_bbox_constraints
46 from chainercv.links.model.ssd
import random_distort
47 from chainercv.links.model.ssd
import resize_with_random_interpolation
56 chainer.config.cv_resize_backend =
'cv2' 62 super(MultiboxTrainChain, self).
__init__()
63 with self.init_scope():
68 def forward(self, imgs, gt_mb_locs, gt_mb_labels):
69 mb_locs, mb_confs = self.
model(imgs)
70 loc_loss, conf_loss = multibox_loss(
71 mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.
k)
72 loss = loc_loss * self.
alpha + conf_loss
74 chainer.reporter.report(
75 {
'loss': loss,
'loss/loc': loc_loss,
'loss/conf': conf_loss},
99 img, bbox, label = in_data
102 img = random_distort(img)
105 if np.random.randint(2):
106 img, param = transforms.random_expand(
107 img, fill=self.
mean, return_param=
True)
108 bbox = transforms.translate_bbox(
109 bbox, y_offset=param[
'y_offset'], x_offset=param[
'x_offset'])
112 img, param = random_crop_with_bbox_constraints(
113 img, bbox, return_param=
True)
114 bbox, param = transforms.crop_bbox(
115 bbox, y_slice=param[
'y_slice'], x_slice=param[
'x_slice'],
116 allow_outside_center=
False, return_param=
True)
117 label = label[param[
'index']]
121 img = resize_with_random_interpolation(img, (self.
size, self.
size))
122 bbox = transforms.resize_bbox(bbox, (H, W), (self.
size, self.
size))
125 img, params = transforms.random_flip(
126 img, x_random=
True, return_param=
True)
127 bbox = transforms.flip_bbox(
128 bbox, (self.
size, self.
size), x_flip=params[
'x_flip'])
132 mb_loc, mb_label = self.coder.encode(bbox, label)
134 return img, mb_loc, mb_label
138 rospack = rospkg.RosPack()
139 jsk_perception_datasets_path = osp.join(
140 rospack.get_path(
'jsk_perception'),
'learning_datasets')
142 parser = argparse.ArgumentParser()
144 parser.add_argument(
'--train-dataset-dir', type=str,
145 default=osp.join(jsk_perception_datasets_path,
146 'kitchen_dataset',
'train'))
147 parser.add_argument(
'--val-dataset-dir', type=str,
148 default=osp.join(jsk_perception_datasets_path,
149 'kitchen_dataset',
'test'))
150 parser.add_argument(
'--dataset-type', type=str,
153 '--model-name', choices=(
'ssd300',
'ssd512'), default=
'ssd512')
154 parser.add_argument(
'--gpu', type=int, default=0)
155 parser.add_argument(
'--batch-size', type=int, default=8)
156 parser.add_argument(
'--max-epoch', type=int, default=100)
157 parser.add_argument(
'--out-dir', type=str, default=
None)
158 args = parser.parse_args()
160 if (args.dataset_type ==
'instance'):
161 train_dataset = DetectionDataset(args.train_dataset_dir)
162 elif (args.dataset_type ==
'bbox'):
163 train_dataset = BboxDetectionDataset(args.train_dataset_dir)
165 print(
'unsuppported dataset type')
168 fg_label_names = train_dataset.fg_class_names
170 if args.model_name ==
'ssd300':
172 n_fg_class=len(fg_label_names),
173 pretrained_model=
'imagenet')
174 elif args.model_name ==
'ssd512':
176 n_fg_class=len(fg_label_names),
177 pretrained_model=
'imagenet')
179 model.use_preset(
'evaluate')
182 chainer.cuda.get_device_from_id(args.gpu).use()
185 train = TransformDataset(
187 Transform(model.coder, model.insize, model.mean))
188 train_iter = chainer.iterators.MultiprocessIterator(train, args.batch_size)
190 if (args.dataset_type ==
'instance'):
191 test_dataset = DetectionDataset(args.val_dataset_dir)
192 elif (args.dataset_type ==
'bbox'):
193 test_dataset = BboxDetectionDataset(args.val_dataset_dir)
195 test_iter = chainer.iterators.SerialIterator(
196 test_dataset, args.batch_size, repeat=
False, shuffle=
False)
199 optimizer = chainer.optimizers.MomentumSGD()
200 optimizer.setup(train_chain)
201 for param
in train_chain.params():
202 if param.name ==
'b':
203 param.update_rule.add_hook(GradientScaling(2))
205 param.update_rule.add_hook(WeightDecay(0.0005))
207 updater = training.updaters.StandardUpdater(
208 train_iter, optimizer, device=args.gpu)
210 now = datetime.datetime.now()
211 timestamp = now.strftime(
'%Y%m%d-%H%M%S')
212 if args.out_dir
is None:
214 rospkg.get_ros_home(),
'learning_logs', timestamp)
216 out_dir = args.out_dir
218 step_epoch = [args.max_epoch * 2 // 3, args.max_epoch * 5 // 6]
219 trainer = training.Trainer(
220 updater, (args.max_epoch,
'epoch'), out_dir)
222 extensions.ExponentialShift(
'lr', 0.1, init=1e-3),
223 trigger=triggers.ManualScheduleTrigger(step_epoch,
'epoch'))
226 DetectionVOCEvaluator(
227 test_iter, model, use_07_metric=
True,
228 label_names=fg_label_names),
229 trigger=triggers.ManualScheduleTrigger(
230 step_epoch + [args.max_epoch],
'epoch'))
232 log_interval = 10,
'iteration' 234 extensions.LogReport(log_name=
'log.json', trigger=log_interval))
235 trainer.extend(extensions.observe_lr(), trigger=log_interval)
236 trainer.extend(extensions.PrintReport(
237 [
'epoch',
'iteration',
'lr',
238 'main/loss',
'main/loss/loc',
'main/loss/conf',
239 'validation/main/map']),
240 trigger=log_interval)
241 trainer.extend(extensions.ProgressBar(update_interval=10))
244 extensions.snapshot_object(
245 model,
'model_snapshot.npz'),
246 trigger=(args.max_epoch,
'epoch'))
251 if __name__ ==
'__main__':
def forward(self, imgs, gt_mb_locs, gt_mb_labels)
def __init__(self, model, alpha=1, k=3)