00001
00002
00003 """
00004 Extract Bag of Features and create BoF Histograms
00005
00006 usage: extract_bof.py [-h] {fit,extract,dataset} ...
00007
00008 positional arguments:
00009 {fit,extract,dataset}
00010 fit fit feature extractor using dataset.
00011 dataset create bof histogram dataset.
00012 extract extract feature in realtime.
00013
00014 """
00015
00016 import os
00017 import sys
00018 import gzip
00019 import cPickle as pickle
00020 import argparse
00021
00022 import progressbar
00023 import numpy as np
00024 from sklearn.datasets.base import Bunch
00025 from sklearn.preprocessing import normalize
00026 from sklearn.neighbors import NearestNeighbors
00027 from sklearn.cluster import MiniBatchKMeans
00028
00029 import rospy
00030 from posedetection_msgs.msg import Feature0D
00031 from jsk_recognition_msgs.msg import Histogram
00032
00033
00034 class BagOfFeatures(object):
00035 def __init__(self, hist_size=500):
00036 self.nn = None
00037 self.hist_size = hist_size
00038
00039 def fit(self, X):
00040 """Fit features and extract bag of features"""
00041 k = self.hist_size
00042 km = MiniBatchKMeans(n_clusters=k, init_size=3*k, max_iter=300)
00043 km.fit(X)
00044 nn = NearestNeighbors(n_neighbors=1)
00045 nn.fit(km.cluster_centers_)
00046 self.nn = nn
00047
00048 def transform(self, X):
00049 return np.vstack([self.make_hist(xi.reshape((-1, 128))) for xi in X])
00050
00051 def make_hist(self, descriptors):
00052 """Make histogram for one image"""
00053 nn = self.nn
00054 if nn is None:
00055 raise ValueError('must fit features before making histogram')
00056 indices = nn.kneighbors(descriptors, return_distance=False)
00057 histogram = np.zeros(self.hist_size)
00058 for idx in np.unique(indices):
00059 mask = indices == idx
00060 histogram[idx] = mask.sum()
00061 indices = indices[mask == False]
00062 return histogram
00063
00064
00065 def create_dataset(data_path, bof_path, bof_hist_path):
00066 print('creating dataset')
00067 with gzip.open(data_path, 'rb') as f:
00068 dataset = pickle.load(f)
00069 descs, y, target_names = (dataset['descriptors'],
00070 dataset['target'],
00071 dataset['target_names'])
00072 print('extracting feature')
00073 with gzip.open(bof_path, 'rb') as f:
00074 bof = pickle.load(f)
00075 X = bof.transform(descs)
00076 normalize(X, copy=False)
00077 dataset = Bunch(data=X, target=y, target_names=target_names)
00078 print('saving dataset')
00079 with gzip.open(bof_hist_path, 'wb') as f:
00080 pickle.dump(dataset, f)
00081
00082
00083 def fit_bof_extractor(data_path, bof_path):
00084 print('loading data')
00085 with gzip.open(data_path, 'rb') as f:
00086 dataset = pickle.load(f)
00087 descs = dataset['descriptors']
00088 X = np.vstack(map(lambda x: np.array(x).reshape((-1, 128)), descs))
00089
00090 print('fitting bag of features extractor')
00091 bof = BagOfFeatures()
00092 bof.fit(X)
00093
00094 print('saving bof')
00095 with gzip.open(bof_path, 'wb') as f:
00096 pickle.dump(bof, f)
00097
00098
00099 class ExtractInRealtime(object):
00100 def __init__(self, bof_path):
00101 with gzip.open(bof_path, 'rb') as f:
00102 self.bof = pickle.load(f)
00103 self.pub = rospy.Publisher('~output/bof_hist', Histogram, queue_size=1)
00104 rospy.Subscriber('Feature0D', Feature0D, self._cb_feature0d)
00105 rospy.loginfo('Initialized bof extractor')
00106
00107 def _cb_feature0d(self, msg):
00108 desc = np.array(msg.descriptors)
00109 X = self.bof.transform([desc])
00110 normalize(X, copy=False)
00111 self.pub.publish(Histogram(header=msg.header, histogram=X[0]))
00112
00113
00114 def main():
00115 parser = argparse.ArgumentParser()
00116 subparsers = parser.add_subparsers(dest='command')
00117
00118 fit_parser = subparsers.add_parser('fit',
00119 help='fit feature extractor using dataset')
00120 fit_parser.add_argument('data_path', help='data path')
00121 fit_parser.add_argument('-O', '--output', default='bof.pkl.gz',
00122 help='bof feature extractor instance save path')
00123
00124 dataset_parser = subparsers.add_parser('dataset',
00125 help='create bof histogram dataset')
00126 dataset_parser.add_argument('data_path', help='data path')
00127 dataset_parser.add_argument('bof_path', help='bof data path')
00128 dataset_parser.add_argument('-O', '--output', default='bof_hist.pkl.gz',
00129 help='save path of bof histogram (default: bof_hist.pkl.gz)')
00130
00131 extract_parser = subparsers.add_parser('extract',
00132 help='extract feature in realtime')
00133 extract_parser.add_argument('bof_path', help='bof data path')
00134 args = parser.parse_args(rospy.myargv(sys.argv[1:]))
00135
00136 if args.command == 'fit':
00137 fit_bof_extractor(data_path=args.data_path, bof_path=args.output)
00138 elif args.command == 'dataset':
00139 create_dataset(data_path=args.data_path,
00140 bof_path=args.bof_path,
00141 bof_hist_path=args.output)
00142 elif args.command == 'extract':
00143 rospy.init_node('extract_bof')
00144 ex_real = ExtractInRealtime(bof_path=args.bof_path)
00145 rospy.spin()
00146
00147
00148 if __name__ == '__main__':
00149 main()