00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #ifndef RTABMAP_FLANN_LSH_INDEX_H_
00036 #define RTABMAP_FLANN_LSH_INDEX_H_
00037
00038 #include <algorithm>
00039 #include <cassert>
00040 #include <cstring>
00041 #include <map>
00042 #include <vector>
00043
00044 #include "rtflann/general.h"
00045 #include "rtflann/algorithms/nn_index.h"
00046 #include "rtflann/util/matrix.h"
00047 #include "rtflann/util/result_set.h"
00048 #include "rtflann/util/heap.h"
00049 #include "rtflann/util/lsh_table.h"
00050 #include "rtflann/util/allocator.h"
00051 #include "rtflann/util/random.h"
00052 #include "rtflann/util/saving.h"
00053
00054 namespace rtflann
00055 {
00056
00057 struct LshIndexParams : public IndexParams
00058 {
00059 LshIndexParams(unsigned int table_number = 12, unsigned int key_size = 20, unsigned int multi_probe_level = 2)
00060 {
00061 (* this)["algorithm"] = FLANN_INDEX_LSH;
00062
00063 (*this)["table_number"] = table_number;
00064
00065 (*this)["key_size"] = key_size;
00066
00067 (*this)["multi_probe_level"] = multi_probe_level;
00068 }
00069 };
00070
00077 template<typename Distance>
00078 class LshIndex : public NNIndex<Distance>
00079 {
00080 public:
00081 typedef typename Distance::ElementType ElementType;
00082 typedef typename Distance::ResultType DistanceType;
00083
00084 typedef NNIndex<Distance> BaseClass;
00085
00090 LshIndex(const IndexParams& params = LshIndexParams(), Distance d = Distance()) :
00091 BaseClass(params, d)
00092 {
00093 table_number_ = get_param<unsigned int>(index_params_,"table_number",12);
00094 key_size_ = get_param<unsigned int>(index_params_,"key_size",20);
00095 multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2);
00096
00097 fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
00098 }
00099
00100
00106 LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(), Distance d = Distance()) :
00107 BaseClass(params, d)
00108 {
00109 table_number_ = get_param<unsigned int>(index_params_,"table_number",12);
00110 key_size_ = get_param<unsigned int>(index_params_,"key_size",20);
00111 multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2);
00112
00113 fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
00114
00115 setDataset(input_data);
00116 }
00117
00118 LshIndex(const LshIndex& other) : BaseClass(other),
00119 tables_(other.tables_),
00120 table_number_(other.table_number_),
00121 key_size_(other.key_size_),
00122 multi_probe_level_(other.multi_probe_level_),
00123 xor_masks_(other.xor_masks_)
00124 {
00125 }
00126
00127 LshIndex& operator=(LshIndex other)
00128 {
00129 this->swap(other);
00130 return *this;
00131 }
00132
00133 virtual ~LshIndex()
00134 {
00135 freeIndex();
00136 }
00137
00138
00139 BaseClass* clone() const
00140 {
00141 return new LshIndex(*this);
00142 }
00143
00144 using BaseClass::buildIndex;
00145
00146 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00147 {
00148 assert(points.cols==veclen_);
00149 size_t old_size = size_;
00150
00151 extendDataset(points);
00152
00153 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00154 buildIndex();
00155 }
00156 else {
00157 for (unsigned int i = 0; i < table_number_; ++i) {
00158 lsh::LshTable<ElementType>& table = tables_[i];
00159 for (size_t i=old_size;i<size_;++i) {
00160 table.add(i, points_[i]);
00161 }
00162 }
00163 }
00164 }
00165
00166
00167 flann_algorithm_t getType() const
00168 {
00169 return FLANN_INDEX_LSH;
00170 }
00171
00172
00173 template<typename Archive>
00174 void serialize(Archive& ar)
00175 {
00176 ar.setObject(this);
00177
00178 ar & *static_cast<NNIndex<Distance>*>(this);
00179
00180 ar & table_number_;
00181 ar & key_size_;
00182 ar & multi_probe_level_;
00183
00184 ar & xor_masks_;
00185 ar & tables_;
00186
00187 if (Archive::is_loading::value) {
00188 index_params_["algorithm"] = getType();
00189 index_params_["table_number"] = table_number_;
00190 index_params_["key_size"] = key_size_;
00191 index_params_["multi_probe_level"] = multi_probe_level_;
00192 }
00193 }
00194
00195 void saveIndex(FILE* stream)
00196 {
00197 serialization::SaveArchive sa(stream);
00198 sa & *this;
00199 }
00200
00201 void loadIndex(FILE* stream)
00202 {
00203 serialization::LoadArchive la(stream);
00204 la & *this;
00205 }
00206
00211 int usedMemory() const
00212 {
00213 return size_ * sizeof(int);
00214 }
00215
00224 int knnSearch(const Matrix<ElementType>& queries,
00225 Matrix<size_t>& indices,
00226 Matrix<DistanceType>& dists,
00227 size_t knn,
00228 const SearchParams& params) const
00229 {
00230 assert(queries.cols == veclen_);
00231 assert(indices.rows >= queries.rows);
00232 assert(dists.rows >= queries.rows);
00233 assert(indices.cols >= knn);
00234 assert(dists.cols >= knn);
00235
00236 int count = 0;
00237 if (params.use_heap==FLANN_True) {
00238 #pragma omp parallel num_threads(params.cores)
00239 {
00240 KNNUniqueResultSet<DistanceType> resultSet(knn);
00241 #pragma omp for schedule(static) reduction(+:count)
00242 for (int i = 0; i < (int)queries.rows; i++) {
00243 resultSet.clear();
00244 findNeighbors(resultSet, queries[i], params);
00245 size_t n = std::min(resultSet.size(), knn);
00246 resultSet.copy(indices[i], dists[i], n, params.sorted);
00247 indices_to_ids(indices[i], indices[i], n);
00248 count += n;
00249 }
00250 }
00251 }
00252 else {
00253 #pragma omp parallel num_threads(params.cores)
00254 {
00255 KNNResultSet<DistanceType> resultSet(knn);
00256 #pragma omp for schedule(static) reduction(+:count)
00257 for (int i = 0; i < (int)queries.rows; i++) {
00258 resultSet.clear();
00259 findNeighbors(resultSet, queries[i], params);
00260 size_t n = std::min(resultSet.size(), knn);
00261 resultSet.copy(indices[i], dists[i], n, params.sorted);
00262 indices_to_ids(indices[i], indices[i], n);
00263 count += n;
00264 }
00265 }
00266 }
00267
00268 return count;
00269 }
00270
00279 int knnSearch(const Matrix<ElementType>& queries,
00280 std::vector< std::vector<size_t> >& indices,
00281 std::vector<std::vector<DistanceType> >& dists,
00282 size_t knn,
00283 const SearchParams& params) const
00284 {
00285 assert(queries.cols == veclen_);
00286 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00287 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00288
00289 int count = 0;
00290 if (params.use_heap==FLANN_True) {
00291 #pragma omp parallel num_threads(params.cores)
00292 {
00293 KNNUniqueResultSet<DistanceType> resultSet(knn);
00294 #pragma omp for schedule(static) reduction(+:count)
00295 for (int i = 0; i < (int)queries.rows; i++) {
00296 resultSet.clear();
00297 findNeighbors(resultSet, queries[i], params);
00298 size_t n = std::min(resultSet.size(), knn);
00299 indices[i].resize(n);
00300 dists[i].resize(n);
00301 if (n > 0) {
00302 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00303 indices_to_ids(&indices[i][0], &indices[i][0], n);
00304 }
00305 count += n;
00306 }
00307 }
00308 }
00309 else {
00310 #pragma omp parallel num_threads(params.cores)
00311 {
00312 KNNResultSet<DistanceType> resultSet(knn);
00313 #pragma omp for schedule(static) reduction(+:count)
00314 for (int i = 0; i < (int)queries.rows; i++) {
00315 resultSet.clear();
00316 findNeighbors(resultSet, queries[i], params);
00317 size_t n = std::min(resultSet.size(), knn);
00318 indices[i].resize(n);
00319 dists[i].resize(n);
00320 if (n > 0) {
00321 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00322 indices_to_ids(&indices[i][0], &indices[i][0], n);
00323 }
00324 count += n;
00325 }
00326 }
00327 }
00328
00329 return count;
00330 }
00331
00341 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& ) const
00342 {
00343 getNeighbors(vec, result);
00344 }
00345
00346 protected:
00347
00351 void buildIndexImpl()
00352 {
00353 tables_.resize(table_number_);
00354 std::vector<std::pair<size_t,ElementType*> > features;
00355 features.reserve(points_.size());
00356 for (size_t i=0;i<points_.size();++i) {
00357 features.push_back(std::make_pair(i, points_[i]));
00358 }
00359 for (unsigned int i = 0; i < table_number_; ++i) {
00360 lsh::LshTable<ElementType>& table = tables_[i];
00361 table = lsh::LshTable<ElementType>(veclen_, key_size_);
00362
00363
00364 table.add(features);
00365 }
00366 }
00367
00368 void freeIndex()
00369 {
00370
00371 }
00372
00373
00374 private:
00377 typedef std::pair<float, unsigned int> ScoreIndexPair;
00378 struct SortScoreIndexPairOnSecond
00379 {
00380 bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const
00381 {
00382 return left.second < right.second;
00383 }
00384 };
00385
00392 void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level,
00393 std::vector<lsh::BucketKey>& xor_masks)
00394 {
00395 xor_masks.push_back(key);
00396 if (level == 0) return;
00397 for (int index = lowest_index - 1; index >= 0; --index) {
00398
00399 lsh::BucketKey new_key = key | (lsh::BucketKey(1) << index);
00400 fill_xor_mask(new_key, index, level - 1, xor_masks);
00401 }
00402 }
00403
00412 void getNeighbors(const ElementType* vec, bool do_radius, float radius, bool do_k, unsigned int k_nn,
00413 float& checked_average)
00414 {
00415 static std::vector<ScoreIndexPair> score_index_heap;
00416
00417 if (do_k) {
00418 unsigned int worst_score = std::numeric_limits<unsigned int>::max();
00419 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
00420 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
00421 for (; table != table_end; ++table) {
00422 size_t key = table->getKey(vec);
00423 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
00424 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
00425 for (; xor_mask != xor_mask_end; ++xor_mask) {
00426 size_t sub_key = key ^ (*xor_mask);
00427 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
00428 if (bucket == 0) continue;
00429
00430
00431 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
00432 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
00433 DistanceType hamming_distance;
00434
00435
00436 for (; training_index < last_training_index; ++training_index) {
00437 if (removed_ && removed_points_.test(*training_index)) continue;
00438 hamming_distance = distance_(vec, points_[*training_index].point, veclen_);
00439
00440 if (hamming_distance < worst_score) {
00441
00442 score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
00443 std::push_heap(score_index_heap.begin(), score_index_heap.end());
00444
00445 if (score_index_heap.size() > (unsigned int)k_nn) {
00446
00447 std::pop_heap(score_index_heap.begin(), score_index_heap.end());
00448 score_index_heap.pop_back();
00449
00450 worst_score = score_index_heap.front().first;
00451 }
00452 }
00453 }
00454 }
00455 }
00456 }
00457 else {
00458 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
00459 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
00460 for (; table != table_end; ++table) {
00461 size_t key = table->getKey(vec);
00462 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
00463 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
00464 for (; xor_mask != xor_mask_end; ++xor_mask) {
00465 size_t sub_key = key ^ (*xor_mask);
00466 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
00467 if (bucket == 0) continue;
00468
00469
00470 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
00471 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
00472 DistanceType hamming_distance;
00473
00474
00475 for (; training_index < last_training_index; ++training_index) {
00476 if (removed_ && removed_points_.test(*training_index)) continue;
00477
00478 hamming_distance = distance_(vec, points_[*training_index].point, veclen_);
00479 if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
00480 }
00481 }
00482 }
00483 }
00484 }
00485
00490 void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result) const
00491 {
00492 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
00493 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
00494 for (; table != table_end; ++table) {
00495 size_t key = table->getKey(vec);
00496 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
00497 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
00498 for (; xor_mask != xor_mask_end; ++xor_mask) {
00499 size_t sub_key = key ^ (*xor_mask);
00500 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
00501 if (bucket == 0) continue;
00502
00503
00504 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
00505 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
00506 DistanceType hamming_distance;
00507
00508
00509 for (; training_index < last_training_index; ++training_index) {
00510 if (removed_ && removed_points_.test(*training_index)) continue;
00511
00512 hamming_distance = distance_(vec, points_[*training_index], veclen_);
00513 result.addPoint(hamming_distance, *training_index);
00514 }
00515 }
00516 }
00517 }
00518
00519
00520 void swap(LshIndex& other)
00521 {
00522 BaseClass::swap(other);
00523 std::swap(tables_, other.tables_);
00524 std::swap(size_at_build_, other.size_at_build_);
00525 std::swap(table_number_, other.table_number_);
00526 std::swap(key_size_, other.key_size_);
00527 std::swap(multi_probe_level_, other.multi_probe_level_);
00528 std::swap(xor_masks_, other.xor_masks_);
00529 }
00530
00532 std::vector<lsh::LshTable<ElementType> > tables_;
00533
00535 unsigned int table_number_;
00537 unsigned int key_size_;
00539 unsigned int multi_probe_level_;
00540
00542 std::vector<lsh::BucketKey> xor_masks_;
00543
00544 USING_BASECLASS_SYMBOLS
00545 };
00546 }
00547
00548 #endif //FLANN_LSH_INDEX_H_