3 from __future__
import print_function
10 os.environ[
'MPLBACKEND'] =
'Agg' 12 import itertools, pkg_resources, sys
13 from distutils.version
import LooseVersion
14 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0')
and \
15 sys.version_info.major == 2:
16 print(
'''Please install chainer < 7.0.0: 18 sudo pip install chainer==6.7.0 20 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 23 if [p
for p
in list(itertools.chain(*[pkg_resources.find_distributions(_)
for _
in sys.path]))
if "cupy-" in p.project_name ] == []:
24 print(
'''Please install CuPy 26 sudo pip install cupy-cuda[your cuda version] 28 sudo pip install cupy-cuda91 33 from chainer
import cuda
34 from chainer.datasets
import TransformDataset
35 import chainer.serializers
as S
36 from chainer.training
import extensions
47 rospack = rospkg.RosPack()
48 jsk_perception_datasets_path = osp.join(
49 rospack.get_path(
'jsk_perception'),
'learning_datasets')
51 parser = argparse.ArgumentParser()
54 parser.add_argument(
'--train_dataset_dir', type=str,
55 default=osp.join(jsk_perception_datasets_path,
56 'kitchen_dataset',
'train'))
57 parser.add_argument(
'--val_dataset_dir', type=str,
58 default=osp.join(jsk_perception_datasets_path,
59 'kitchen_dataset',
'test'))
63 '--model_name', type=str, default=
'fcn32s',
64 choices=[
'fcn32s',
'fcn16s',
'fcn8s',
'fcn8s_at_once'])
67 parser.add_argument(
'--gpu', type=int, default=0)
68 parser.add_argument(
'--batch_size', type=int, default=1)
69 parser.add_argument(
'--max_epoch', type=int, default=100)
70 parser.add_argument(
'--lr', type=float, default=1e-10)
71 parser.add_argument(
'--weight_decay', type=float, default=0.0001)
72 parser.add_argument(
'--out_dir', type=str, default=
None)
73 parser.add_argument(
'--progressbar_update_interval', type=float,
75 parser.add_argument(
'--print_interval', type=float, default=100)
76 parser.add_argument(
'--print_interval_type', type=str,
78 choices=[
'epoch',
'iteration'])
79 parser.add_argument(
'--log_interval', type=float, default=10)
80 parser.add_argument(
'--log_interval_type', type=str,
82 choices=[
'epoch',
'iteration'])
83 parser.add_argument(
'--plot_interval', type=float, default=5)
84 parser.add_argument(
'--plot_interval_type', type=str,
86 choices=[
'epoch',
'iteration'])
87 parser.add_argument(
'--eval_interval', type=float, default=10)
88 parser.add_argument(
'--eval_interval_type', type=str,
90 choices=[
'epoch',
'iteration'])
91 parser.add_argument(
'--save_interval', type=float, default=10)
92 parser.add_argument(
'--save_interval_type', type=str,
94 choices=[
'epoch',
'iteration'])
96 args = parser.parse_args()
119 now = datetime.datetime.now()
121 timestamp = now.strftime(
'%Y%m%d-%H%M%S')
124 rospkg.get_ros_home(),
'learning_logs', timestamp)
140 rgb_img, lbl = in_data
142 mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
143 bgr_img = rgb_img[:, :, ::-1]
144 bgr_img = bgr_img.astype(np.float32)
147 bgr_img = bgr_img.transpose((2, 0, 1))
152 train_dataset_transformed = TransformDataset(
154 val_dataset_transformed = TransformDataset(
157 train_dataset_transformed, batch_size=self.
batch_size,
160 val_dataset_transformed, batch_size=self.
batch_size,
161 shared_mem=10 ** 7, repeat=
False, shuffle=
False)
164 n_class = len(self.train_dataset.class_names)
166 self.
model = fcn.models.FCN32s(n_class=n_class)
167 vgg = fcn.models.VGG16()
168 vgg_path = vgg.download()
169 S.load_npz(vgg_path, vgg)
170 self.model.init_from_vgg16(vgg)
172 self.
model = fcn.models.FCN16s(n_class=n_class)
173 fcn32s = fcn.models.FCN32s()
174 fcn32s_path = fcn32s.download()
175 S.load_npz(fcn32s_path, fcn32s)
176 self.model.init_from_fcn32s(fcn32s_path, fcn32s)
178 self.
model = fcn.models.FCN8s(n_class=n_class)
179 fcn16s = fcn.models.FCN16s()
180 fcn16s_path = fcn16s.download()
181 S.load_npz(fcn16s_path, fcn16s)
182 self.model.init_from_fcn16s(fcn16s_path, fcn16s)
184 self.
model = fcn.models.FCN8sAtOnce(n_class=n_class)
185 vgg = fcn.models.VGG16()
186 vgg_path = vgg.download()
187 S.load_npz(vgg_path, vgg)
188 self.model.init_from_vgg16(vgg)
191 'Unsupported model_name: {}'.format(self.
model_name))
194 cuda.get_device_from_id(self.
gpu).use()
199 lr=self.
lr, momentum=0.9)
200 self.optimizer.setup(self.
model)
201 self.optimizer.add_hook(
205 self.
updater = chainer.training.updater.StandardUpdater(
211 extensions.Evaluator(
217 extensions.snapshot_object(
220 filename=
'model_snapshot.npz'),
221 trigger=chainer.training.triggers.MinValueTrigger(
222 'validation/main/loss',
227 extensions.dump_graph(
228 root_name=
'main/loss',
229 out_name=
'network_architecture.dot'))
233 extensions.ProgressBar(
236 extensions.observe_lr(),
239 extensions.LogReport(
243 extensions.PrintReport([
249 'validation/main/loss',
254 extensions.PlotReport([
256 'validation/main/loss',
258 file_name=
'loss_plot.png',
268 params[
'class_names'] = self.train_dataset.class_names
270 params[
'out_dir'] = self.
out_dir 271 params[
'gpu'] = self.
gpu 274 params[
'lr'] = self.
lr 277 fcn.extensions.ParamsReport(params, file_name=
'params.yaml'))
283 fcn.extensions.ParamsReport(
284 model_name, file_name=
'model_name.yaml'))
285 target_names = dict()
286 target_names[
'target_names'] = self.train_dataset.class_names
288 fcn.extensions.ParamsReport(
289 target_names, file_name=
'target_names.yaml'))
292 if __name__ ==
'__main__':
progressbar_update_interval
def setup_optimizer(self)
def transform_dataset(self, in_data)