00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef RTABMAP_FLANN_KDTREE_INDEX_H_
00032 #define RTABMAP_FLANN_KDTREE_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038 #include <stdarg.h>
00039 #include <cmath>
00040
00041 #include "rtflann/general.h"
00042 #include "rtflann/algorithms/nn_index.h"
00043 #include "rtflann/util/dynamic_bitset.h"
00044 #include "rtflann/util/matrix.h"
00045 #include "rtflann/util/result_set.h"
00046 #include "rtflann/util/heap.h"
00047 #include "rtflann/util/allocator.h"
00048 #include "rtflann/util/random.h"
00049 #include "rtflann/util/saving.h"
00050
00051
00052 namespace rtflann
00053 {
00054
00055 struct KDTreeIndexParams : public IndexParams
00056 {
00057 KDTreeIndexParams(int trees = 4)
00058 {
00059 (*this)["algorithm"] = FLANN_INDEX_KDTREE;
00060 (*this)["trees"] = trees;
00061 }
00062 };
00063
00064
00071 template <typename Distance>
00072 class KDTreeIndex : public NNIndex<Distance>
00073 {
00074 public:
00075 typedef typename Distance::ElementType ElementType;
00076 typedef typename Distance::ResultType DistanceType;
00077
00078 typedef NNIndex<Distance> BaseClass;
00079
00080 typedef bool needs_kdtree_distance;
00081
00082 private:
00083
00084 struct Node
00085 {
00089 int divfeat;
00093 DistanceType divval;
00097 ElementType* point;
00101 Node* child1, *child2;
00102 Node(){
00103 child1 = NULL;
00104 child2 = NULL;
00105 }
00106 ~Node() {
00107 if (child1 != NULL) { child1->~Node(); child1 = NULL; }
00108
00109 if (child2 != NULL) { child2->~Node(); child2 = NULL; }
00110 }
00111
00112 private:
00113 template<typename Archive>
00114 void serialize(Archive& ar)
00115 {
00116 typedef KDTreeIndex<Distance> Index;
00117 Index* obj = static_cast<Index*>(ar.getObject());
00118
00119 ar & divfeat;
00120 ar & divval;
00121
00122 bool leaf_node = false;
00123 if (Archive::is_saving::value) {
00124 leaf_node = ((child1==NULL) && (child2==NULL));
00125 }
00126 ar & leaf_node;
00127
00128 if (leaf_node) {
00129 if (Archive::is_loading::value) {
00130 point = obj->points_[divfeat];
00131 }
00132 }
00133
00134 if (!leaf_node) {
00135 if (Archive::is_loading::value) {
00136 child1 = new(obj->pool_) Node();
00137 child2 = new(obj->pool_) Node();
00138 }
00139 ar & *child1;
00140 ar & *child2;
00141 }
00142 }
00143 friend struct serialization::access;
00144 };
00145
00146 typedef Node* NodePtr;
00147 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00148 typedef BranchSt* Branch;
00149
00150 public:
00151
00159 KDTreeIndex(const IndexParams& params = KDTreeIndexParams(), Distance d = Distance() ) :
00160 BaseClass(params, d), mean_(NULL), var_(NULL)
00161 {
00162 trees_ = get_param(index_params_,"trees",4);
00163 }
00164
00165
00173 KDTreeIndex(const Matrix<ElementType>& dataset, const IndexParams& params = KDTreeIndexParams(),
00174 Distance d = Distance() ) : BaseClass(params,d ), mean_(NULL), var_(NULL)
00175 {
00176 trees_ = get_param(index_params_,"trees",4);
00177
00178 setDataset(dataset);
00179 }
00180
00181 KDTreeIndex(const KDTreeIndex& other) : BaseClass(other),
00182 trees_(other.trees_)
00183 {
00184 tree_roots_.resize(other.tree_roots_.size());
00185 for (size_t i=0;i<tree_roots_.size();++i) {
00186 copyTree(tree_roots_[i], other.tree_roots_[i]);
00187 }
00188 }
00189
00190 KDTreeIndex& operator=(KDTreeIndex other)
00191 {
00192 this->swap(other);
00193 return *this;
00194 }
00195
00199 virtual ~KDTreeIndex()
00200 {
00201 freeIndex();
00202 }
00203
00204 BaseClass* clone() const
00205 {
00206 return new KDTreeIndex(*this);
00207 }
00208
00209 using BaseClass::buildIndex;
00210
00211 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00212 {
00213 assert(points.cols==veclen_);
00214
00215 size_t old_size = size_;
00216 extendDataset(points);
00217
00218 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00219 buildIndex();
00220 }
00221 else {
00222 for (size_t i=old_size;i<size_;++i) {
00223 for (int j = 0; j < trees_; j++) {
00224 addPointToTree(tree_roots_[j], i);
00225 }
00226 }
00227 }
00228 }
00229
00230 flann_algorithm_t getType() const
00231 {
00232 return FLANN_INDEX_KDTREE;
00233 }
00234
00235
00236 template<typename Archive>
00237 void serialize(Archive& ar)
00238 {
00239 ar.setObject(this);
00240
00241 ar & *static_cast<NNIndex<Distance>*>(this);
00242
00243 ar & trees_;
00244
00245 if (Archive::is_loading::value) {
00246 tree_roots_.resize(trees_);
00247 }
00248 for (size_t i=0;i<tree_roots_.size();++i) {
00249 if (Archive::is_loading::value) {
00250 tree_roots_[i] = new(pool_) Node();
00251 }
00252 ar & *tree_roots_[i];
00253 }
00254
00255 if (Archive::is_loading::value) {
00256 index_params_["algorithm"] = getType();
00257 index_params_["trees"] = trees_;
00258 }
00259 }
00260
00261
00262 void saveIndex(FILE* stream)
00263 {
00264 serialization::SaveArchive sa(stream);
00265 sa & *this;
00266 }
00267
00268
00269 void loadIndex(FILE* stream)
00270 {
00271 freeIndex();
00272 serialization::LoadArchive la(stream);
00273 la & *this;
00274 }
00275
00280 int usedMemory() const
00281 {
00282 return int(pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int));
00283 }
00284
00294 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00295 {
00296 int maxChecks = searchParams.checks;
00297 float epsError = 1+searchParams.eps;
00298
00299 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00300 if (removed_) {
00301 getExactNeighbors<true>(result, vec, epsError);
00302 }
00303 else {
00304 getExactNeighbors<false>(result, vec, epsError);
00305 }
00306 }
00307 else {
00308 if (removed_) {
00309 getNeighbors<true>(result, vec, maxChecks, epsError);
00310 }
00311 else {
00312 getNeighbors<false>(result, vec, maxChecks, epsError);
00313 }
00314 }
00315 }
00316
00317 #ifdef ANDROID
00318
00328 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams, Heap<BranchSt>* heap) const
00329 {
00330 int maxChecks = searchParams.checks;
00331 float epsError = 1+searchParams.eps;
00332
00333 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00334 if (removed_) {
00335 getExactNeighbors<true>(result, vec, epsError);
00336 }
00337 else {
00338 getExactNeighbors<false>(result, vec, epsError);
00339 }
00340 }
00341 else {
00342 if (removed_) {
00343 getNeighbors<true>(result, vec, maxChecks, epsError, heap);
00344 }
00345 else {
00346 getNeighbors<false>(result, vec, maxChecks, epsError, heap);
00347 }
00348 }
00349 }
00350
00359 virtual int knnSearch(const Matrix<ElementType>& queries,
00360 Matrix<size_t>& indices,
00361 Matrix<DistanceType>& dists,
00362 size_t knn,
00363 const SearchParams& params) const
00364 {
00365 assert(queries.cols == veclen());
00366 assert(indices.rows >= queries.rows);
00367 assert(dists.rows >= queries.rows);
00368 assert(indices.cols >= knn);
00369 assert(dists.cols >= knn);
00370 bool use_heap;
00371
00372 if (params.use_heap==FLANN_Undefined) {
00373 use_heap = (knn>KNN_HEAP_THRESHOLD)?true:false;
00374 }
00375 else {
00376 use_heap = (params.use_heap==FLANN_True)?true:false;
00377 }
00378 int count = 0;
00379
00380 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00381
00382 if (use_heap) {
00383
00384 {
00385 KNNResultSet2<DistanceType> resultSet(knn);
00386
00387 for (int i = 0; i < (int)queries.rows; i++) {
00388 resultSet.clear();
00389 findNeighbors(resultSet, queries[i], params, heap);
00390 size_t n = std::min(resultSet.size(), knn);
00391 resultSet.copy(indices[i], dists[i], n, params.sorted);
00392 indices_to_ids(indices[i], indices[i], n);
00393 count += n;
00394 }
00395 }
00396 }
00397 else {
00398 std::vector<double> times(queries.rows);
00399
00400 {
00401 KNNSimpleResultSet<DistanceType> resultSet(knn);
00402
00403 for (int i = 0; i < (int)queries.rows; i++) {
00404 resultSet.clear();
00405 findNeighbors(resultSet, queries[i], params, heap);
00406 size_t n = std::min(resultSet.size(), knn);
00407 resultSet.copy(indices[i], dists[i], n, params.sorted);
00408 indices_to_ids(indices[i], indices[i], n);
00409 count += n;
00410 }
00411 }
00412 std::sort(times.begin(), times.end());
00413 }
00414 delete heap;
00415 return count;
00416 }
00417
00418
00427 virtual int knnSearch(const Matrix<ElementType>& queries,
00428 std::vector< std::vector<size_t> >& indices,
00429 std::vector<std::vector<DistanceType> >& dists,
00430 size_t knn,
00431 const SearchParams& params) const
00432 {
00433 assert(queries.cols == veclen());
00434 bool use_heap;
00435 if (params.use_heap==FLANN_Undefined) {
00436 use_heap = (knn>KNN_HEAP_THRESHOLD)?true:false;
00437 }
00438 else {
00439 use_heap = (params.use_heap==FLANN_True)?true:false;
00440 }
00441
00442 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00443 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00444
00445 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00446
00447 int count = 0;
00448 if (use_heap) {
00449
00450 {
00451 KNNResultSet2<DistanceType> resultSet(knn);
00452
00453 for (int i = 0; i < (int)queries.rows; i++) {
00454 resultSet.clear();
00455 findNeighbors(resultSet, queries[i], params, heap);
00456 size_t n = std::min(resultSet.size(), knn);
00457 indices[i].resize(n);
00458 dists[i].resize(n);
00459 if (n>0) {
00460 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00461 indices_to_ids(&indices[i][0], &indices[i][0], n);
00462 }
00463 count += n;
00464 }
00465 }
00466 }
00467 else {
00468
00469 {
00470 KNNSimpleResultSet<DistanceType> resultSet(knn);
00471
00472 for (int i = 0; i < (int)queries.rows; i++) {
00473 resultSet.clear();
00474 findNeighbors(resultSet, queries[i], params, heap);
00475 size_t n = std::min(resultSet.size(), knn);
00476 indices[i].resize(n);
00477 dists[i].resize(n);
00478 if (n>0) {
00479 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00480 indices_to_ids(&indices[i][0], &indices[i][0], n);
00481 }
00482 count += n;
00483 }
00484 }
00485 }
00486 delete heap;
00487
00488 return count;
00489 }
00490
00500 virtual int radiusSearch(const Matrix<ElementType>& queries,
00501 Matrix<size_t>& indices,
00502 Matrix<DistanceType>& dists,
00503 float radius,
00504 const SearchParams& params) const
00505 {
00506 assert(queries.cols == veclen());
00507 int count = 0;
00508 size_t num_neighbors = std::min(indices.cols, dists.cols);
00509 int max_neighbors = params.max_neighbors;
00510 if (max_neighbors<0) max_neighbors = num_neighbors;
00511 else max_neighbors = std::min(max_neighbors,(int)num_neighbors);
00512
00513 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00514
00515 if (max_neighbors==0) {
00516
00517 {
00518 CountRadiusResultSet<DistanceType> resultSet(radius);
00519
00520 for (int i = 0; i < (int)queries.rows; i++) {
00521 resultSet.clear();
00522 findNeighbors(resultSet, queries[i], params, heap);
00523 count += resultSet.size();
00524 }
00525 }
00526 }
00527 else {
00528
00529
00530 if (params.max_neighbors<0 && (num_neighbors>=this->size())) {
00531
00532 {
00533 RadiusResultSet<DistanceType> resultSet(radius);
00534
00535 for (int i = 0; i < (int)queries.rows; i++) {
00536 resultSet.clear();
00537 findNeighbors(resultSet, queries[i], params, heap);
00538 size_t n = resultSet.size();
00539 count += n;
00540 if (n>num_neighbors) n = num_neighbors;
00541 resultSet.copy(indices[i], dists[i], n, params.sorted);
00542
00543
00544 if (n<indices.cols) indices[i][n] = size_t(-1);
00545 if (n<dists.cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
00546 indices_to_ids(indices[i], indices[i], n);
00547 }
00548 }
00549 }
00550 else {
00551
00552
00553 {
00554 KNNRadiusResultSet<DistanceType> resultSet(radius, max_neighbors);
00555
00556 for (int i = 0; i < (int)queries.rows; i++) {
00557 resultSet.clear();
00558 findNeighbors(resultSet, queries[i], params, heap);
00559 size_t n = resultSet.size();
00560 count += n;
00561 if ((int)n>max_neighbors) n = max_neighbors;
00562 resultSet.copy(indices[i], dists[i], n, params.sorted);
00563
00564
00565 if (n<indices.cols) indices[i][n] = size_t(-1);
00566 if (n<dists.cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
00567 indices_to_ids(indices[i], indices[i], n);
00568 }
00569 }
00570 }
00571 }
00572 delete heap;
00573 return count;
00574 }
00575
00585 virtual int radiusSearch(const Matrix<ElementType>& queries,
00586 std::vector< std::vector<size_t> >& indices,
00587 std::vector<std::vector<DistanceType> >& dists,
00588 float radius,
00589 const SearchParams& params) const
00590 {
00591 assert(queries.cols == veclen());
00592 int count = 0;
00593
00594 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00595
00596
00597 if (params.max_neighbors==0) {
00598
00599 {
00600 CountRadiusResultSet<DistanceType> resultSet(radius);
00601
00602 for (int i = 0; i < (int)queries.rows; i++) {
00603 resultSet.clear();
00604 findNeighbors(resultSet, queries[i], params, heap);
00605 count += resultSet.size();
00606 }
00607 }
00608 }
00609 else {
00610 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00611 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00612
00613 if (params.max_neighbors<0) {
00614
00615
00616 {
00617 RadiusResultSet<DistanceType> resultSet(radius);
00618
00619 for (int i = 0; i < (int)queries.rows; i++) {
00620 resultSet.clear();
00621 findNeighbors(resultSet, queries[i], params, heap);
00622 size_t n = resultSet.size();
00623 count += n;
00624 indices[i].resize(n);
00625 dists[i].resize(n);
00626 if (n > 0) {
00627 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00628 indices_to_ids(&indices[i][0], &indices[i][0], n);
00629 }
00630 }
00631 }
00632 }
00633 else {
00634
00635
00636 {
00637 KNNRadiusResultSet<DistanceType> resultSet(radius, params.max_neighbors);
00638
00639 for (int i = 0; i < (int)queries.rows; i++) {
00640 resultSet.clear();
00641 findNeighbors(resultSet, queries[i], params, heap);
00642 size_t n = resultSet.size();
00643 count += n;
00644 if ((int)n>params.max_neighbors) n = params.max_neighbors;
00645 indices[i].resize(n);
00646 dists[i].resize(n);
00647 if (n > 0) {
00648 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00649 indices_to_ids(&indices[i][0], &indices[i][0], n);
00650 }
00651 }
00652 }
00653 }
00654 }
00655 delete heap;
00656 return count;
00657 }
00658 #endif
00659
00660 protected:
00661
00665 void buildIndexImpl()
00666 {
00667
00668 std::vector<int> ind(size_);
00669 for (size_t i = 0; i < size_; ++i) {
00670 ind[i] = int(i);
00671 }
00672
00673 mean_ = new DistanceType[veclen_];
00674 var_ = new DistanceType[veclen_];
00675
00676 tree_roots_.resize(trees_);
00677
00678 for (int i = 0; i < trees_; i++) {
00679
00680 std::random_shuffle(ind.begin(), ind.end());
00681 tree_roots_[i] = divideTree(&ind[0], int(size_) );
00682 }
00683 delete[] mean_;
00684 delete[] var_;
00685 }
00686
00687 void freeIndex()
00688 {
00689 for (size_t i=0;i<tree_roots_.size();++i) {
00690
00691 if (tree_roots_[i]!=NULL) tree_roots_[i]->~Node();
00692 }
00693 pool_.free();
00694 }
00695
00696
00697 private:
00698
00699 void copyTree(NodePtr& dst, const NodePtr& src)
00700 {
00701 dst = new(pool_) Node();
00702 dst->divfeat = src->divfeat;
00703 dst->divval = src->divval;
00704 if (src->child1==NULL && src->child2==NULL) {
00705 dst->point = points_[dst->divfeat];
00706 dst->child1 = NULL;
00707 dst->child2 = NULL;
00708 }
00709 else {
00710 copyTree(dst->child1, src->child1);
00711 copyTree(dst->child2, src->child2);
00712 }
00713 }
00714
00724 NodePtr divideTree(int* ind, int count)
00725 {
00726 NodePtr node = new(pool_) Node();
00727
00728
00729 if (count == 1) {
00730 node->child1 = node->child2 = NULL;
00731 node->divfeat = *ind;
00732 node->point = points_[*ind];
00733 }
00734 else {
00735 int idx;
00736 int cutfeat;
00737 DistanceType cutval;
00738 meanSplit(ind, count, idx, cutfeat, cutval);
00739
00740 node->divfeat = cutfeat;
00741 node->divval = cutval;
00742 node->child1 = divideTree(ind, idx);
00743 node->child2 = divideTree(ind+idx, count-idx);
00744 }
00745
00746 return node;
00747 }
00748
00749
00755 void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
00756 {
00757 memset(mean_,0,veclen_*sizeof(DistanceType));
00758 memset(var_,0,veclen_*sizeof(DistanceType));
00759
00760
00761
00762
00763 int cnt = std::min((int)SAMPLE_MEAN+1, count);
00764 for (int j = 0; j < cnt; ++j) {
00765 ElementType* v = points_[ind[j]];
00766 for (size_t k=0; k<veclen_; ++k) {
00767 mean_[k] += v[k];
00768 }
00769 }
00770 DistanceType div_factor = DistanceType(1)/cnt;
00771 for (size_t k=0; k<veclen_; ++k) {
00772 mean_[k] *= div_factor;
00773 }
00774
00775
00776 for (int j = 0; j < cnt; ++j) {
00777 ElementType* v = points_[ind[j]];
00778 for (size_t k=0; k<veclen_; ++k) {
00779 DistanceType dist = v[k] - mean_[k];
00780 var_[k] += dist * dist;
00781 }
00782 }
00783
00784 cutfeat = selectDivision(var_);
00785 cutval = mean_[cutfeat];
00786
00787 int lim1, lim2;
00788 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00789
00790 if (lim1>count/2) index = lim1;
00791 else if (lim2<count/2) index = lim2;
00792 else index = count/2;
00793
00794
00795
00796
00797 if ((lim1==count)||(lim2==0)) index = count/2;
00798 }
00799
00800
00805 int selectDivision(DistanceType* v)
00806 {
00807 int num = 0;
00808 size_t topind[RAND_DIM];
00809
00810
00811 for (size_t i = 0; i < veclen_; ++i) {
00812 if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
00813
00814 if (num < RAND_DIM) {
00815 topind[num++] = i;
00816 }
00817 else {
00818 topind[num-1] = i;
00819 }
00820
00821 int j = num - 1;
00822 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
00823 std::swap(topind[j], topind[j-1]);
00824 --j;
00825 }
00826 }
00827 }
00828
00829 int rnd = rand_int(num);
00830 return (int)topind[rnd];
00831 }
00832
00833
00843 void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00844 {
00845
00846 int left = 0;
00847 int right = count-1;
00848 for (;; ) {
00849 while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
00850 while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
00851 if (left>right) break;
00852 std::swap(ind[left], ind[right]); ++left; --right;
00853 }
00854 lim1 = left;
00855 right = count-1;
00856 for (;; ) {
00857 while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
00858 while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
00859 if (left>right) break;
00860 std::swap(ind[left], ind[right]); ++left; --right;
00861 }
00862 lim2 = left;
00863 }
00864
00869 template<bool with_removed>
00870 void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError) const
00871 {
00872
00873
00874 if (trees_ > 1) {
00875 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00876 }
00877 if (trees_>0) {
00878 searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError);
00879 }
00880 }
00881
00887 template<bool with_removed>
00888 void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError) const
00889 {
00890 int i;
00891 BranchSt branch;
00892
00893 int checkCount = 0;
00894 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00895 DynamicBitset checked(size_);
00896
00897
00898 for (i = 0; i < trees_; ++i) {
00899 searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00900 }
00901
00902
00903 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00904 searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00905 }
00906
00907 delete heap;
00908
00909 }
00910
00911 #ifdef ANDROID
00912
00917 template<bool with_removed>
00918 void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError, Heap<BranchSt>* heap) const
00919 {
00920 int i;
00921 BranchSt branch;
00922
00923 int checkCount = 0;
00924 DynamicBitset checked(size_);
00925 heap->clear();
00926
00927
00928 for (i = 0; i < trees_; ++i) {
00929 searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00930 }
00931
00932
00933 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00934 searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00935 }
00936 }
00937 #endif
00938
00939
00945 template<bool with_removed>
00946 void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
00947 float epsError, Heap<BranchSt>* heap, DynamicBitset& checked) const
00948 {
00949 if (result_set.worstDist()<mindist) {
00950
00951 return;
00952 }
00953
00954
00955 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00956 int index = node->divfeat;
00957 if (with_removed) {
00958 if (removed_points_.test(index)) return;
00959 }
00960
00961 if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
00962 checked.set(index);
00963 checkCount++;
00964
00965 DistanceType dist = distance_(node->point, vec, veclen_);
00966 result_set.addPoint(dist,index);
00967 return;
00968 }
00969
00970
00971 ElementType val = vec[node->divfeat];
00972 DistanceType diff = val - node->divval;
00973 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00974 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00975
00976
00977
00978
00979
00980
00981
00982
00983
00984 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00985
00986 if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
00987 heap->insert( BranchSt(otherChild, new_distsq) );
00988 }
00989
00990
00991 searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
00992 }
00993
00997 template<bool with_removed>
00998 void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError) const
00999 {
01000
01001 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
01002 int index = node->divfeat;
01003 if (with_removed) {
01004 if (removed_points_.test(index)) return;
01005 }
01006 DistanceType dist = distance_(node->point, vec, veclen_);
01007 result_set.addPoint(dist,index);
01008
01009 return;
01010 }
01011
01012
01013 ElementType val = vec[node->divfeat];
01014 DistanceType diff = val - node->divval;
01015 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
01016 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
01017
01018
01019
01020
01021
01022
01023
01024
01025
01026 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
01027
01028
01029 searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError);
01030
01031 if (mindist*epsError<=result_set.worstDist()) {
01032 searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError);
01033 }
01034 }
01035
01036 void addPointToTree(NodePtr node, int ind)
01037 {
01038 ElementType* point = points_[ind];
01039
01040 if ((node->child1==NULL) && (node->child2==NULL)) {
01041 ElementType* leaf_point = node->point;
01042 ElementType max_span = 0;
01043 size_t div_feat = 0;
01044 for (size_t i=0;i<veclen_;++i) {
01045 ElementType span = std::abs(point[i]-leaf_point[i]);
01046 if (span > max_span) {
01047 max_span = span;
01048 div_feat = i;
01049 }
01050 }
01051 NodePtr left = new(pool_) Node();
01052 left->child1 = left->child2 = NULL;
01053 NodePtr right = new(pool_) Node();
01054 right->child1 = right->child2 = NULL;
01055
01056 if (point[div_feat]<leaf_point[div_feat]) {
01057 left->divfeat = ind;
01058 left->point = point;
01059 right->divfeat = node->divfeat;
01060 right->point = node->point;
01061 }
01062 else {
01063 left->divfeat = node->divfeat;
01064 left->point = node->point;
01065 right->divfeat = ind;
01066 right->point = point;
01067 }
01068 node->divfeat = div_feat;
01069 node->divval = (point[div_feat]+leaf_point[div_feat])/2;
01070 node->child1 = left;
01071 node->child2 = right;
01072 }
01073 else {
01074 if (point[node->divfeat]<node->divval) {
01075 addPointToTree(node->child1,ind);
01076 }
01077 else {
01078 addPointToTree(node->child2,ind);
01079 }
01080 }
01081 }
01082 private:
01083 void swap(KDTreeIndex& other)
01084 {
01085 BaseClass::swap(other);
01086 std::swap(trees_, other.trees_);
01087 std::swap(tree_roots_, other.tree_roots_);
01088 std::swap(pool_, other.pool_);
01089 }
01090
01091 private:
01092
01093 enum
01094 {
01100 SAMPLE_MEAN = 100,
01108 RAND_DIM=5
01109 };
01110
01111
01115 int trees_;
01116
01117 DistanceType* mean_;
01118 DistanceType* var_;
01119
01123 std::vector<NodePtr> tree_roots_;
01124
01132 PooledAllocator pool_;
01133
01134 USING_BASECLASS_SYMBOLS
01135 };
01136
01137 }
01138
01139 #endif //FLANN_KDTREE_INDEX_H_