00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef RTABMAP_FLANN_KDTREE_INDEX_H_
00032 #define RTABMAP_FLANN_KDTREE_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038 #include <stdarg.h>
00039 #include <cmath>
00040
00041 #include "rtflann/general.h"
00042 #include "rtflann/algorithms/nn_index.h"
00043 #include "rtflann/util/dynamic_bitset.h"
00044 #include "rtflann/util/matrix.h"
00045 #include "rtflann/util/result_set.h"
00046 #include "rtflann/util/heap.h"
00047 #include "rtflann/util/allocator.h"
00048 #include "rtflann/util/random.h"
00049 #include "rtflann/util/saving.h"
00050
00051
00052 namespace rtflann
00053 {
00054
00055 struct KDTreeIndexParams : public IndexParams
00056 {
00057 KDTreeIndexParams(int trees = 4)
00058 {
00059 (*this)["algorithm"] = FLANN_INDEX_KDTREE;
00060 (*this)["trees"] = trees;
00061 }
00062 };
00063
00064
00071 template <typename Distance>
00072 class KDTreeIndex : public NNIndex<Distance>
00073 {
00074 public:
00075 typedef typename Distance::ElementType ElementType;
00076 typedef typename Distance::ResultType DistanceType;
00077
00078 typedef NNIndex<Distance> BaseClass;
00079
00080 typedef bool needs_kdtree_distance;
00081
00082
00090 KDTreeIndex(const IndexParams& params = KDTreeIndexParams(), Distance d = Distance() ) :
00091 BaseClass(params, d), mean_(NULL), var_(NULL)
00092 {
00093 trees_ = get_param(index_params_,"trees",4);
00094 }
00095
00096
00104 KDTreeIndex(const Matrix<ElementType>& dataset, const IndexParams& params = KDTreeIndexParams(),
00105 Distance d = Distance() ) : BaseClass(params,d ), mean_(NULL), var_(NULL)
00106 {
00107 trees_ = get_param(index_params_,"trees",4);
00108
00109 setDataset(dataset);
00110 }
00111
00112 KDTreeIndex(const KDTreeIndex& other) : BaseClass(other),
00113 trees_(other.trees_)
00114 {
00115 tree_roots_.resize(other.tree_roots_.size());
00116 for (size_t i=0;i<tree_roots_.size();++i) {
00117 copyTree(tree_roots_[i], other.tree_roots_[i]);
00118 }
00119 }
00120
00121 KDTreeIndex& operator=(KDTreeIndex other)
00122 {
00123 this->swap(other);
00124 return *this;
00125 }
00126
00130 virtual ~KDTreeIndex()
00131 {
00132 freeIndex();
00133 }
00134
00135 BaseClass* clone() const
00136 {
00137 return new KDTreeIndex(*this);
00138 }
00139
00140 using BaseClass::buildIndex;
00141
00142 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00143 {
00144 assert(points.cols==veclen_);
00145
00146 size_t old_size = size_;
00147 extendDataset(points);
00148
00149 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00150 buildIndex();
00151 }
00152 else {
00153 for (size_t i=old_size;i<size_;++i) {
00154 for (int j = 0; j < trees_; j++) {
00155 addPointToTree(tree_roots_[j], i);
00156 }
00157 }
00158 }
00159 }
00160
00161 flann_algorithm_t getType() const
00162 {
00163 return FLANN_INDEX_KDTREE;
00164 }
00165
00166
00167 template<typename Archive>
00168 void serialize(Archive& ar)
00169 {
00170 ar.setObject(this);
00171
00172 ar & *static_cast<NNIndex<Distance>*>(this);
00173
00174 ar & trees_;
00175
00176 if (Archive::is_loading::value) {
00177 tree_roots_.resize(trees_);
00178 }
00179 for (size_t i=0;i<tree_roots_.size();++i) {
00180 if (Archive::is_loading::value) {
00181 tree_roots_[i] = new(pool_) Node();
00182 }
00183 ar & *tree_roots_[i];
00184 }
00185
00186 if (Archive::is_loading::value) {
00187 index_params_["algorithm"] = getType();
00188 index_params_["trees"] = trees_;
00189 }
00190 }
00191
00192
00193 void saveIndex(FILE* stream)
00194 {
00195 serialization::SaveArchive sa(stream);
00196 sa & *this;
00197 }
00198
00199
00200 void loadIndex(FILE* stream)
00201 {
00202 freeIndex();
00203 serialization::LoadArchive la(stream);
00204 la & *this;
00205 }
00206
00211 int usedMemory() const
00212 {
00213 return int(pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int));
00214 }
00215
00225 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00226 {
00227 int maxChecks = searchParams.checks;
00228 float epsError = 1+searchParams.eps;
00229
00230 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00231 if (removed_) {
00232 getExactNeighbors<true>(result, vec, epsError);
00233 }
00234 else {
00235 getExactNeighbors<false>(result, vec, epsError);
00236 }
00237 }
00238 else {
00239 if (removed_) {
00240 getNeighbors<true>(result, vec, maxChecks, epsError);
00241 }
00242 else {
00243 getNeighbors<false>(result, vec, maxChecks, epsError);
00244 }
00245 }
00246 }
00247
00248 protected:
00249
00253 void buildIndexImpl()
00254 {
00255
00256 std::vector<int> ind(size_);
00257 for (size_t i = 0; i < size_; ++i) {
00258 ind[i] = int(i);
00259 }
00260
00261 mean_ = new DistanceType[veclen_];
00262 var_ = new DistanceType[veclen_];
00263
00264 tree_roots_.resize(trees_);
00265
00266 for (int i = 0; i < trees_; i++) {
00267
00268 std::random_shuffle(ind.begin(), ind.end());
00269 tree_roots_[i] = divideTree(&ind[0], int(size_) );
00270 }
00271 delete[] mean_;
00272 delete[] var_;
00273 }
00274
00275 void freeIndex()
00276 {
00277 for (size_t i=0;i<tree_roots_.size();++i) {
00278
00279 if (tree_roots_[i]!=NULL) tree_roots_[i]->~Node();
00280 }
00281 pool_.free();
00282 }
00283
00284
00285 private:
00286
00287
00288 struct Node
00289 {
00293 int divfeat;
00297 DistanceType divval;
00301 ElementType* point;
00305 Node* child1, *child2;
00306 Node(){
00307 child1 = NULL;
00308 child2 = NULL;
00309 }
00310 ~Node() {
00311 if (child1 != NULL) { child1->~Node(); child1 = NULL; }
00312
00313 if (child2 != NULL) { child2->~Node(); child2 = NULL; }
00314 }
00315
00316 private:
00317 template<typename Archive>
00318 void serialize(Archive& ar)
00319 {
00320 typedef KDTreeIndex<Distance> Index;
00321 Index* obj = static_cast<Index*>(ar.getObject());
00322
00323 ar & divfeat;
00324 ar & divval;
00325
00326 bool leaf_node = false;
00327 if (Archive::is_saving::value) {
00328 leaf_node = ((child1==NULL) && (child2==NULL));
00329 }
00330 ar & leaf_node;
00331
00332 if (leaf_node) {
00333 if (Archive::is_loading::value) {
00334 point = obj->points_[divfeat];
00335 }
00336 }
00337
00338 if (!leaf_node) {
00339 if (Archive::is_loading::value) {
00340 child1 = new(obj->pool_) Node();
00341 child2 = new(obj->pool_) Node();
00342 }
00343 ar & *child1;
00344 ar & *child2;
00345 }
00346 }
00347 friend struct serialization::access;
00348 };
00349 typedef Node* NodePtr;
00350 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00351 typedef BranchSt* Branch;
00352
00353
00354 void copyTree(NodePtr& dst, const NodePtr& src)
00355 {
00356 dst = new(pool_) Node();
00357 dst->divfeat = src->divfeat;
00358 dst->divval = src->divval;
00359 if (src->child1==NULL && src->child2==NULL) {
00360 dst->point = points_[dst->divfeat];
00361 dst->child1 = NULL;
00362 dst->child2 = NULL;
00363 }
00364 else {
00365 copyTree(dst->child1, src->child1);
00366 copyTree(dst->child2, src->child2);
00367 }
00368 }
00369
00379 NodePtr divideTree(int* ind, int count)
00380 {
00381 NodePtr node = new(pool_) Node();
00382
00383
00384 if (count == 1) {
00385 node->child1 = node->child2 = NULL;
00386 node->divfeat = *ind;
00387 node->point = points_[*ind];
00388 }
00389 else {
00390 int idx;
00391 int cutfeat;
00392 DistanceType cutval;
00393 meanSplit(ind, count, idx, cutfeat, cutval);
00394
00395 node->divfeat = cutfeat;
00396 node->divval = cutval;
00397 node->child1 = divideTree(ind, idx);
00398 node->child2 = divideTree(ind+idx, count-idx);
00399 }
00400
00401 return node;
00402 }
00403
00404
00410 void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
00411 {
00412 memset(mean_,0,veclen_*sizeof(DistanceType));
00413 memset(var_,0,veclen_*sizeof(DistanceType));
00414
00415
00416
00417
00418 int cnt = std::min((int)SAMPLE_MEAN+1, count);
00419 for (int j = 0; j < cnt; ++j) {
00420 ElementType* v = points_[ind[j]];
00421 for (size_t k=0; k<veclen_; ++k) {
00422 mean_[k] += v[k];
00423 }
00424 }
00425 DistanceType div_factor = DistanceType(1)/cnt;
00426 for (size_t k=0; k<veclen_; ++k) {
00427 mean_[k] *= div_factor;
00428 }
00429
00430
00431 for (int j = 0; j < cnt; ++j) {
00432 ElementType* v = points_[ind[j]];
00433 for (size_t k=0; k<veclen_; ++k) {
00434 DistanceType dist = v[k] - mean_[k];
00435 var_[k] += dist * dist;
00436 }
00437 }
00438
00439 cutfeat = selectDivision(var_);
00440 cutval = mean_[cutfeat];
00441
00442 int lim1, lim2;
00443 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00444
00445 if (lim1>count/2) index = lim1;
00446 else if (lim2<count/2) index = lim2;
00447 else index = count/2;
00448
00449
00450
00451
00452 if ((lim1==count)||(lim2==0)) index = count/2;
00453 }
00454
00455
00460 int selectDivision(DistanceType* v)
00461 {
00462 int num = 0;
00463 size_t topind[RAND_DIM];
00464
00465
00466 for (size_t i = 0; i < veclen_; ++i) {
00467 if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
00468
00469 if (num < RAND_DIM) {
00470 topind[num++] = i;
00471 }
00472 else {
00473 topind[num-1] = i;
00474 }
00475
00476 int j = num - 1;
00477 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
00478 std::swap(topind[j], topind[j-1]);
00479 --j;
00480 }
00481 }
00482 }
00483
00484 int rnd = rand_int(num);
00485 return (int)topind[rnd];
00486 }
00487
00488
00498 void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00499 {
00500
00501 int left = 0;
00502 int right = count-1;
00503 for (;; ) {
00504 while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
00505 while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
00506 if (left>right) break;
00507 std::swap(ind[left], ind[right]); ++left; --right;
00508 }
00509 lim1 = left;
00510 right = count-1;
00511 for (;; ) {
00512 while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
00513 while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
00514 if (left>right) break;
00515 std::swap(ind[left], ind[right]); ++left; --right;
00516 }
00517 lim2 = left;
00518 }
00519
00524 template<bool with_removed>
00525 void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError) const
00526 {
00527
00528
00529 if (trees_ > 1) {
00530 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00531 }
00532 if (trees_>0) {
00533 searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError);
00534 }
00535 }
00536
00542 template<bool with_removed>
00543 void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError) const
00544 {
00545 int i;
00546 BranchSt branch;
00547
00548 int checkCount = 0;
00549 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00550 DynamicBitset checked(size_);
00551
00552
00553 for (i = 0; i < trees_; ++i) {
00554 searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00555 }
00556
00557
00558 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00559 searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00560 }
00561
00562 delete heap;
00563
00564 }
00565
00571 template<bool with_removed>
00572 void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
00573 float epsError, Heap<BranchSt>* heap, DynamicBitset& checked) const
00574 {
00575 if (result_set.worstDist()<mindist) {
00576
00577 return;
00578 }
00579
00580
00581 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00582 int index = node->divfeat;
00583 if (with_removed) {
00584 if (removed_points_.test(index)) return;
00585 }
00586
00587 if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
00588 checked.set(index);
00589 checkCount++;
00590
00591 DistanceType dist = distance_(node->point, vec, veclen_);
00592 result_set.addPoint(dist,index);
00593 return;
00594 }
00595
00596
00597 ElementType val = vec[node->divfeat];
00598 DistanceType diff = val - node->divval;
00599 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00600 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00601
00602
00603
00604
00605
00606
00607
00608
00609
00610 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00611
00612 if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
00613 heap->insert( BranchSt(otherChild, new_distsq) );
00614 }
00615
00616
00617 searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
00618 }
00619
00623 template<bool with_removed>
00624 void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError) const
00625 {
00626
00627 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00628 int index = node->divfeat;
00629 if (with_removed) {
00630 if (removed_points_.test(index)) return;
00631 }
00632 DistanceType dist = distance_(node->point, vec, veclen_);
00633 result_set.addPoint(dist,index);
00634
00635 return;
00636 }
00637
00638
00639 ElementType val = vec[node->divfeat];
00640 DistanceType diff = val - node->divval;
00641 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00642 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00653
00654
00655 searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError);
00656
00657 if (mindist*epsError<=result_set.worstDist()) {
00658 searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError);
00659 }
00660 }
00661
00662 void addPointToTree(NodePtr node, int ind)
00663 {
00664 ElementType* point = points_[ind];
00665
00666 if ((node->child1==NULL) && (node->child2==NULL)) {
00667 ElementType* leaf_point = node->point;
00668 ElementType max_span = 0;
00669 size_t div_feat = 0;
00670 for (size_t i=0;i<veclen_;++i) {
00671 ElementType span = std::abs(point[i]-leaf_point[i]);
00672 if (span > max_span) {
00673 max_span = span;
00674 div_feat = i;
00675 }
00676 }
00677 NodePtr left = new(pool_) Node();
00678 left->child1 = left->child2 = NULL;
00679 NodePtr right = new(pool_) Node();
00680 right->child1 = right->child2 = NULL;
00681
00682 if (point[div_feat]<leaf_point[div_feat]) {
00683 left->divfeat = ind;
00684 left->point = point;
00685 right->divfeat = node->divfeat;
00686 right->point = node->point;
00687 }
00688 else {
00689 left->divfeat = node->divfeat;
00690 left->point = node->point;
00691 right->divfeat = ind;
00692 right->point = point;
00693 }
00694 node->divfeat = div_feat;
00695 node->divval = (point[div_feat]+leaf_point[div_feat])/2;
00696 node->child1 = left;
00697 node->child2 = right;
00698 }
00699 else {
00700 if (point[node->divfeat]<node->divval) {
00701 addPointToTree(node->child1,ind);
00702 }
00703 else {
00704 addPointToTree(node->child2,ind);
00705 }
00706 }
00707 }
00708 private:
00709 void swap(KDTreeIndex& other)
00710 {
00711 BaseClass::swap(other);
00712 std::swap(trees_, other.trees_);
00713 std::swap(tree_roots_, other.tree_roots_);
00714 std::swap(pool_, other.pool_);
00715 }
00716
00717 private:
00718
00719 enum
00720 {
00726 SAMPLE_MEAN = 100,
00734 RAND_DIM=5
00735 };
00736
00737
00741 int trees_;
00742
00743 DistanceType* mean_;
00744 DistanceType* var_;
00745
00749 std::vector<NodePtr> tree_roots_;
00750
00758 PooledAllocator pool_;
00759
00760 USING_BASECLASS_SYMBOLS
00761 };
00762
00763 }
00764
00765 #endif //FLANN_KDTREE_INDEX_H_