create_db_for_feature_based_object_recognition.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 import argparse
00004 import glob
00005 import os.path as osp
00006 
00007 import chainer
00008 from chainer import cuda
00009 import numpy as np
00010 import rospkg
00011 import skimage.io
00012 from sklearn.metrics import classification_report
00013 from sklearn.neighbors import KNeighborsClassifier
00014 import tqdm
00015 import yaml
00016 
00017 # TODO(wkentaro): Support Resnet50/101
00018 from jsk_recognition_utils.chainermodels import ResNet152Feature
00019 
00020 
00021 def get_templates(template_dir):
00022     target_names = yaml.load(open(osp.join(template_dir, 'target_names.yaml')))
00023     for cls_id, cls_name in enumerate(target_names):
00024         obj_dir = osp.join(template_dir, cls_name)
00025         for img_file in glob.glob(osp.join(obj_dir, '*.jpg')):
00026             dirname, basename = osp.split(img_file)
00027             mask_file = osp.join(dirname, 'masks', basename)
00028             img = skimage.io.imread(img_file)
00029             mask = skimage.io.imread(mask_file) >= 127
00030             yield cls_id, img, mask
00031 
00032 
00033 def main():
00034     parser = argparse.ArgumentParser()
00035     parser.add_argument('template_dir', help='Template dir')
00036     parser.add_argument('db_file', help='DB file which will be created')
00037     parser.add_argument('-g', '--gpu', type=int, default=0)
00038     args = parser.parse_args()
00039 
00040     template_dir = args.template_dir
00041     db_file = args.db_file
00042     gpu = args.gpu
00043 
00044     pkg_path = rospkg.RosPack().get_path('jsk_perception')
00045     mean_file = osp.join(pkg_path, 'trained_data/resnet_lsvrc2012_mean.npy')
00046     pretrained_model = osp.join(
00047         pkg_path, 'trained_data/resnet152_from_caffe.npz')
00048 
00049     target_names = yaml.load(open(osp.join(template_dir, 'target_names.yaml')))
00050 
00051     mean = np.load(mean_file)
00052 
00053     model = ResNet152Feature()
00054     chainer.serializers.load_npz(pretrained_model, model)
00055     if gpu >= 0:
00056         chainer.cuda.get_device_from_id(gpu).use()
00057         model.to_gpu()
00058 
00059     chainer.global_config.train = False
00060     chainer.global_config.enable_backprop = False
00061 
00062     ###########################################################################
00063 
00064     X = []
00065     y = []
00066     for cls_id, img, mask in tqdm.tqdm(get_templates(template_dir)):
00067         img = img[:, :, ::-1]  # RGB- > BGR
00068         img = img.astype(np.float64)
00069         img[mask] -= mean[mask]
00070         img[~mask] = 0
00071 
00072         img = img.transpose(2, 0, 1)
00073         img = img.astype(np.float32)
00074         x_data = np.asarray([img])
00075         x_data = cuda.to_gpu(x_data)
00076         x = chainer.Variable(x_data)
00077         feat = model(x)
00078 
00079         feat = cuda.to_cpu(feat.data)
00080         feat = feat.squeeze(axis=(2, 3))
00081         for f in feat:
00082             X.append(f)
00083             y.append(cls_id)
00084     X = np.asarray(X)
00085     y = np.asarray(y)
00086     np.savez_compressed(db_file, X=X, y=y, target_names=target_names)
00087 
00088     knn = KNeighborsClassifier(n_neighbors=1)
00089     knn.fit(X, y)
00090     y_pred = knn.predict(X)
00091     # validation: must be all 1.0
00092     print(classification_report(y, y_pred, labels=range(len(target_names)),
00093                                 target_names=target_names))
00094 
00095 
00096 if __name__ == '__main__':
00097     main()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07