kdtree_single_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 #ifndef RTABMAP_FLANN_KDTREE_SINGLE_INDEX_H_
00032 #define RTABMAP_FLANN_KDTREE_SINGLE_INDEX_H_
00033 
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038 
00039 #include "rtflann/general.h"
00040 #include "rtflann/algorithms/nn_index.h"
00041 #include "rtflann/util/matrix.h"
00042 #include "rtflann/util/result_set.h"
00043 #include "rtflann/util/heap.h"
00044 #include "rtflann/util/allocator.h"
00045 #include "rtflann/util/random.h"
00046 #include "rtflann/util/saving.h"
00047 
00048 namespace rtflann
00049 {
00050 
00051 struct KDTreeSingleIndexParams : public IndexParams
00052 {
00053     KDTreeSingleIndexParams(int leaf_max_size = 10, bool reorder = true)
00054     {
00055         (*this)["algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
00056         (*this)["leaf_max_size"] = leaf_max_size;
00057         (*this)["reorder"] = reorder;
00058     }
00059 };
00060 
00061 
00068 template <typename Distance>
00069 class KDTreeSingleIndex : public NNIndex<Distance>
00070 {
00071 public:
00072     typedef typename Distance::ElementType ElementType;
00073     typedef typename Distance::ResultType DistanceType;
00074 
00075     typedef NNIndex<Distance> BaseClass;
00076 
00077     typedef bool needs_kdtree_distance;
00078 
00085     KDTreeSingleIndex(const IndexParams& params = KDTreeSingleIndexParams(), Distance d = Distance() ) :
00086         BaseClass(params, d), root_node_(NULL)
00087     {
00088         leaf_max_size_ = get_param(params,"leaf_max_size",10);
00089         reorder_ = get_param(params, "reorder", true);
00090     }
00091 
00099     KDTreeSingleIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeSingleIndexParams(),
00100                       Distance d = Distance() ) : BaseClass(params, d), root_node_(NULL)
00101     {
00102         leaf_max_size_ = get_param(params,"leaf_max_size",10);
00103         reorder_ = get_param(params, "reorder", true);
00104 
00105         setDataset(inputData);
00106     }
00107 
00108 
00109     KDTreeSingleIndex(const KDTreeSingleIndex& other) : BaseClass(other),
00110             leaf_max_size_(other.leaf_max_size_),
00111             reorder_(other.reorder_),
00112             vind_(other.vind_),
00113             root_bbox_(other.root_bbox_)
00114     {
00115         if (reorder_) {
00116             data_ = rtflann::Matrix<ElementType>(new ElementType[size_*veclen_], size_, veclen_);
00117             std::copy(other.data_[0], other.data_[0]+size_*veclen_, data_[0]);
00118         }
00119         copyTree(root_node_, other.root_node_);
00120     }
00121 
00122     KDTreeSingleIndex& operator=(KDTreeSingleIndex other)
00123     {
00124         this->swap(other);
00125         return *this;
00126     }
00127     
00131     virtual ~KDTreeSingleIndex()
00132     {
00133         freeIndex();
00134     }
00135     
00136     BaseClass* clone() const
00137     {
00138         return new KDTreeSingleIndex(*this);
00139     }
00140 
00141     using BaseClass::buildIndex;
00142 
00143     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00144     {
00145         assert(points.cols==veclen_);
00146         extendDataset(points);
00147         buildIndex();
00148     }
00149 
00150     flann_algorithm_t getType() const
00151     {
00152         return FLANN_INDEX_KDTREE_SINGLE;
00153     }
00154 
00155 
00156     template<typename Archive>
00157     void serialize(Archive& ar)
00158     {
00159         ar.setObject(this);
00160 
00161         if (reorder_) index_params_["save_dataset"] = false;
00162 
00163         ar & *static_cast<NNIndex<Distance>*>(this);
00164 
00165         ar & reorder_;
00166         ar & leaf_max_size_;
00167         ar & root_bbox_;
00168         ar & vind_;
00169 
00170         if (reorder_) {
00171             ar & data_;
00172         }
00173 
00174         if (Archive::is_loading::value) {
00175             root_node_ = new(pool_) Node();
00176         }
00177 
00178         ar & *root_node_;
00179 
00180         if (Archive::is_loading::value) {
00181             index_params_["algorithm"] = getType();
00182             index_params_["leaf_max_size"] = leaf_max_size_;
00183             index_params_["reorder"] = reorder_;
00184         }
00185     }
00186 
00187 
00188     void saveIndex(FILE* stream)
00189     {
00190         serialization::SaveArchive sa(stream);
00191         sa & *this;
00192     }
00193 
00194 
00195     void loadIndex(FILE* stream)
00196     {
00197         freeIndex();
00198         serialization::LoadArchive la(stream);
00199         la & *this;
00200     }
00201 
00206     int usedMemory() const
00207     {
00208         return pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int);  // pool memory and vind array memory
00209     }
00210 
00220     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00221     {
00222         float epsError = 1+searchParams.eps;
00223 
00224         std::vector<DistanceType> dists(veclen_,0);
00225         DistanceType distsq = computeInitialDistances(vec, dists);
00226         if (removed_) {
00227             searchLevel<true>(result, vec, root_node_, distsq, dists, epsError);
00228         }
00229         else {
00230             searchLevel<false>(result, vec, root_node_, distsq, dists, epsError);
00231         }
00232     }
00233 
00234 protected:
00235 
00239     void buildIndexImpl()
00240     {
00241         // Create a permutable array of indices to the input vectors.
00242         vind_.resize(size_);
00243         for (size_t i = 0; i < size_; i++) {
00244             vind_[i] = i;
00245         }
00246 
00247         computeBoundingBox(root_bbox_);
00248         root_node_ = divideTree(0, size_, root_bbox_ );   // construct the tree
00249 
00250         if (reorder_) {
00251             data_ = rtflann::Matrix<ElementType>(new ElementType[size_*veclen_], size_, veclen_);
00252             for (size_t i=0; i<size_; ++i) {
00253                 std::copy(points_[vind_[i]], points_[vind_[i]]+veclen_, data_[i]);
00254             }
00255         }
00256     }
00257 
00258 private:
00259 
00260 
00261     /*--------------------- Internal Data Structures --------------------------*/
00262     struct Node
00263     {
00267         int left, right;
00271         int divfeat;
00275         DistanceType divlow, divhigh;
00279         Node* child1, * child2;
00280         
00281         ~Node()
00282         {
00283             if (child1) child1->~Node();
00284             if (child2) child2->~Node();
00285         }
00286 
00287     private:
00288         template<typename Archive>
00289         void serialize(Archive& ar)
00290         {
00291             typedef KDTreeSingleIndex<Distance> Index;
00292             Index* obj = static_cast<Index*>(ar.getObject());
00293 
00294             ar & left;
00295             ar & right;
00296             ar & divfeat;
00297             ar & divlow;
00298             ar & divhigh;
00299 
00300             bool leaf_node = false;
00301             if (Archive::is_saving::value) {
00302                 leaf_node = ((child1==NULL) && (child2==NULL));
00303             }
00304             ar & leaf_node;
00305 
00306             if (!leaf_node) {
00307                 if (Archive::is_loading::value) {
00308                     child1 = new(obj->pool_) Node();
00309                     child2 = new(obj->pool_) Node();
00310                 }
00311                 ar & *child1;
00312                 ar & *child2;
00313             }
00314         }
00315         friend struct serialization::access;
00316     };
00317     typedef Node* NodePtr;
00318 
00319 
00320     struct Interval
00321     {
00322         DistanceType low, high;
00323         
00324     private:
00325         template <typename Archive>
00326         void serialize(Archive& ar)
00327         {
00328             ar & low;
00329             ar & high;
00330         }
00331         friend struct serialization::access;
00332     };
00333 
00334     typedef std::vector<Interval> BoundingBox;
00335 
00336     typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00337     typedef BranchSt* Branch;
00338 
00339 
00340     
00341     void freeIndex()
00342     {
00343         if (data_.ptr()) {
00344             delete[] data_.ptr();
00345             data_ = rtflann::Matrix<ElementType>();
00346         }
00347         if (root_node_) root_node_->~Node();
00348         pool_.free();
00349     }
00350     
00351     void copyTree(NodePtr& dst, const NodePtr& src)
00352     {
00353         dst = new(pool_) Node();
00354         *dst = *src;
00355         if (src->child1!=NULL && src->child2!=NULL) {
00356             copyTree(dst->child1, src->child1);
00357             copyTree(dst->child2, src->child2);
00358         }
00359     }
00360 
00361 
00362 
00363     void computeBoundingBox(BoundingBox& bbox)
00364     {
00365         bbox.resize(veclen_);
00366         for (size_t i=0; i<veclen_; ++i) {
00367             bbox[i].low = (DistanceType)points_[0][i];
00368             bbox[i].high = (DistanceType)points_[0][i];
00369         }
00370         for (size_t k=1; k<size_; ++k) {
00371             for (size_t i=0; i<veclen_; ++i) {
00372                 if (points_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)points_[k][i];
00373                 if (points_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)points_[k][i];
00374             }
00375         }
00376     }
00377 
00378 
00388     NodePtr divideTree(int left, int right, BoundingBox& bbox)
00389     {
00390         NodePtr node = new (pool_) Node(); // allocate memory
00391 
00392         /* If too few exemplars remain, then make this a leaf node. */
00393         if ( (right-left) <= leaf_max_size_) {
00394             node->child1 = node->child2 = NULL;    /* Mark as leaf node. */
00395             node->left = left;
00396             node->right = right;
00397 
00398             // compute bounding-box of leaf points
00399             for (size_t i=0; i<veclen_; ++i) {
00400                 bbox[i].low = (DistanceType)points_[vind_[left]][i];
00401                 bbox[i].high = (DistanceType)points_[vind_[left]][i];
00402             }
00403             for (int k=left+1; k<right; ++k) {
00404                 for (size_t i=0; i<veclen_; ++i) {
00405                     if (bbox[i].low>points_[vind_[k]][i]) bbox[i].low=(DistanceType)points_[vind_[k]][i];
00406                     if (bbox[i].high<points_[vind_[k]][i]) bbox[i].high=(DistanceType)points_[vind_[k]][i];
00407                 }
00408             }
00409         }
00410         else {
00411             int idx;
00412             int cutfeat;
00413             DistanceType cutval;
00414             middleSplit(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
00415 
00416             node->divfeat = cutfeat;
00417 
00418             BoundingBox left_bbox(bbox);
00419             left_bbox[cutfeat].high = cutval;
00420             node->child1 = divideTree(left, left+idx, left_bbox);
00421 
00422             BoundingBox right_bbox(bbox);
00423             right_bbox[cutfeat].low = cutval;
00424             node->child2 = divideTree(left+idx, right, right_bbox);
00425 
00426             node->divlow = left_bbox[cutfeat].high;
00427             node->divhigh = right_bbox[cutfeat].low;
00428 
00429             for (size_t i=0; i<veclen_; ++i) {
00430                 bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
00431                 bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
00432             }
00433         }
00434 
00435         return node;
00436     }
00437 
00438     void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
00439     {
00440         min_elem = points_[ind[0]][dim];
00441         max_elem = points_[ind[0]][dim];
00442         for (int i=1; i<count; ++i) {
00443             ElementType val = points_[ind[i]][dim];
00444             if (val<min_elem) min_elem = val;
00445             if (val>max_elem) max_elem = val;
00446         }
00447     }
00448 
00449     void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
00450     {
00451         // find the largest span from the approximate bounding box
00452         ElementType max_span = bbox[0].high-bbox[0].low;
00453         cutfeat = 0;
00454         cutval = (bbox[0].high+bbox[0].low)/2;
00455         for (size_t i=1; i<veclen_; ++i) {
00456             ElementType span = bbox[i].high-bbox[i].low;
00457             if (span>max_span) {
00458                 max_span = span;
00459                 cutfeat = i;
00460                 cutval = (bbox[i].high+bbox[i].low)/2;
00461             }
00462         }
00463 
00464         // compute exact span on the found dimension
00465         ElementType min_elem, max_elem;
00466         computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00467         cutval = (min_elem+max_elem)/2;
00468         max_span = max_elem - min_elem;
00469 
00470         // check if a dimension of a largest span exists
00471         size_t k = cutfeat;
00472         for (size_t i=0; i<veclen_; ++i) {
00473             if (i==k) continue;
00474             ElementType span = bbox[i].high-bbox[i].low;
00475             if (span>max_span) {
00476                 computeMinMax(ind, count, i, min_elem, max_elem);
00477                 span = max_elem - min_elem;
00478                 if (span>max_span) {
00479                     max_span = span;
00480                     cutfeat = i;
00481                     cutval = (min_elem+max_elem)/2;
00482                 }
00483             }
00484         }
00485         int lim1, lim2;
00486         planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00487 
00488         if (lim1>count/2) index = lim1;
00489         else if (lim2<count/2) index = lim2;
00490         else index = count/2;
00491         
00492         assert(index > 0 && index < count);
00493     }
00494 
00495 
00496     void middleSplit_(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
00497     {
00498         const float eps_val=0.00001f;
00499         DistanceType max_span = bbox[0].high-bbox[0].low;
00500         for (size_t i=1; i<veclen_; ++i) {
00501             DistanceType span = bbox[i].high-bbox[i].low;
00502             if (span>max_span) {
00503                 max_span = span;
00504             }
00505         }
00506         DistanceType max_spread = -1;
00507         cutfeat = 0;
00508         for (size_t i=0; i<veclen_; ++i) {
00509             DistanceType span = bbox[i].high-bbox[i].low;
00510             if (span>(DistanceType)((1-eps_val)*max_span)) {
00511                 ElementType min_elem, max_elem;
00512                 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00513                 DistanceType spread = (DistanceType)(max_elem-min_elem);
00514                 if (spread>max_spread) {
00515                     cutfeat = i;
00516                     max_spread = spread;
00517                 }
00518             }
00519         }
00520         // split in the middle
00521         DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
00522         ElementType min_elem, max_elem;
00523         computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00524 
00525         if (split_val<min_elem) cutval = (DistanceType)min_elem;
00526         else if (split_val>max_elem) cutval = (DistanceType)max_elem;
00527         else cutval = split_val;
00528 
00529         int lim1, lim2;
00530         planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00531 
00532         if (lim1>count/2) index = lim1;
00533         else if (lim2<count/2) index = lim2;
00534         else index = count/2;
00535         
00536         assert(index > 0 && index < count);
00537     }
00538 
00539 
00549     void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00550     {
00551         int left = 0;
00552         int right = count-1;
00553         for (;; ) {
00554             while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
00555             while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
00556             if (left>right) break;
00557             std::swap(ind[left], ind[right]); ++left; --right;
00558         }
00559 
00560         lim1 = left;
00561         right = count-1;
00562         for (;; ) {
00563             while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
00564             while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
00565             if (left>right) break;
00566             std::swap(ind[left], ind[right]); ++left; --right;
00567         }
00568         lim2 = left;
00569     }
00570 
00571     DistanceType computeInitialDistances(const ElementType* vec, std::vector<DistanceType>& dists) const
00572     {
00573         DistanceType distsq = 0.0;
00574 
00575         for (size_t i = 0; i < veclen_; ++i) {
00576             if (vec[i] < root_bbox_[i].low) {
00577                 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, i);
00578                 distsq += dists[i];
00579             }
00580             if (vec[i] > root_bbox_[i].high) {
00581                 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, i);
00582                 distsq += dists[i];
00583             }
00584         }
00585 
00586         return distsq;
00587     }
00588 
00592     template <bool with_removed>
00593     void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindistsq,
00594                      std::vector<DistanceType>& dists, const float epsError) const
00595     {
00596         /* If this is a leaf node, then do check and return. */
00597         if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00598             DistanceType worst_dist = result_set.worstDist();
00599             for (int i=node->left; i<node->right; ++i) {
00600                 if (with_removed) {
00601                     if (removed_points_.test(vind_[i])) continue;
00602                 }
00603                 ElementType* point = reorder_ ? data_[i] : points_[vind_[i]];
00604                 DistanceType dist = distance_(vec, point, veclen_, worst_dist);
00605                 if (dist<worst_dist) {
00606                     result_set.addPoint(dist,vind_[i]);
00607                 }
00608             }
00609             return;
00610         }
00611 
00612         /* Which child branch should be taken first? */
00613         int idx = node->divfeat;
00614         ElementType val = vec[idx];
00615         DistanceType diff1 = val - node->divlow;
00616         DistanceType diff2 = val - node->divhigh;
00617 
00618         NodePtr bestChild;
00619         NodePtr otherChild;
00620         DistanceType cut_dist;
00621         if ((diff1+diff2)<0) {
00622             bestChild = node->child1;
00623             otherChild = node->child2;
00624             cut_dist = distance_.accum_dist(val, node->divhigh, idx);
00625         }
00626         else {
00627             bestChild = node->child2;
00628             otherChild = node->child1;
00629             cut_dist = distance_.accum_dist( val, node->divlow, idx);
00630         }
00631 
00632         /* Call recursively to search next level down. */
00633         searchLevel<with_removed>(result_set, vec, bestChild, mindistsq, dists, epsError);
00634 
00635         DistanceType dst = dists[idx];
00636         mindistsq = mindistsq + cut_dist - dst;
00637         dists[idx] = cut_dist;
00638         if (mindistsq*epsError<=result_set.worstDist()) {
00639             searchLevel<with_removed>(result_set, vec, otherChild, mindistsq, dists, epsError);
00640         }
00641         dists[idx] = dst;
00642     }
00643 
00644     
00645     void swap(KDTreeSingleIndex& other)
00646     {
00647         BaseClass::swap(other);
00648         std::swap(leaf_max_size_, other.leaf_max_size_);
00649         std::swap(reorder_, other.reorder_);
00650         std::swap(vind_, other.vind_);
00651         std::swap(data_, other.data_);
00652         std::swap(root_node_, other.root_node_);
00653         std::swap(root_bbox_, other.root_bbox_);
00654         std::swap(pool_, other.pool_);
00655     }
00656     
00657 private:
00658 
00659 
00660 
00661     int leaf_max_size_;
00662     
00663     
00664     bool reorder_;
00665 
00669     std::vector<int> vind_;
00670 
00671     Matrix<ElementType> data_;
00672 
00676     NodePtr root_node_;
00677 
00681     BoundingBox root_bbox_;
00682 
00690     PooledAllocator pool_;
00691 
00692     USING_BASECLASS_SYMBOLS
00693 
00694 };   // class KDTreeSingleIndex
00695 
00696 }
00697 
00698 #endif //FLANN_KDTREE_SINGLE_INDEX_H_


rtabmap
Author(s): Mathieu Labbe
autogenerated on Thu Jun 6 2019 21:59:20