kdtree_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 #ifndef RTABMAP_FLANN_KDTREE_INDEX_H_
00032 #define RTABMAP_FLANN_KDTREE_INDEX_H_
00033 
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038 #include <stdarg.h>
00039 #include <cmath>
00040 
00041 #include "rtflann/general.h"
00042 #include "rtflann/algorithms/nn_index.h"
00043 #include "rtflann/util/dynamic_bitset.h"
00044 #include "rtflann/util/matrix.h"
00045 #include "rtflann/util/result_set.h"
00046 #include "rtflann/util/heap.h"
00047 #include "rtflann/util/allocator.h"
00048 #include "rtflann/util/random.h"
00049 #include "rtflann/util/saving.h"
00050 
00051 
00052 namespace rtflann
00053 {
00054 
00055 struct KDTreeIndexParams : public IndexParams
00056 {
00057     KDTreeIndexParams(int trees = 4)
00058     {
00059         (*this)["algorithm"] = FLANN_INDEX_KDTREE;
00060         (*this)["trees"] = trees;
00061     }
00062 };
00063 
00064 
00071 template <typename Distance>
00072 class KDTreeIndex : public NNIndex<Distance>
00073 {
00074 public:
00075     typedef typename Distance::ElementType ElementType;
00076     typedef typename Distance::ResultType DistanceType;
00077 
00078     typedef NNIndex<Distance> BaseClass;
00079 
00080     typedef bool needs_kdtree_distance;
00081 
00082 private:
00083          /*--------------------- Internal Data Structures --------------------------*/
00084         struct Node
00085         {
00089                 int divfeat;
00093                 DistanceType divval;
00097                 ElementType* point;
00101                 Node* child1, *child2;
00102                 Node(){
00103                         child1 = NULL;
00104                         child2 = NULL;
00105                 }
00106                 ~Node() {
00107                         if (child1 != NULL) { child1->~Node(); child1 = NULL; }
00108 
00109                         if (child2 != NULL) { child2->~Node(); child2 = NULL; }
00110                 }
00111 
00112         private:
00113                 template<typename Archive>
00114                 void serialize(Archive& ar)
00115                 {
00116                         typedef KDTreeIndex<Distance> Index;
00117                         Index* obj = static_cast<Index*>(ar.getObject());
00118 
00119                         ar & divfeat;
00120                         ar & divval;
00121 
00122                         bool leaf_node = false;
00123                         if (Archive::is_saving::value) {
00124                                 leaf_node = ((child1==NULL) && (child2==NULL));
00125                         }
00126                         ar & leaf_node;
00127 
00128                         if (leaf_node) {
00129                                 if (Archive::is_loading::value) {
00130                                         point = obj->points_[divfeat];
00131                                 }
00132                         }
00133 
00134                         if (!leaf_node) {
00135                                 if (Archive::is_loading::value) {
00136                                         child1 = new(obj->pool_) Node();
00137                                         child2 = new(obj->pool_) Node();
00138                                 }
00139                                 ar & *child1;
00140                                 ar & *child2;
00141                         }
00142                 }
00143                 friend struct serialization::access;
00144         };
00145 
00146         typedef Node* NodePtr;
00147         typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00148         typedef BranchSt* Branch;
00149 
00150 public:
00151 
00159     KDTreeIndex(const IndexParams& params = KDTreeIndexParams(), Distance d = Distance() ) :
00160         BaseClass(params, d), mean_(NULL), var_(NULL)
00161     {
00162         trees_ = get_param(index_params_,"trees",4);
00163     }
00164 
00165 
00173     KDTreeIndex(const Matrix<ElementType>& dataset, const IndexParams& params = KDTreeIndexParams(),
00174                 Distance d = Distance() ) : BaseClass(params,d ), mean_(NULL), var_(NULL)
00175     {
00176         trees_ = get_param(index_params_,"trees",4);
00177 
00178         setDataset(dataset);
00179     }
00180 
00181     KDTreeIndex(const KDTreeIndex& other) : BaseClass(other),
00182                 trees_(other.trees_)
00183     {
00184         tree_roots_.resize(other.tree_roots_.size());
00185         for (size_t i=0;i<tree_roots_.size();++i) {
00186                 copyTree(tree_roots_[i], other.tree_roots_[i]);
00187         }
00188     }
00189 
00190     KDTreeIndex& operator=(KDTreeIndex other)
00191     {
00192         this->swap(other);
00193         return *this;
00194     }
00195 
00199     virtual ~KDTreeIndex()
00200     {
00201         freeIndex();
00202     }
00203 
00204     BaseClass* clone() const
00205     {
00206         return new KDTreeIndex(*this);
00207     }
00208 
00209     using BaseClass::buildIndex;
00210     
00211     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00212     {
00213         assert(points.cols==veclen_);
00214 
00215         size_t old_size = size_;
00216         extendDataset(points);
00217         
00218         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00219             buildIndex();
00220         }
00221         else {
00222             for (size_t i=old_size;i<size_;++i) {
00223                 for (int j = 0; j < trees_; j++) {
00224                     addPointToTree(tree_roots_[j], i);
00225                 }
00226             }
00227         }        
00228     }
00229 
00230     flann_algorithm_t getType() const
00231     {
00232         return FLANN_INDEX_KDTREE;
00233     }
00234 
00235 
00236     template<typename Archive>
00237     void serialize(Archive& ar)
00238     {
00239         ar.setObject(this);
00240 
00241         ar & *static_cast<NNIndex<Distance>*>(this);
00242 
00243         ar & trees_;
00244 
00245         if (Archive::is_loading::value) {
00246                 tree_roots_.resize(trees_);
00247         }
00248         for (size_t i=0;i<tree_roots_.size();++i) {
00249                 if (Archive::is_loading::value) {
00250                         tree_roots_[i] = new(pool_) Node();
00251                 }
00252                 ar & *tree_roots_[i];
00253         }
00254 
00255         if (Archive::is_loading::value) {
00256             index_params_["algorithm"] = getType();
00257             index_params_["trees"] = trees_;
00258         }
00259     }
00260 
00261 
00262     void saveIndex(FILE* stream)
00263     {
00264         serialization::SaveArchive sa(stream);
00265         sa & *this;
00266     }
00267 
00268 
00269     void loadIndex(FILE* stream)
00270     {
00271         freeIndex();
00272         serialization::LoadArchive la(stream);
00273         la & *this;
00274     }
00275 
00280     int usedMemory() const
00281     {
00282         return int(pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int));  // pool memory and vind array memory
00283     }
00284 
00294     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00295     {
00296         int maxChecks = searchParams.checks;
00297         float epsError = 1+searchParams.eps;
00298 
00299         if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00300                 if (removed_) {
00301                         getExactNeighbors<true>(result, vec, epsError);
00302                 }
00303                 else {
00304                         getExactNeighbors<false>(result, vec, epsError);
00305                 }
00306         }
00307         else {
00308                 if (removed_) {
00309                         getNeighbors<true>(result, vec, maxChecks, epsError);
00310                 }
00311                 else {
00312                         getNeighbors<false>(result, vec, maxChecks, epsError);
00313                 }
00314         }
00315     }
00316 
00317 #ifdef ANDROID
00318 
00328         void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams, Heap<BranchSt>* heap) const
00329         {
00330                 int maxChecks = searchParams.checks;
00331                 float epsError = 1+searchParams.eps;
00332 
00333                 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00334                         if (removed_) {
00335                                 getExactNeighbors<true>(result, vec, epsError);
00336                         }
00337                         else {
00338                                 getExactNeighbors<false>(result, vec, epsError);
00339                         }
00340                 }
00341                 else {
00342                         if (removed_) {
00343                                 getNeighbors<true>(result, vec, maxChecks, epsError, heap);
00344                         }
00345                         else {
00346                                 getNeighbors<false>(result, vec, maxChecks, epsError, heap);
00347                         }
00348                 }
00349         }
00350 
00359         virtual int knnSearch(const Matrix<ElementType>& queries,
00360                         Matrix<size_t>& indices,
00361                         Matrix<DistanceType>& dists,
00362                         size_t knn,
00363                         const SearchParams& params) const
00364         {
00365                 assert(queries.cols == veclen());
00366                 assert(indices.rows >= queries.rows);
00367                 assert(dists.rows >= queries.rows);
00368                 assert(indices.cols >= knn);
00369                 assert(dists.cols >= knn);
00370                 bool use_heap;
00371 
00372                 if (params.use_heap==FLANN_Undefined) {
00373                         use_heap = (knn>KNN_HEAP_THRESHOLD)?true:false;
00374                 }
00375                 else {
00376                         use_heap = (params.use_heap==FLANN_True)?true:false;
00377                 }
00378                 int count = 0;
00379 
00380                 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00381 
00382                 if (use_heap) {
00383         //#pragma omp parallel num_threads(params.cores)
00384                         {
00385                                 KNNResultSet2<DistanceType> resultSet(knn);
00386         //#pragma omp for schedule(static) reduction(+:count)
00387                                 for (int i = 0; i < (int)queries.rows; i++) {
00388                                         resultSet.clear();
00389                                         findNeighbors(resultSet, queries[i], params, heap);
00390                                         size_t n = std::min(resultSet.size(), knn);
00391                                         resultSet.copy(indices[i], dists[i], n, params.sorted);
00392                                         indices_to_ids(indices[i], indices[i], n);
00393                                         count += n;
00394                                 }
00395                         }
00396                 }
00397                 else {
00398                         std::vector<double> times(queries.rows);
00399         //#pragma omp parallel num_threads(params.cores)
00400                         {
00401                                 KNNSimpleResultSet<DistanceType> resultSet(knn);
00402         //#pragma omp for schedule(static) reduction(+:count)
00403                                 for (int i = 0; i < (int)queries.rows; i++) {
00404                                         resultSet.clear();
00405                                         findNeighbors(resultSet, queries[i], params, heap);
00406                                         size_t n = std::min(resultSet.size(), knn);
00407                                         resultSet.copy(indices[i], dists[i], n, params.sorted);
00408                                         indices_to_ids(indices[i], indices[i], n);
00409                                         count += n;
00410                                 }
00411                         }
00412                         std::sort(times.begin(), times.end());
00413                 }
00414                 delete heap;
00415                 return count;
00416         }
00417 
00418 
00427         virtual int knnSearch(const Matrix<ElementType>& queries,
00428                                         std::vector< std::vector<size_t> >& indices,
00429                                         std::vector<std::vector<DistanceType> >& dists,
00430                                         size_t knn,
00431                                         const SearchParams& params) const
00432         {
00433                 assert(queries.cols == veclen());
00434                 bool use_heap;
00435                 if (params.use_heap==FLANN_Undefined) {
00436                         use_heap = (knn>KNN_HEAP_THRESHOLD)?true:false;
00437                 }
00438                 else {
00439                         use_heap = (params.use_heap==FLANN_True)?true:false;
00440                 }
00441 
00442                 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00443                 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00444 
00445                 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00446 
00447                 int count = 0;
00448                 if (use_heap) {
00449         //#pragma omp parallel num_threads(params.cores)
00450                         {
00451                                 KNNResultSet2<DistanceType> resultSet(knn);
00452         //#pragma omp for schedule(static) reduction(+:count)
00453                                 for (int i = 0; i < (int)queries.rows; i++) {
00454                                         resultSet.clear();
00455                                         findNeighbors(resultSet, queries[i], params, heap);
00456                                         size_t n = std::min(resultSet.size(), knn);
00457                                         indices[i].resize(n);
00458                                         dists[i].resize(n);
00459                                         if (n>0) {
00460                                                 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00461                                                 indices_to_ids(&indices[i][0], &indices[i][0], n);
00462                                         }
00463                                         count += n;
00464                                 }
00465                         }
00466                 }
00467                 else {
00468         //#pragma omp parallel num_threads(params.cores)
00469                         {
00470                                 KNNSimpleResultSet<DistanceType> resultSet(knn);
00471         //#pragma omp for schedule(static) reduction(+:count)
00472                                 for (int i = 0; i < (int)queries.rows; i++) {
00473                                         resultSet.clear();
00474                                         findNeighbors(resultSet, queries[i], params, heap);
00475                                         size_t n = std::min(resultSet.size(), knn);
00476                                         indices[i].resize(n);
00477                                         dists[i].resize(n);
00478                                         if (n>0) {
00479                                                 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00480                                                 indices_to_ids(&indices[i][0], &indices[i][0], n);
00481                                         }
00482                                         count += n;
00483                                 }
00484                         }
00485                 }
00486                 delete heap;
00487 
00488                 return count;
00489         }
00490 
00500         virtual int radiusSearch(const Matrix<ElementType>& queries,
00501                         Matrix<size_t>& indices,
00502                         Matrix<DistanceType>& dists,
00503                         float radius,
00504                         const SearchParams& params) const
00505         {
00506                 assert(queries.cols == veclen());
00507                 int count = 0;
00508                 size_t num_neighbors = std::min(indices.cols, dists.cols);
00509                 int max_neighbors = params.max_neighbors;
00510                 if (max_neighbors<0) max_neighbors = num_neighbors;
00511                 else max_neighbors = std::min(max_neighbors,(int)num_neighbors);
00512 
00513                 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00514 
00515                 if (max_neighbors==0) {
00516         //#pragma omp parallel num_threads(params.cores)
00517                         {
00518                                 CountRadiusResultSet<DistanceType> resultSet(radius);
00519         //#pragma omp for schedule(static) reduction(+:count)
00520                                 for (int i = 0; i < (int)queries.rows; i++) {
00521                                         resultSet.clear();
00522                                         findNeighbors(resultSet, queries[i], params, heap);
00523                                         count += resultSet.size();
00524                                 }
00525                         }
00526                 }
00527                 else {
00528                         // explicitly indicated to use unbounded radius result set
00529                         // and we know there'll be enough room for resulting indices and dists
00530                         if (params.max_neighbors<0 && (num_neighbors>=this->size())) {
00531         //#pragma omp parallel num_threads(params.cores)
00532                                 {
00533                                         RadiusResultSet<DistanceType> resultSet(radius);
00534         //#pragma omp for schedule(static) reduction(+:count)
00535                                         for (int i = 0; i < (int)queries.rows; i++) {
00536                                                 resultSet.clear();
00537                                                 findNeighbors(resultSet, queries[i], params, heap);
00538                                                 size_t n = resultSet.size();
00539                                                 count += n;
00540                                                 if (n>num_neighbors) n = num_neighbors;
00541                                                 resultSet.copy(indices[i], dists[i], n, params.sorted);
00542 
00543                                                 // mark the next element in the output buffers as unused
00544                                                 if (n<indices.cols) indices[i][n] = size_t(-1);
00545                                                 if (n<dists.cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
00546                                                 indices_to_ids(indices[i], indices[i], n);
00547                                         }
00548                                 }
00549                         }
00550                         else {
00551                                 // number of neighbors limited to max_neighbors
00552         //#pragma omp parallel num_threads(params.cores)
00553                                 {
00554                                         KNNRadiusResultSet<DistanceType> resultSet(radius, max_neighbors);
00555         //#pragma omp for schedule(static) reduction(+:count)
00556                                         for (int i = 0; i < (int)queries.rows; i++) {
00557                                                 resultSet.clear();
00558                                                 findNeighbors(resultSet, queries[i], params, heap);
00559                                                 size_t n = resultSet.size();
00560                                                 count += n;
00561                                                 if ((int)n>max_neighbors) n = max_neighbors;
00562                                                 resultSet.copy(indices[i], dists[i], n, params.sorted);
00563 
00564                                                 // mark the next element in the output buffers as unused
00565                                                 if (n<indices.cols) indices[i][n] = size_t(-1);
00566                                                 if (n<dists.cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
00567                                                 indices_to_ids(indices[i], indices[i], n);
00568                                         }
00569                                 }
00570                         }
00571                 }
00572                 delete heap;
00573                 return count;
00574         }
00575 
00585         virtual int radiusSearch(const Matrix<ElementType>& queries,
00586                         std::vector< std::vector<size_t> >& indices,
00587                         std::vector<std::vector<DistanceType> >& dists,
00588                         float radius,
00589                         const SearchParams& params) const
00590         {
00591                 assert(queries.cols == veclen());
00592                 int count = 0;
00593 
00594                 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00595 
00596                 // just count neighbors
00597                 if (params.max_neighbors==0) {
00598         //#pragma omp parallel num_threads(params.cores)
00599                         {
00600                                 CountRadiusResultSet<DistanceType> resultSet(radius);
00601         //#pragma omp for schedule(static) reduction(+:count)
00602                                 for (int i = 0; i < (int)queries.rows; i++) {
00603                                         resultSet.clear();
00604                                         findNeighbors(resultSet, queries[i], params, heap);
00605                                         count += resultSet.size();
00606                                 }
00607                         }
00608                 }
00609                 else {
00610                         if (indices.size() < queries.rows ) indices.resize(queries.rows);
00611                         if (dists.size() < queries.rows ) dists.resize(queries.rows);
00612 
00613                         if (params.max_neighbors<0) {
00614                                 // search for all neighbors
00615         //#pragma omp parallel num_threads(params.cores)
00616                                 {
00617                                         RadiusResultSet<DistanceType> resultSet(radius);
00618         //#pragma omp for schedule(static) reduction(+:count)
00619                                         for (int i = 0; i < (int)queries.rows; i++) {
00620                                                 resultSet.clear();
00621                                                 findNeighbors(resultSet, queries[i], params, heap);
00622                                                 size_t n = resultSet.size();
00623                                                 count += n;
00624                                                 indices[i].resize(n);
00625                                                 dists[i].resize(n);
00626                                                 if (n > 0) {
00627                                                         resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00628                                                         indices_to_ids(&indices[i][0], &indices[i][0], n);
00629                                                 }
00630                                         }
00631                                 }
00632                         }
00633                         else {
00634                                 // number of neighbors limited to max_neighbors
00635         //#pragma omp parallel num_threads(params.cores)
00636                                 {
00637                                         KNNRadiusResultSet<DistanceType> resultSet(radius, params.max_neighbors);
00638         //#pragma omp for schedule(static) reduction(+:count)
00639                                         for (int i = 0; i < (int)queries.rows; i++) {
00640                                                 resultSet.clear();
00641                                                 findNeighbors(resultSet, queries[i], params, heap);
00642                                                 size_t n = resultSet.size();
00643                                                 count += n;
00644                                                 if ((int)n>params.max_neighbors) n = params.max_neighbors;
00645                                                 indices[i].resize(n);
00646                                                 dists[i].resize(n);
00647                                                 if (n > 0) {
00648                                                         resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00649                                                         indices_to_ids(&indices[i][0], &indices[i][0], n);
00650                                                 }
00651                                         }
00652                                 }
00653                         }
00654                 }
00655                 delete heap;
00656                 return count;
00657         }
00658 #endif
00659 
00660 protected:
00661 
00665     void buildIndexImpl()
00666     {
00667         // Create a permutable array of indices to the input vectors.
00668         std::vector<int> ind(size_);
00669         for (size_t i = 0; i < size_; ++i) {
00670             ind[i] = int(i);
00671         }
00672 
00673         mean_ = new DistanceType[veclen_];
00674         var_ = new DistanceType[veclen_];
00675 
00676         tree_roots_.resize(trees_);
00677         /* Construct the randomized trees. */
00678         for (int i = 0; i < trees_; i++) {
00679             /* Randomize the order of vectors to allow for unbiased sampling. */
00680             std::random_shuffle(ind.begin(), ind.end());
00681             tree_roots_[i] = divideTree(&ind[0], int(size_) );
00682         }
00683         delete[] mean_;
00684         delete[] var_;
00685     }
00686 
00687     void freeIndex()
00688     {
00689         for (size_t i=0;i<tree_roots_.size();++i) {
00690                 // using placement new, so call destructor explicitly
00691                 if (tree_roots_[i]!=NULL) tree_roots_[i]->~Node();
00692         }
00693         pool_.free();
00694     }
00695 
00696 
00697 private:
00698 
00699     void copyTree(NodePtr& dst, const NodePtr& src)
00700     {
00701         dst = new(pool_) Node();
00702         dst->divfeat = src->divfeat;
00703         dst->divval = src->divval;
00704         if (src->child1==NULL && src->child2==NULL) {
00705                 dst->point = points_[dst->divfeat];
00706                 dst->child1 = NULL;
00707                 dst->child2 = NULL;
00708         }
00709         else {
00710                 copyTree(dst->child1, src->child1);
00711                 copyTree(dst->child2, src->child2);
00712         }
00713     }
00714 
00724     NodePtr divideTree(int* ind, int count)
00725     {
00726         NodePtr node = new(pool_) Node(); // allocate memory
00727 
00728         /* If too few exemplars remain, then make this a leaf node. */
00729         if (count == 1) {
00730             node->child1 = node->child2 = NULL;    /* Mark as leaf node. */
00731             node->divfeat = *ind;    /* Store index of this vec. */
00732             node->point = points_[*ind];
00733         }
00734         else {
00735             int idx;
00736             int cutfeat;
00737             DistanceType cutval;
00738             meanSplit(ind, count, idx, cutfeat, cutval);
00739 
00740             node->divfeat = cutfeat;
00741             node->divval = cutval;
00742             node->child1 = divideTree(ind, idx);
00743             node->child2 = divideTree(ind+idx, count-idx);
00744         }
00745 
00746         return node;
00747     }
00748 
00749 
00755     void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
00756     {
00757         memset(mean_,0,veclen_*sizeof(DistanceType));
00758         memset(var_,0,veclen_*sizeof(DistanceType));
00759 
00760         /* Compute mean values.  Only the first SAMPLE_MEAN values need to be
00761             sampled to get a good estimate.
00762          */
00763         int cnt = std::min((int)SAMPLE_MEAN+1, count);
00764         for (int j = 0; j < cnt; ++j) {
00765             ElementType* v = points_[ind[j]];
00766             for (size_t k=0; k<veclen_; ++k) {
00767                 mean_[k] += v[k];
00768             }
00769         }
00770         DistanceType div_factor = DistanceType(1)/cnt;
00771         for (size_t k=0; k<veclen_; ++k) {
00772             mean_[k] *= div_factor;
00773         }
00774 
00775         /* Compute variances (no need to divide by count). */
00776         for (int j = 0; j < cnt; ++j) {
00777             ElementType* v = points_[ind[j]];
00778             for (size_t k=0; k<veclen_; ++k) {
00779                 DistanceType dist = v[k] - mean_[k];
00780                 var_[k] += dist * dist;
00781             }
00782         }
00783         /* Select one of the highest variance indices at random. */
00784         cutfeat = selectDivision(var_);
00785         cutval = mean_[cutfeat];
00786 
00787         int lim1, lim2;
00788         planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00789 
00790         if (lim1>count/2) index = lim1;
00791         else if (lim2<count/2) index = lim2;
00792         else index = count/2;
00793 
00794         /* If either list is empty, it means that all remaining features
00795          * are identical. Split in the middle to maintain a balanced tree.
00796          */
00797         if ((lim1==count)||(lim2==0)) index = count/2;
00798     }
00799 
00800 
00805     int selectDivision(DistanceType* v)
00806     {
00807         int num = 0;
00808         size_t topind[RAND_DIM];
00809 
00810         /* Create a list of the indices of the top RAND_DIM values. */
00811         for (size_t i = 0; i < veclen_; ++i) {
00812             if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
00813                 /* Put this element at end of topind. */
00814                 if (num < RAND_DIM) {
00815                     topind[num++] = i;            /* Add to list. */
00816                 }
00817                 else {
00818                     topind[num-1] = i;         /* Replace last element. */
00819                 }
00820                 /* Bubble end value down to right location by repeated swapping. */
00821                 int j = num - 1;
00822                 while (j > 0  &&  v[topind[j]] > v[topind[j-1]]) {
00823                     std::swap(topind[j], topind[j-1]);
00824                     --j;
00825                 }
00826             }
00827         }
00828         /* Select a random integer in range [0,num-1], and return that index. */
00829         int rnd = rand_int(num);
00830         return (int)topind[rnd];
00831     }
00832 
00833 
00843     void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00844     {
00845         /* Move vector indices for left subtree to front of list. */
00846         int left = 0;
00847         int right = count-1;
00848         for (;; ) {
00849             while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
00850             while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
00851             if (left>right) break;
00852             std::swap(ind[left], ind[right]); ++left; --right;
00853         }
00854         lim1 = left;
00855         right = count-1;
00856         for (;; ) {
00857             while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
00858             while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
00859             if (left>right) break;
00860             std::swap(ind[left], ind[right]); ++left; --right;
00861         }
00862         lim2 = left;
00863     }
00864 
00869     template<bool with_removed>
00870     void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError) const
00871     {
00872         //              checkID -= 1;  /* Set a different unique ID for each search. */
00873 
00874         if (trees_ > 1) {
00875             fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00876         }
00877         if (trees_>0) {
00878             searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError);
00879         }
00880     }
00881 
00887     template<bool with_removed>
00888     void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError) const
00889     {
00890         int i;
00891         BranchSt branch;
00892 
00893         int checkCount = 0;
00894         Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00895         DynamicBitset checked(size_);
00896 
00897         /* Search once through each tree down to root. */
00898         for (i = 0; i < trees_; ++i) {
00899             searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00900         }
00901 
00902         /* Keep searching other branches from heap until finished. */
00903         while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00904             searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00905         }
00906 
00907         delete heap;
00908 
00909     }
00910 
00911 #ifdef ANDROID
00912 
00917         template<bool with_removed>
00918         void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError, Heap<BranchSt>* heap) const
00919         {
00920                 int i;
00921                 BranchSt branch;
00922 
00923                 int checkCount = 0;
00924                 DynamicBitset checked(size_);
00925                 heap->clear();
00926 
00927                 /* Search once through each tree down to root. */
00928                 for (i = 0; i < trees_; ++i) {
00929                         searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00930                 }
00931 
00932                 /* Keep searching other branches from heap until finished. */
00933                 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00934                         searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00935                 }
00936         }
00937 #endif
00938 
00939 
00945     template<bool with_removed>
00946     void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
00947                      float epsError, Heap<BranchSt>* heap, DynamicBitset& checked) const
00948     {
00949         if (result_set.worstDist()<mindist) {
00950             //                  printf("Ignoring branch, too far\n");
00951             return;
00952         }
00953 
00954         /* If this is a leaf node, then do check and return. */
00955         if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00956             int index = node->divfeat;
00957             if (with_removed) {
00958                 if (removed_points_.test(index)) return;
00959             }
00960             /*  Do not check same node more than once when searching multiple trees. */
00961             if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
00962             checked.set(index);
00963             checkCount++;
00964 
00965             DistanceType dist = distance_(node->point, vec, veclen_);
00966             result_set.addPoint(dist,index);
00967             return;
00968         }
00969 
00970         /* Which child branch should be taken first? */
00971         ElementType val = vec[node->divfeat];
00972         DistanceType diff = val - node->divval;
00973         NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00974         NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00975 
00976         /* Create a branch record for the branch not taken.  Add distance
00977             of this feature boundary (we don't attempt to correct for any
00978             use of this feature in a parent node, which is unlikely to
00979             happen and would have only a small effect).  Don't bother
00980             adding more branches to heap after halfway point, as cost of
00981             adding exceeds their value.
00982          */
00983 
00984         DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00985         //              if (2 * checkCount < maxCheck  ||  !result.full()) {
00986         if ((new_distsq*epsError < result_set.worstDist())||  !result_set.full()) {
00987             heap->insert( BranchSt(otherChild, new_distsq) );
00988         }
00989 
00990         /* Call recursively to search next level down. */
00991         searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
00992     }
00993 
00997     template<bool with_removed>
00998     void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError) const
00999     {
01000         /* If this is a leaf node, then do check and return. */
01001         if ((node->child1 == NULL)&&(node->child2 == NULL)) {
01002             int index = node->divfeat;
01003             if (with_removed) {
01004                 if (removed_points_.test(index)) return; // ignore removed points
01005             }
01006             DistanceType dist = distance_(node->point, vec, veclen_);
01007             result_set.addPoint(dist,index);
01008 
01009             return;
01010         }
01011 
01012         /* Which child branch should be taken first? */
01013         ElementType val = vec[node->divfeat];
01014         DistanceType diff = val - node->divval;
01015         NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
01016         NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
01017 
01018         /* Create a branch record for the branch not taken.  Add distance
01019             of this feature boundary (we don't attempt to correct for any
01020             use of this feature in a parent node, which is unlikely to
01021             happen and would have only a small effect).  Don't bother
01022             adding more branches to heap after halfway point, as cost of
01023             adding exceeds their value.
01024          */
01025 
01026         DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
01027 
01028         /* Call recursively to search next level down. */
01029         searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError);
01030 
01031         if (mindist*epsError<=result_set.worstDist()) {
01032             searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError);
01033         }
01034     }
01035     
01036     void addPointToTree(NodePtr node, int ind)
01037     {
01038         ElementType* point = points_[ind];
01039         
01040         if ((node->child1==NULL) && (node->child2==NULL)) {
01041             ElementType* leaf_point = node->point;
01042             ElementType max_span = 0;
01043             size_t div_feat = 0;
01044             for (size_t i=0;i<veclen_;++i) {
01045                 ElementType span = std::abs(point[i]-leaf_point[i]);
01046                 if (span > max_span) {
01047                     max_span = span;
01048                     div_feat = i;
01049                 }
01050             }
01051             NodePtr left = new(pool_) Node();
01052             left->child1 = left->child2 = NULL;
01053             NodePtr right = new(pool_) Node();
01054             right->child1 = right->child2 = NULL;
01055 
01056             if (point[div_feat]<leaf_point[div_feat]) {
01057                 left->divfeat = ind;
01058                 left->point = point;
01059                 right->divfeat = node->divfeat;
01060                 right->point = node->point;
01061             }
01062             else {
01063                 left->divfeat = node->divfeat;
01064                 left->point = node->point;
01065                 right->divfeat = ind;
01066                 right->point = point;
01067             }
01068             node->divfeat = div_feat;
01069             node->divval = (point[div_feat]+leaf_point[div_feat])/2;
01070             node->child1 = left;
01071             node->child2 = right;            
01072         }
01073         else {
01074             if (point[node->divfeat]<node->divval) {
01075                 addPointToTree(node->child1,ind);
01076             }
01077             else {
01078                 addPointToTree(node->child2,ind);                
01079             }
01080         }
01081     }
01082 private:
01083     void swap(KDTreeIndex& other)
01084     {
01085         BaseClass::swap(other);
01086         std::swap(trees_, other.trees_);
01087         std::swap(tree_roots_, other.tree_roots_);
01088         std::swap(pool_, other.pool_);
01089     }
01090 
01091 private:
01092 
01093     enum
01094     {
01100         SAMPLE_MEAN = 100,
01108         RAND_DIM=5
01109     };
01110 
01111 
01115     int trees_;
01116 
01117     DistanceType* mean_;
01118     DistanceType* var_;
01119 
01123     std::vector<NodePtr> tree_roots_;
01124 
01132     PooledAllocator pool_;
01133 
01134     USING_BASECLASS_SYMBOLS
01135 };   // class KDTreeIndex
01136 
01137 }
01138 
01139 #endif //FLANN_KDTREE_INDEX_H_


rtabmap
Author(s): Mathieu Labbe
autogenerated on Thu Jun 6 2019 21:59:20