3 from __future__
import print_function
9 import itertools, pkg_resources, sys
10 from distutils.version
import LooseVersion
11 if LooseVersion(pkg_resources.get_distribution(
"chainer").version) >= LooseVersion(
'7.0.0')
and \
12 sys.version_info.major == 2:
13 print(
'''Please install chainer < 7.0.0: 15 sudo pip install chainer==6.7.0 17 c.f https://github.com/jsk-ros-pkg/jsk_recognition/pull/2485 20 if [p
for p
in list(itertools.chain(*[pkg_resources.find_distributions(_)
for _
in sys.path]))
if "cupy-" in p.project_name ] == []:
21 print(
'''Please install CuPy 23 sudo pip install cupy-cuda[your cuda version] 25 sudo pip install cupy-cuda91 30 from chainer
import cuda
34 from sklearn.metrics
import classification_report
35 from sklearn.neighbors
import KNeighborsClassifier
44 target_names = yaml.load(open(osp.join(template_dir,
'target_names.yaml')))
45 for cls_id, cls_name
in enumerate(target_names):
46 obj_dir = osp.join(template_dir, cls_name)
47 for img_file
in glob.glob(osp.join(obj_dir,
'*.jpg')):
48 dirname, basename = osp.split(img_file)
49 mask_file = osp.join(dirname,
'masks', basename)
50 img = skimage.io.imread(img_file)
51 mask = skimage.io.imread(mask_file) >= 127
52 yield cls_id, img, mask
56 parser = argparse.ArgumentParser()
57 parser.add_argument(
'template_dir', help=
'Template dir')
58 parser.add_argument(
'db_file', help=
'DB file which will be created')
59 parser.add_argument(
'-g',
'--gpu', type=int, default=0)
60 args = parser.parse_args()
62 template_dir = args.template_dir
63 db_file = args.db_file
66 pkg_path = rospkg.RosPack().get_path(
'jsk_perception')
67 mean_file = osp.join(pkg_path,
'trained_data/resnet_lsvrc2012_mean.npy')
68 pretrained_model = osp.join(
69 pkg_path,
'trained_data/resnet152_from_caffe.npz')
71 target_names = yaml.load(open(osp.join(template_dir,
'target_names.yaml')))
73 mean = np.load(mean_file)
75 model = ResNet152Feature()
76 chainer.serializers.load_npz(pretrained_model, model)
78 chainer.cuda.get_device_from_id(gpu).use()
81 chainer.global_config.train =
False 82 chainer.global_config.enable_backprop =
False 88 for cls_id, img, mask
in tqdm.tqdm(
get_templates(template_dir)):
90 img = img.astype(np.float64)
91 img[mask] -= mean[mask]
94 img = img.transpose(2, 0, 1)
95 img = img.astype(np.float32)
96 x_data = np.asarray([img])
97 x_data = cuda.to_gpu(x_data)
98 x = chainer.Variable(x_data)
101 feat = cuda.to_cpu(feat.data)
102 feat = feat.squeeze(axis=(2, 3))
108 np.savez_compressed(db_file, X=X, y=y, target_names=target_names)
110 knn = KNeighborsClassifier(n_neighbors=1)
112 y_pred = knn.predict(X)
114 print(classification_report(y, y_pred, labels=range(len(target_names)),
115 target_names=target_names))
118 if __name__ ==
'__main__':
def get_templates(template_dir)