00001 #!/usr/bin/env python
00002 # -*- coding: utf-8 -*-
00003 """
00004 Extract Bag of Features and create BoF Histograms
00006     usage: extract_bof.py [-h] {fit,extract,dataset} ...
00008     positional arguments:
00009     {fit,extract,dataset}
00010         fit                 fit feature extractor using dataset.
00011         dataset             create bof histogram dataset.
00012         extract             extract feature in realtime.
00014 """
00016 import os
00017 import sys
00018 import gzip
00019 import cPickle as pickle
00020 import argparse
00022 import progressbar
00023 import numpy as np
00024 from sklearn.datasets.base import Bunch
00025 from sklearn.preprocessing import normalize
00026 from sklearn.neighbors import NearestNeighbors
00027 from sklearn.cluster import MiniBatchKMeans
00029 import rospy
00030 from posedetection_msgs.msg import Feature0D
00031 from jsk_recognition_msgs.msg import Histogram
00034 class BagOfFeatures(object):
00035     def __init__(self, hist_size=500):
00036         self.nn = None
00037         self.hist_size = hist_size
00039     def fit(self, X):
00040         """Fit features and extract bag of features"""
00041         k = self.hist_size
00042         km = MiniBatchKMeans(n_clusters=k, init_size=3*k, max_iter=300)
00043         km.fit(X)
00044         nn = NearestNeighbors(n_neighbors=1)
00045         nn.fit(km.cluster_centers_)
00046         self.nn = nn
00048     def transform(self, X):
00049         return np.vstack([self.make_hist(xi.reshape((-1, 128))) for xi in X])
00051     def make_hist(self, descriptors):
00052         """Make histogram for one image"""
00053         nn = self.nn
00054         if nn is None:
00055             raise ValueError('must fit features before making histogram')
00056         indices = nn.kneighbors(descriptors, return_distance=False)
00057         histogram = np.zeros(self.hist_size)
00058         for idx in np.unique(indices):
00059             mask = indices == idx
00060             histogram[idx] = mask.sum()  # count the idx
00061             indices = indices[mask == False]
00062         return histogram
00065 def create_dataset(data_path, bof_path, bof_hist_path):
00066     print('creating dataset')
00067     with gzip.open(data_path, 'rb') as f:
00068         dataset = pickle.load(f)
00069     descs, y, target_names = (dataset['descriptors'],
00070                               dataset['target'],
00071                               dataset['target_names'])
00072     print('extracting feature')
00073     with gzip.open(bof_path, 'rb') as f:
00074         bof = pickle.load(f)
00075     X = bof.transform(descs)
00076     normalize(X, copy=False)
00077     dataset = Bunch(data=X, target=y, target_names=target_names)
00078     print('saving dataset')
00079     with gzip.open(bof_hist_path, 'wb') as f:
00080         pickle.dump(dataset, f)
00083 def fit_bof_extractor(data_path, bof_path):
00084     print('loading data')
00085     with gzip.open(data_path, 'rb') as f:
00086         dataset = pickle.load(f)
00087     descs = dataset['descriptors']
00088     X = np.vstack(map(lambda x: np.array(x).reshape((-1, 128)), descs))
00089     # extract feature
00090     print('fitting bag of features extractor')
00091     bof = BagOfFeatures()
00092     bof.fit(X)
00093     # save bof extractor
00094     print('saving bof')
00095     with gzip.open(bof_path, 'wb') as f:
00096         pickle.dump(bof, f)
00099 class ExtractInRealtime(object):
00100     def __init__(self, bof_path):
00101         with gzip.open(bof_path, 'rb') as f:
00102             self.bof = pickle.load(f)
00103         self.pub = rospy.Publisher('~output/bof_hist', Histogram, queue_size=1)
00104         rospy.Subscriber('Feature0D', Feature0D, self._cb_feature0d)
00105         rospy.loginfo('Initialized bof extractor')
00107     def _cb_feature0d(self, msg):
00108         desc = np.array(msg.descriptors)
00109         X = self.bof.transform([desc])
00110         normalize(X, copy=False)
00111         self.pub.publish(Histogram(header=msg.header, histogram=X[0]))
00114 def main():
00115     parser = argparse.ArgumentParser()
00116     subparsers = parser.add_subparsers(dest='command')
00117     # fit command
00118     fit_parser = subparsers.add_parser('fit',
00119                         help='fit feature extractor using dataset')
00120     fit_parser.add_argument('data_path', help='data path')
00121     fit_parser.add_argument('-O', '--output', default='bof.pkl.gz',
00122                             help='bof feature extractor instance save path')
00123     # dataset command
00124     dataset_parser = subparsers.add_parser('dataset',
00125                         help='create bof histogram dataset')
00126     dataset_parser.add_argument('data_path', help='data path')
00127     dataset_parser.add_argument('bof_path', help='bof data path')
00128     dataset_parser.add_argument('-O', '--output', default='bof_hist.pkl.gz',
00129         help='save path of bof histogram (default: bof_hist.pkl.gz)')
00130     # extract command
00131     extract_parser = subparsers.add_parser('extract',
00132                         help='extract feature in realtime')
00133     extract_parser.add_argument('bof_path', help='bof data path')
00134     args = parser.parse_args(rospy.myargv(sys.argv[1:]))
00136     if args.command == 'fit':
00137         fit_bof_extractor(data_path=args.data_path, bof_path=args.output)
00138     elif args.command == 'dataset':
00139         create_dataset(data_path=args.data_path,
00140                        bof_path=args.bof_path,
00141                        bof_hist_path=args.output)
00142     elif args.command == 'extract':
00143         rospy.init_node('extract_bof')
00144         ex_real = ExtractInRealtime(bof_path=args.bof_path)
00145         rospy.spin()
00148 if __name__ == '__main__':
00149     main()

