lsh_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 /***********************************************************************
00032  * Author: Vincent Rabaud
00033  *************************************************************************/
00034 
00035 #ifndef RTABMAP_FLANN_LSH_INDEX_H_
00036 #define RTABMAP_FLANN_LSH_INDEX_H_
00037 
00038 #include <algorithm>
00039 #include <cassert>
00040 #include <cstring>
00041 #include <map>
00042 #include <vector>
00043 
00044 #include "rtflann/general.h"
00045 #include "rtflann/algorithms/nn_index.h"
00046 #include "rtflann/util/matrix.h"
00047 #include "rtflann/util/result_set.h"
00048 #include "rtflann/util/heap.h"
00049 #include "rtflann/util/lsh_table.h"
00050 #include "rtflann/util/allocator.h"
00051 #include "rtflann/util/random.h"
00052 #include "rtflann/util/saving.h"
00053 
00054 namespace rtflann
00055 {
00056 
00057 struct LshIndexParams : public IndexParams
00058 {
00059     LshIndexParams(unsigned int table_number = 12, unsigned int key_size = 20, unsigned int multi_probe_level = 2)
00060     {
00061         (* this)["algorithm"] = FLANN_INDEX_LSH;
00062         // The number of hash tables to use
00063         (*this)["table_number"] = table_number;
00064         // The length of the key in the hash tables
00065         (*this)["key_size"] = key_size;
00066         // Number of levels to use in multi-probe (0 for standard LSH)
00067         (*this)["multi_probe_level"] = multi_probe_level;
00068     }
00069 };
00070 
00077 template<typename Distance>
00078 class LshIndex : public NNIndex<Distance>
00079 {
00080 public:
00081     typedef typename Distance::ElementType ElementType;
00082     typedef typename Distance::ResultType DistanceType;
00083 
00084     typedef NNIndex<Distance> BaseClass;
00085 
00090     LshIndex(const IndexParams& params = LshIndexParams(), Distance d = Distance()) :
00091         BaseClass(params, d)
00092     {
00093         table_number_ = get_param<unsigned int>(index_params_,"table_number",12);
00094         key_size_ = get_param<unsigned int>(index_params_,"key_size",20);
00095         multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2);
00096 
00097         fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
00098     }
00099 
00100 
00106     LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(), Distance d = Distance()) :
00107         BaseClass(params, d)
00108     {
00109         table_number_ = get_param<unsigned int>(index_params_,"table_number",12);
00110         key_size_ = get_param<unsigned int>(index_params_,"key_size",20);
00111         multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2);
00112 
00113         fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
00114 
00115         setDataset(input_data);
00116     }
00117 
00118     LshIndex(const LshIndex& other) : BaseClass(other),
00119         tables_(other.tables_),
00120         table_number_(other.table_number_),
00121         key_size_(other.key_size_),
00122         multi_probe_level_(other.multi_probe_level_),
00123         xor_masks_(other.xor_masks_)
00124     {
00125     }
00126     
00127     LshIndex& operator=(LshIndex other)
00128     {
00129         this->swap(other);
00130         return *this;
00131     }
00132 
00133     virtual ~LshIndex()
00134     {
00135         freeIndex();
00136     }
00137 
00138 
00139     BaseClass* clone() const
00140     {
00141         return new LshIndex(*this);
00142     }
00143     
00144     using BaseClass::buildIndex;
00145 
00146     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00147     {
00148         assert(points.cols==veclen_);
00149         size_t old_size = size_;
00150 
00151         extendDataset(points);
00152         
00153         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00154             buildIndex();
00155         }
00156         else {
00157             for (unsigned int i = 0; i < table_number_; ++i) {
00158                 lsh::LshTable<ElementType>& table = tables_[i];                
00159                 for (size_t i=old_size;i<size_;++i) {
00160                     table.add(i, points_[i]);
00161                 }            
00162             }
00163         }
00164     }
00165 
00166 
00167     flann_algorithm_t getType() const
00168     {
00169         return FLANN_INDEX_LSH;
00170     }
00171 
00172 
00173     template<typename Archive>
00174     void serialize(Archive& ar)
00175     {
00176         ar.setObject(this);
00177 
00178         ar & *static_cast<NNIndex<Distance>*>(this);
00179 
00180         ar & table_number_;
00181         ar & key_size_;
00182         ar & multi_probe_level_;
00183 
00184         ar & xor_masks_;
00185         ar & tables_;
00186 
00187         if (Archive::is_loading::value) {
00188             index_params_["algorithm"] = getType();
00189             index_params_["table_number"] = table_number_;
00190             index_params_["key_size"] = key_size_;
00191             index_params_["multi_probe_level"] = multi_probe_level_;
00192         }
00193     }
00194 
00195     void saveIndex(FILE* stream)
00196     {
00197         serialization::SaveArchive sa(stream);
00198         sa & *this;
00199     }
00200 
00201     void loadIndex(FILE* stream)
00202     {
00203         serialization::LoadArchive la(stream);
00204         la & *this;
00205     }
00206 
00211     int usedMemory() const
00212     {
00213         return size_ * sizeof(int);
00214     }
00215 
00224     int knnSearch(const Matrix<ElementType>& queries,
00225                                         Matrix<size_t>& indices,
00226                                         Matrix<DistanceType>& dists,
00227                                         size_t knn,
00228                                         const SearchParams& params) const
00229     {
00230         assert(queries.cols == veclen_);
00231         assert(indices.rows >= queries.rows);
00232         assert(dists.rows >= queries.rows);
00233         assert(indices.cols >= knn);
00234         assert(dists.cols >= knn);
00235 
00236         int count = 0;
00237         if (params.use_heap==FLANN_True) {
00238 #pragma omp parallel num_threads(params.cores)
00239                 {
00240                         KNNUniqueResultSet<DistanceType> resultSet(knn);
00241 #pragma omp for schedule(static) reduction(+:count)
00242                         for (int i = 0; i < (int)queries.rows; i++) {
00243                                 resultSet.clear();
00244                                 findNeighbors(resultSet, queries[i], params);
00245                                 size_t n = std::min(resultSet.size(), knn);
00246                                 resultSet.copy(indices[i], dists[i], n, params.sorted);
00247                                 indices_to_ids(indices[i], indices[i], n);
00248                                 count += n;
00249                         }
00250                 }
00251         }
00252         else {
00253 #pragma omp parallel num_threads(params.cores)
00254                 {
00255                         KNNResultSet<DistanceType> resultSet(knn);
00256 #pragma omp for schedule(static) reduction(+:count)
00257                         for (int i = 0; i < (int)queries.rows; i++) {
00258                                 resultSet.clear();
00259                                 findNeighbors(resultSet, queries[i], params);
00260                                 size_t n = std::min(resultSet.size(), knn);
00261                                 resultSet.copy(indices[i], dists[i], n, params.sorted);
00262                                 indices_to_ids(indices[i], indices[i], n);
00263                                 count += n;
00264                         }
00265                 }
00266         }
00267 
00268         return count;
00269     }
00270 
00279     int knnSearch(const Matrix<ElementType>& queries,
00280                                         std::vector< std::vector<size_t> >& indices,
00281                                         std::vector<std::vector<DistanceType> >& dists,
00282                                 size_t knn,
00283                                 const SearchParams& params) const
00284     {
00285         assert(queries.cols == veclen_);
00286                 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00287                 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00288 
00289                 int count = 0;
00290                 if (params.use_heap==FLANN_True) {
00291 #pragma omp parallel num_threads(params.cores)
00292                         {
00293                                 KNNUniqueResultSet<DistanceType> resultSet(knn);
00294 #pragma omp for schedule(static) reduction(+:count)
00295                                 for (int i = 0; i < (int)queries.rows; i++) {
00296                                         resultSet.clear();
00297                                         findNeighbors(resultSet, queries[i], params);
00298                                         size_t n = std::min(resultSet.size(), knn);
00299                                         indices[i].resize(n);
00300                                         dists[i].resize(n);
00301                                         if (n > 0) {
00302                                                 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00303                                                 indices_to_ids(&indices[i][0], &indices[i][0], n);
00304                                         }
00305                                         count += n;
00306                                 }
00307                         }
00308                 }
00309                 else {
00310 #pragma omp parallel num_threads(params.cores)
00311                         {
00312                                 KNNResultSet<DistanceType> resultSet(knn);
00313 #pragma omp for schedule(static) reduction(+:count)
00314                                 for (int i = 0; i < (int)queries.rows; i++) {
00315                                         resultSet.clear();
00316                                         findNeighbors(resultSet, queries[i], params);
00317                                         size_t n = std::min(resultSet.size(), knn);
00318                                         indices[i].resize(n);
00319                                         dists[i].resize(n);
00320                                         if (n > 0) {
00321                                                 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00322                                                 indices_to_ids(&indices[i][0], &indices[i][0], n);
00323                                         }
00324                                         count += n;
00325                                 }
00326                         }
00327                 }
00328 
00329                 return count;
00330     }
00331 
00341     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& /*searchParams*/) const
00342     {
00343         getNeighbors(vec, result);
00344     }
00345 
00346 protected:
00347 
00351     void buildIndexImpl()
00352     {
00353         tables_.resize(table_number_);
00354         std::vector<std::pair<size_t,ElementType*> > features;
00355         features.reserve(points_.size());
00356         for (size_t i=0;i<points_.size();++i) {
00357                 features.push_back(std::make_pair(i, points_[i]));
00358         }
00359         for (unsigned int i = 0; i < table_number_; ++i) {
00360             lsh::LshTable<ElementType>& table = tables_[i];
00361             table = lsh::LshTable<ElementType>(veclen_, key_size_);
00362 
00363             // Add the features to the table
00364             table.add(features);
00365         }
00366     }
00367 
00368     void freeIndex()
00369     {
00370         /* nothing to do here */
00371     }
00372 
00373 
00374 private:
00377     typedef std::pair<float, unsigned int> ScoreIndexPair;
00378     struct SortScoreIndexPairOnSecond
00379     {
00380         bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const
00381         {
00382             return left.second < right.second;
00383         }
00384     };
00385 
00392     void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level,
00393                        std::vector<lsh::BucketKey>& xor_masks)
00394     {
00395         xor_masks.push_back(key);
00396         if (level == 0) return;
00397         for (int index = lowest_index - 1; index >= 0; --index) {
00398             // Create a new key
00399             lsh::BucketKey new_key = key | (lsh::BucketKey(1) << index);
00400             fill_xor_mask(new_key, index, level - 1, xor_masks);
00401         }
00402     }
00403 
00412     void getNeighbors(const ElementType* vec, bool do_radius, float radius, bool do_k, unsigned int k_nn,
00413                       float& checked_average)
00414     {
00415         static std::vector<ScoreIndexPair> score_index_heap;
00416 
00417         if (do_k) {
00418             unsigned int worst_score = std::numeric_limits<unsigned int>::max();
00419             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
00420             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
00421             for (; table != table_end; ++table) {
00422                 size_t key = table->getKey(vec);
00423                 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
00424                 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
00425                 for (; xor_mask != xor_mask_end; ++xor_mask) {
00426                     size_t sub_key = key ^ (*xor_mask);
00427                     const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
00428                     if (bucket == 0) continue;
00429 
00430                     // Go over each descriptor index
00431                     std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
00432                     std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
00433                     DistanceType hamming_distance;
00434 
00435                     // Process the rest of the candidates
00436                     for (; training_index < last_training_index; ++training_index) {
00437                         if (removed_ && removed_points_.test(*training_index)) continue;
00438                         hamming_distance = distance_(vec, points_[*training_index].point, veclen_);
00439 
00440                         if (hamming_distance < worst_score) {
00441                             // Insert the new element
00442                             score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
00443                             std::push_heap(score_index_heap.begin(), score_index_heap.end());
00444 
00445                             if (score_index_heap.size() > (unsigned int)k_nn) {
00446                                 // Remove the highest distance value as we have too many elements
00447                                 std::pop_heap(score_index_heap.begin(), score_index_heap.end());
00448                                 score_index_heap.pop_back();
00449                                 // Keep track of the worst score
00450                                 worst_score = score_index_heap.front().first;
00451                             }
00452                         }
00453                     }
00454                 }
00455             }
00456         }
00457         else {
00458             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
00459             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
00460             for (; table != table_end; ++table) {
00461                 size_t key = table->getKey(vec);
00462                 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
00463                 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
00464                 for (; xor_mask != xor_mask_end; ++xor_mask) {
00465                     size_t sub_key = key ^ (*xor_mask);
00466                     const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
00467                     if (bucket == 0) continue;
00468 
00469                     // Go over each descriptor index
00470                     std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
00471                     std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
00472                     DistanceType hamming_distance;
00473 
00474                     // Process the rest of the candidates
00475                     for (; training_index < last_training_index; ++training_index) {
00476                         if (removed_ && removed_points_.test(*training_index)) continue;
00477                         // Compute the Hamming distance
00478                         hamming_distance = distance_(vec, points_[*training_index].point, veclen_);
00479                         if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
00480                     }
00481                 }
00482             }
00483         }
00484     }
00485 
00490     void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result) const
00491     {
00492         typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
00493         typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
00494         for (; table != table_end; ++table) {
00495             size_t key = table->getKey(vec);
00496             std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
00497             std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
00498             for (; xor_mask != xor_mask_end; ++xor_mask) {
00499                 size_t sub_key = key ^ (*xor_mask);
00500                 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
00501                 if (bucket == 0) continue;
00502 
00503                 // Go over each descriptor index
00504                 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
00505                 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
00506                 DistanceType hamming_distance;
00507 
00508                 // Process the rest of the candidates
00509                 for (; training_index < last_training_index; ++training_index) {
00510                         if (removed_ && removed_points_.test(*training_index)) continue;
00511                     // Compute the Hamming distance
00512                     hamming_distance = distance_(vec, points_[*training_index], veclen_);
00513                     result.addPoint(hamming_distance, *training_index);
00514                 }
00515             }
00516         }
00517     }
00518 
00519 
00520     void swap(LshIndex& other)
00521     {
00522         BaseClass::swap(other);
00523         std::swap(tables_, other.tables_);
00524         std::swap(size_at_build_, other.size_at_build_);
00525         std::swap(table_number_, other.table_number_);
00526         std::swap(key_size_, other.key_size_);
00527         std::swap(multi_probe_level_, other.multi_probe_level_);
00528         std::swap(xor_masks_, other.xor_masks_);
00529     }
00530 
00532     std::vector<lsh::LshTable<ElementType> > tables_;
00533     
00535     unsigned int table_number_;
00537     unsigned int key_size_;
00539     unsigned int multi_probe_level_;
00540 
00542     std::vector<lsh::BucketKey> xor_masks_;
00543 
00544     USING_BASECLASS_SYMBOLS
00545 };
00546 }
00547 
00548 #endif //FLANN_LSH_INDEX_H_


rtabmap
Author(s): Mathieu Labbe
autogenerated on Thu Jun 6 2019 21:59:20