kdtree_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 #ifndef RTABMAP_FLANN_KDTREE_INDEX_H_
00032 #define RTABMAP_FLANN_KDTREE_INDEX_H_
00033 
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038 #include <stdarg.h>
00039 #include <cmath>
00040 
00041 #include "rtflann/general.h"
00042 #include "rtflann/algorithms/nn_index.h"
00043 #include "rtflann/util/dynamic_bitset.h"
00044 #include "rtflann/util/matrix.h"
00045 #include "rtflann/util/result_set.h"
00046 #include "rtflann/util/heap.h"
00047 #include "rtflann/util/allocator.h"
00048 #include "rtflann/util/random.h"
00049 #include "rtflann/util/saving.h"
00050 
00051 
00052 namespace rtflann
00053 {
00054 
00055 struct KDTreeIndexParams : public IndexParams
00056 {
00057     KDTreeIndexParams(int trees = 4)
00058     {
00059         (*this)["algorithm"] = FLANN_INDEX_KDTREE;
00060         (*this)["trees"] = trees;
00061     }
00062 };
00063 
00064 
00071 template <typename Distance>
00072 class KDTreeIndex : public NNIndex<Distance>
00073 {
00074 public:
00075     typedef typename Distance::ElementType ElementType;
00076     typedef typename Distance::ResultType DistanceType;
00077 
00078     typedef NNIndex<Distance> BaseClass;
00079 
00080     typedef bool needs_kdtree_distance;
00081 
00082 
00090     KDTreeIndex(const IndexParams& params = KDTreeIndexParams(), Distance d = Distance() ) :
00091         BaseClass(params, d), mean_(NULL), var_(NULL)
00092     {
00093         trees_ = get_param(index_params_,"trees",4);
00094     }
00095 
00096 
00104     KDTreeIndex(const Matrix<ElementType>& dataset, const IndexParams& params = KDTreeIndexParams(),
00105                 Distance d = Distance() ) : BaseClass(params,d ), mean_(NULL), var_(NULL)
00106     {
00107         trees_ = get_param(index_params_,"trees",4);
00108 
00109         setDataset(dataset);
00110     }
00111 
00112     KDTreeIndex(const KDTreeIndex& other) : BaseClass(other),
00113                 trees_(other.trees_)
00114     {
00115         tree_roots_.resize(other.tree_roots_.size());
00116         for (size_t i=0;i<tree_roots_.size();++i) {
00117                 copyTree(tree_roots_[i], other.tree_roots_[i]);
00118         }
00119     }
00120 
00121     KDTreeIndex& operator=(KDTreeIndex other)
00122     {
00123         this->swap(other);
00124         return *this;
00125     }
00126 
00130     virtual ~KDTreeIndex()
00131     {
00132         freeIndex();
00133     }
00134 
00135     BaseClass* clone() const
00136     {
00137         return new KDTreeIndex(*this);
00138     }
00139 
00140     using BaseClass::buildIndex;
00141     
00142     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00143     {
00144         assert(points.cols==veclen_);
00145 
00146         size_t old_size = size_;
00147         extendDataset(points);
00148         
00149         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00150             buildIndex();
00151         }
00152         else {
00153             for (size_t i=old_size;i<size_;++i) {
00154                 for (int j = 0; j < trees_; j++) {
00155                     addPointToTree(tree_roots_[j], i);
00156                 }
00157             }
00158         }        
00159     }
00160 
00161     flann_algorithm_t getType() const
00162     {
00163         return FLANN_INDEX_KDTREE;
00164     }
00165 
00166 
00167     template<typename Archive>
00168     void serialize(Archive& ar)
00169     {
00170         ar.setObject(this);
00171 
00172         ar & *static_cast<NNIndex<Distance>*>(this);
00173 
00174         ar & trees_;
00175 
00176         if (Archive::is_loading::value) {
00177                 tree_roots_.resize(trees_);
00178         }
00179         for (size_t i=0;i<tree_roots_.size();++i) {
00180                 if (Archive::is_loading::value) {
00181                         tree_roots_[i] = new(pool_) Node();
00182                 }
00183                 ar & *tree_roots_[i];
00184         }
00185 
00186         if (Archive::is_loading::value) {
00187             index_params_["algorithm"] = getType();
00188             index_params_["trees"] = trees_;
00189         }
00190     }
00191 
00192 
00193     void saveIndex(FILE* stream)
00194     {
00195         serialization::SaveArchive sa(stream);
00196         sa & *this;
00197     }
00198 
00199 
00200     void loadIndex(FILE* stream)
00201     {
00202         freeIndex();
00203         serialization::LoadArchive la(stream);
00204         la & *this;
00205     }
00206 
00211     int usedMemory() const
00212     {
00213         return int(pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int));  // pool memory and vind array memory
00214     }
00215 
00225     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00226     {
00227         int maxChecks = searchParams.checks;
00228         float epsError = 1+searchParams.eps;
00229 
00230         if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00231                 if (removed_) {
00232                         getExactNeighbors<true>(result, vec, epsError);
00233                 }
00234                 else {
00235                         getExactNeighbors<false>(result, vec, epsError);
00236                 }
00237         }
00238         else {
00239                 if (removed_) {
00240                         getNeighbors<true>(result, vec, maxChecks, epsError);
00241                 }
00242                 else {
00243                         getNeighbors<false>(result, vec, maxChecks, epsError);
00244                 }
00245         }
00246     }
00247 
00248 protected:
00249 
00253     void buildIndexImpl()
00254     {
00255         // Create a permutable array of indices to the input vectors.
00256         std::vector<int> ind(size_);
00257         for (size_t i = 0; i < size_; ++i) {
00258             ind[i] = int(i);
00259         }
00260 
00261         mean_ = new DistanceType[veclen_];
00262         var_ = new DistanceType[veclen_];
00263 
00264         tree_roots_.resize(trees_);
00265         /* Construct the randomized trees. */
00266         for (int i = 0; i < trees_; i++) {
00267             /* Randomize the order of vectors to allow for unbiased sampling. */
00268             std::random_shuffle(ind.begin(), ind.end());
00269             tree_roots_[i] = divideTree(&ind[0], int(size_) );
00270         }
00271         delete[] mean_;
00272         delete[] var_;
00273     }
00274 
00275     void freeIndex()
00276     {
00277         for (size_t i=0;i<tree_roots_.size();++i) {
00278                 // using placement new, so call destructor explicitly
00279                 if (tree_roots_[i]!=NULL) tree_roots_[i]->~Node();
00280         }
00281         pool_.free();
00282     }
00283 
00284 
00285 private:
00286 
00287     /*--------------------- Internal Data Structures --------------------------*/
00288     struct Node
00289     {
00293         int divfeat;
00297         DistanceType divval;
00301         ElementType* point;
00305                 Node* child1, *child2;
00306                 Node(){
00307                         child1 = NULL;
00308                         child2 = NULL;
00309                 }
00310                 ~Node() {
00311                         if (child1 != NULL) { child1->~Node(); child1 = NULL; }
00312 
00313                         if (child2 != NULL) { child2->~Node(); child2 = NULL; }
00314                 }
00315 
00316     private:
00317         template<typename Archive>
00318         void serialize(Archive& ar)
00319         {
00320                 typedef KDTreeIndex<Distance> Index;
00321                 Index* obj = static_cast<Index*>(ar.getObject());
00322 
00323                 ar & divfeat;
00324                 ar & divval;
00325 
00326                 bool leaf_node = false;
00327                 if (Archive::is_saving::value) {
00328                         leaf_node = ((child1==NULL) && (child2==NULL));
00329                 }
00330                 ar & leaf_node;
00331 
00332                 if (leaf_node) {
00333                         if (Archive::is_loading::value) {
00334                                 point = obj->points_[divfeat];
00335                         }
00336                 }
00337 
00338                 if (!leaf_node) {
00339                                 if (Archive::is_loading::value) {
00340                                         child1 = new(obj->pool_) Node();
00341                                         child2 = new(obj->pool_) Node();
00342                                 }
00343                         ar & *child1;
00344                         ar & *child2;
00345                 }
00346         }
00347         friend struct serialization::access;
00348     };
00349     typedef Node* NodePtr;
00350     typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00351     typedef BranchSt* Branch;
00352 
00353 
00354     void copyTree(NodePtr& dst, const NodePtr& src)
00355     {
00356         dst = new(pool_) Node();
00357         dst->divfeat = src->divfeat;
00358         dst->divval = src->divval;
00359         if (src->child1==NULL && src->child2==NULL) {
00360                 dst->point = points_[dst->divfeat];
00361                 dst->child1 = NULL;
00362                 dst->child2 = NULL;
00363         }
00364         else {
00365                 copyTree(dst->child1, src->child1);
00366                 copyTree(dst->child2, src->child2);
00367         }
00368     }
00369 
00379     NodePtr divideTree(int* ind, int count)
00380     {
00381         NodePtr node = new(pool_) Node(); // allocate memory
00382 
00383         /* If too few exemplars remain, then make this a leaf node. */
00384         if (count == 1) {
00385             node->child1 = node->child2 = NULL;    /* Mark as leaf node. */
00386             node->divfeat = *ind;    /* Store index of this vec. */
00387             node->point = points_[*ind];
00388         }
00389         else {
00390             int idx;
00391             int cutfeat;
00392             DistanceType cutval;
00393             meanSplit(ind, count, idx, cutfeat, cutval);
00394 
00395             node->divfeat = cutfeat;
00396             node->divval = cutval;
00397             node->child1 = divideTree(ind, idx);
00398             node->child2 = divideTree(ind+idx, count-idx);
00399         }
00400 
00401         return node;
00402     }
00403 
00404 
00410     void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
00411     {
00412         memset(mean_,0,veclen_*sizeof(DistanceType));
00413         memset(var_,0,veclen_*sizeof(DistanceType));
00414 
00415         /* Compute mean values.  Only the first SAMPLE_MEAN values need to be
00416             sampled to get a good estimate.
00417          */
00418         int cnt = std::min((int)SAMPLE_MEAN+1, count);
00419         for (int j = 0; j < cnt; ++j) {
00420             ElementType* v = points_[ind[j]];
00421             for (size_t k=0; k<veclen_; ++k) {
00422                 mean_[k] += v[k];
00423             }
00424         }
00425         DistanceType div_factor = DistanceType(1)/cnt;
00426         for (size_t k=0; k<veclen_; ++k) {
00427             mean_[k] *= div_factor;
00428         }
00429 
00430         /* Compute variances (no need to divide by count). */
00431         for (int j = 0; j < cnt; ++j) {
00432             ElementType* v = points_[ind[j]];
00433             for (size_t k=0; k<veclen_; ++k) {
00434                 DistanceType dist = v[k] - mean_[k];
00435                 var_[k] += dist * dist;
00436             }
00437         }
00438         /* Select one of the highest variance indices at random. */
00439         cutfeat = selectDivision(var_);
00440         cutval = mean_[cutfeat];
00441 
00442         int lim1, lim2;
00443         planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00444 
00445         if (lim1>count/2) index = lim1;
00446         else if (lim2<count/2) index = lim2;
00447         else index = count/2;
00448 
00449         /* If either list is empty, it means that all remaining features
00450          * are identical. Split in the middle to maintain a balanced tree.
00451          */
00452         if ((lim1==count)||(lim2==0)) index = count/2;
00453     }
00454 
00455 
00460     int selectDivision(DistanceType* v)
00461     {
00462         int num = 0;
00463         size_t topind[RAND_DIM];
00464 
00465         /* Create a list of the indices of the top RAND_DIM values. */
00466         for (size_t i = 0; i < veclen_; ++i) {
00467             if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
00468                 /* Put this element at end of topind. */
00469                 if (num < RAND_DIM) {
00470                     topind[num++] = i;            /* Add to list. */
00471                 }
00472                 else {
00473                     topind[num-1] = i;         /* Replace last element. */
00474                 }
00475                 /* Bubble end value down to right location by repeated swapping. */
00476                 int j = num - 1;
00477                 while (j > 0  &&  v[topind[j]] > v[topind[j-1]]) {
00478                     std::swap(topind[j], topind[j-1]);
00479                     --j;
00480                 }
00481             }
00482         }
00483         /* Select a random integer in range [0,num-1], and return that index. */
00484         int rnd = rand_int(num);
00485         return (int)topind[rnd];
00486     }
00487 
00488 
00498     void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00499     {
00500         /* Move vector indices for left subtree to front of list. */
00501         int left = 0;
00502         int right = count-1;
00503         for (;; ) {
00504             while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
00505             while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
00506             if (left>right) break;
00507             std::swap(ind[left], ind[right]); ++left; --right;
00508         }
00509         lim1 = left;
00510         right = count-1;
00511         for (;; ) {
00512             while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
00513             while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
00514             if (left>right) break;
00515             std::swap(ind[left], ind[right]); ++left; --right;
00516         }
00517         lim2 = left;
00518     }
00519 
00524     template<bool with_removed>
00525     void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError) const
00526     {
00527         //              checkID -= 1;  /* Set a different unique ID for each search. */
00528 
00529         if (trees_ > 1) {
00530             fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00531         }
00532         if (trees_>0) {
00533             searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError);
00534         }
00535     }
00536 
00542     template<bool with_removed>
00543     void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError) const
00544     {
00545         int i;
00546         BranchSt branch;
00547 
00548         int checkCount = 0;
00549         Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00550         DynamicBitset checked(size_);
00551 
00552         /* Search once through each tree down to root. */
00553         for (i = 0; i < trees_; ++i) {
00554             searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00555         }
00556 
00557         /* Keep searching other branches from heap until finished. */
00558         while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00559             searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00560         }
00561 
00562         delete heap;
00563 
00564     }
00565 
00571     template<bool with_removed>
00572     void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
00573                      float epsError, Heap<BranchSt>* heap, DynamicBitset& checked) const
00574     {
00575         if (result_set.worstDist()<mindist) {
00576             //                  printf("Ignoring branch, too far\n");
00577             return;
00578         }
00579 
00580         /* If this is a leaf node, then do check and return. */
00581         if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00582             int index = node->divfeat;
00583             if (with_removed) {
00584                 if (removed_points_.test(index)) return;
00585             }
00586             /*  Do not check same node more than once when searching multiple trees. */
00587             if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
00588             checked.set(index);
00589             checkCount++;
00590 
00591             DistanceType dist = distance_(node->point, vec, veclen_);
00592             result_set.addPoint(dist,index);
00593             return;
00594         }
00595 
00596         /* Which child branch should be taken first? */
00597         ElementType val = vec[node->divfeat];
00598         DistanceType diff = val - node->divval;
00599         NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00600         NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00601 
00602         /* Create a branch record for the branch not taken.  Add distance
00603             of this feature boundary (we don't attempt to correct for any
00604             use of this feature in a parent node, which is unlikely to
00605             happen and would have only a small effect).  Don't bother
00606             adding more branches to heap after halfway point, as cost of
00607             adding exceeds their value.
00608          */
00609 
00610         DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00611         //              if (2 * checkCount < maxCheck  ||  !result.full()) {
00612         if ((new_distsq*epsError < result_set.worstDist())||  !result_set.full()) {
00613             heap->insert( BranchSt(otherChild, new_distsq) );
00614         }
00615 
00616         /* Call recursively to search next level down. */
00617         searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
00618     }
00619 
00623     template<bool with_removed>
00624     void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError) const
00625     {
00626         /* If this is a leaf node, then do check and return. */
00627         if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00628             int index = node->divfeat;
00629             if (with_removed) {
00630                 if (removed_points_.test(index)) return; // ignore removed points
00631             }
00632             DistanceType dist = distance_(node->point, vec, veclen_);
00633             result_set.addPoint(dist,index);
00634 
00635             return;
00636         }
00637 
00638         /* Which child branch should be taken first? */
00639         ElementType val = vec[node->divfeat];
00640         DistanceType diff = val - node->divval;
00641         NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00642         NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00643 
00644         /* Create a branch record for the branch not taken.  Add distance
00645             of this feature boundary (we don't attempt to correct for any
00646             use of this feature in a parent node, which is unlikely to
00647             happen and would have only a small effect).  Don't bother
00648             adding more branches to heap after halfway point, as cost of
00649             adding exceeds their value.
00650          */
00651 
00652         DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00653 
00654         /* Call recursively to search next level down. */
00655         searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError);
00656 
00657         if (mindist*epsError<=result_set.worstDist()) {
00658             searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError);
00659         }
00660     }
00661     
00662     void addPointToTree(NodePtr node, int ind)
00663     {
00664         ElementType* point = points_[ind];
00665         
00666         if ((node->child1==NULL) && (node->child2==NULL)) {
00667             ElementType* leaf_point = node->point;
00668             ElementType max_span = 0;
00669             size_t div_feat = 0;
00670             for (size_t i=0;i<veclen_;++i) {
00671                 ElementType span = std::abs(point[i]-leaf_point[i]);
00672                 if (span > max_span) {
00673                     max_span = span;
00674                     div_feat = i;
00675                 }
00676             }
00677             NodePtr left = new(pool_) Node();
00678             left->child1 = left->child2 = NULL;
00679             NodePtr right = new(pool_) Node();
00680             right->child1 = right->child2 = NULL;
00681 
00682             if (point[div_feat]<leaf_point[div_feat]) {
00683                 left->divfeat = ind;
00684                 left->point = point;
00685                 right->divfeat = node->divfeat;
00686                 right->point = node->point;
00687             }
00688             else {
00689                 left->divfeat = node->divfeat;
00690                 left->point = node->point;
00691                 right->divfeat = ind;
00692                 right->point = point;
00693             }
00694             node->divfeat = div_feat;
00695             node->divval = (point[div_feat]+leaf_point[div_feat])/2;
00696             node->child1 = left;
00697             node->child2 = right;            
00698         }
00699         else {
00700             if (point[node->divfeat]<node->divval) {
00701                 addPointToTree(node->child1,ind);
00702             }
00703             else {
00704                 addPointToTree(node->child2,ind);                
00705             }
00706         }
00707     }
00708 private:
00709     void swap(KDTreeIndex& other)
00710     {
00711         BaseClass::swap(other);
00712         std::swap(trees_, other.trees_);
00713         std::swap(tree_roots_, other.tree_roots_);
00714         std::swap(pool_, other.pool_);
00715     }
00716 
00717 private:
00718 
00719     enum
00720     {
00726         SAMPLE_MEAN = 100,
00734         RAND_DIM=5
00735     };
00736 
00737 
00741     int trees_;
00742 
00743     DistanceType* mean_;
00744     DistanceType* var_;
00745 
00749     std::vector<NodePtr> tree_roots_;
00750 
00758     PooledAllocator pool_;
00759 
00760     USING_BASECLASS_SYMBOLS
00761 };   // class KDTreeIndex
00762 
00763 }
00764 
00765 #endif //FLANN_KDTREE_INDEX_H_


rtabmap
Author(s): Mathieu Labbe
autogenerated on Sat Jul 23 2016 11:44:16