31 #ifndef RTABMAP_FLANN_KDTREE_INDEX_H_ 32 #define RTABMAP_FLANN_KDTREE_INDEX_H_ 60 (*this)[
"trees"] = trees;
71 template <
typename Distance>
113 template<
typename Archive>
117 Index* obj =
static_cast<Index*
>(ar.getObject());
122 bool leaf_node =
false;
123 if (Archive::is_saving::value) {
124 leaf_node = ((child1==
NULL) && (child2==
NULL));
129 if (Archive::is_loading::value) {
130 point = obj->points_[divfeat];
135 if (Archive::is_loading::value) {
136 child1 =
new(obj->pool_)
Node();
137 child2 =
new(obj->pool_)
Node();
160 BaseClass(params,
d), mean_(
NULL), var_(
NULL)
162 trees_ =
get_param(index_params_,
"trees",4);
174 Distance d = Distance() ) : BaseClass(params,
d ), mean_(
NULL), var_(
NULL)
176 trees_ =
get_param(index_params_,
"trees",4);
185 for (
size_t i=0;i<tree_roots_.size();++i) {
209 using BaseClass::buildIndex;
213 assert(points.
cols==veclen_);
215 size_t old_size = size_;
216 extendDataset(points);
218 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
222 for (
size_t i=old_size;i<size_;++i) {
223 for (
int j = 0; j < trees_; j++) {
224 addPointToTree(tree_roots_[j], i);
236 template<
typename Archive>
245 if (Archive::is_loading::value) {
246 tree_roots_.resize(trees_);
248 for (
size_t i=0;i<tree_roots_.size();++i) {
249 if (Archive::is_loading::value) {
250 tree_roots_[i] =
new(pool_)
Node();
252 ar & *tree_roots_[i];
255 if (Archive::is_loading::value) {
256 index_params_[
"algorithm"] = getType();
257 index_params_[
"trees"] = trees_;
282 return int(pool_.usedMemory+pool_.wastedMemory+size_*
sizeof(
int));
296 int maxChecks = searchParams.
checks;
297 float epsError = 1+searchParams.
eps;
301 getExactNeighbors<true>(result, vec, epsError);
304 getExactNeighbors<false>(result, vec, epsError);
309 getNeighbors<true>(result, vec, maxChecks, epsError);
312 getNeighbors<false>(result, vec, maxChecks, epsError);
330 int maxChecks = searchParams.
checks;
331 float epsError = 1+searchParams.
eps;
335 getExactNeighbors<true>(result, vec, epsError);
338 getExactNeighbors<false>(result, vec, epsError);
343 getNeighbors<true>(result, vec, maxChecks, epsError, heap);
346 getNeighbors<false>(result, vec, maxChecks, epsError, heap);
365 assert(queries.
cols == veclen());
366 assert(indices.
rows >= queries.
rows);
368 assert(indices.
cols >= knn);
369 assert(dists.
cols >= knn);
387 for (
int i = 0; i < (int)queries.
rows; i++) {
389 findNeighbors(resultSet, queries[i], params, heap);
391 resultSet.
copy(indices[i], dists[i], n, params.
sorted);
392 indices_to_ids(indices[i], indices[i], n);
398 std::vector<double> times(queries.
rows);
403 for (
int i = 0; i < (int)queries.
rows; i++) {
405 findNeighbors(resultSet, queries[i], params, heap);
407 resultSet.
copy(indices[i], dists[i], n, params.
sorted);
408 indices_to_ids(indices[i], indices[i], n);
412 std::sort(times.begin(), times.end());
428 std::vector< std::vector<size_t> >& indices,
429 std::vector<std::vector<DistanceType> >& dists,
433 assert(queries.
cols == veclen());
442 if (indices.size() < queries.
rows ) indices.resize(queries.
rows);
443 if (dists.size() < queries.
rows ) dists.resize(queries.
rows);
453 for (
int i = 0; i < (int)queries.
rows; i++) {
455 findNeighbors(resultSet, queries[i], params, heap);
457 indices[i].resize(n);
460 resultSet.
copy(&indices[i][0], &dists[i][0], n, params.
sorted);
461 indices_to_ids(&indices[i][0], &indices[i][0], n);
472 for (
int i = 0; i < (int)queries.
rows; i++) {
474 findNeighbors(resultSet, queries[i], params, heap);
476 indices[i].resize(n);
479 resultSet.
copy(&indices[i][0], &dists[i][0], n, params.
sorted);
480 indices_to_ids(&indices[i][0], &indices[i][0], n);
506 assert(queries.
cols == veclen());
510 if (max_neighbors<0) max_neighbors = num_neighbors;
511 else max_neighbors =
std::min(max_neighbors,(
int)num_neighbors);
515 if (max_neighbors==0) {
520 for (
int i = 0; i < (int)queries.
rows; i++) {
522 findNeighbors(resultSet, queries[i], params, heap);
523 count += resultSet.
size();
530 if (params.
max_neighbors<0 && (num_neighbors>=this->size())) {
535 for (
int i = 0; i < (int)queries.
rows; i++) {
537 findNeighbors(resultSet, queries[i], params, heap);
538 size_t n = resultSet.
size();
540 if (n>num_neighbors) n = num_neighbors;
541 resultSet.
copy(indices[i], dists[i], n, params.
sorted);
544 if (n<indices.
cols) indices[i][n] = size_t(-1);
545 if (n<dists.
cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
546 indices_to_ids(indices[i], indices[i], n);
556 for (
int i = 0; i < (int)queries.
rows; i++) {
558 findNeighbors(resultSet, queries[i], params, heap);
559 size_t n = resultSet.
size();
561 if ((
int)n>max_neighbors) n = max_neighbors;
562 resultSet.
copy(indices[i], dists[i], n, params.
sorted);
565 if (n<indices.
cols) indices[i][n] = size_t(-1);
566 if (n<dists.
cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
567 indices_to_ids(indices[i], indices[i], n);
586 std::vector< std::vector<size_t> >& indices,
587 std::vector<std::vector<DistanceType> >& dists,
591 assert(queries.
cols == veclen());
602 for (
int i = 0; i < (int)queries.
rows; i++) {
604 findNeighbors(resultSet, queries[i], params, heap);
605 count += resultSet.
size();
610 if (indices.size() < queries.
rows ) indices.resize(queries.
rows);
611 if (dists.size() < queries.
rows ) dists.resize(queries.
rows);
619 for (
int i = 0; i < (int)queries.
rows; i++) {
621 findNeighbors(resultSet, queries[i], params, heap);
622 size_t n = resultSet.
size();
624 indices[i].resize(n);
627 resultSet.
copy(&indices[i][0], &dists[i][0], n, params.
sorted);
628 indices_to_ids(&indices[i][0], &indices[i][0], n);
639 for (
int i = 0; i < (int)queries.
rows; i++) {
641 findNeighbors(resultSet, queries[i], params, heap);
642 size_t n = resultSet.
size();
645 indices[i].resize(n);
648 resultSet.
copy(&indices[i][0], &dists[i][0], n, params.
sorted);
649 indices_to_ids(&indices[i][0], &indices[i][0], n);
668 std::vector<int> ind(size_);
669 for (
size_t i = 0; i < size_; ++i) {
673 mean_ =
new DistanceType[veclen_];
674 var_ =
new DistanceType[veclen_];
676 tree_roots_.resize(trees_);
678 for (
int i = 0; i < trees_; i++) {
680 std::random_shuffle(ind.begin(), ind.end());
681 tree_roots_[i] = divideTree(&ind[0],
int(size_) );
689 for (
size_t i=0;i<tree_roots_.size();++i) {
691 if (tree_roots_[i]!=
NULL) tree_roots_[i]->~Node();
701 dst =
new(pool_)
Node();
726 NodePtr node =
new(pool_)
Node();
732 node->
point = points_[*ind];
738 meanSplit(ind, count, idx, cutfeat, cutval);
742 node->
child1 = divideTree(ind, idx);
743 node->
child2 = divideTree(ind+idx, count-idx);
755 void meanSplit(
int* ind,
int count,
int& index,
int& cutfeat, DistanceType& cutval)
757 memset(mean_,0,veclen_*
sizeof(DistanceType));
758 memset(var_,0,veclen_*
sizeof(DistanceType));
763 int cnt =
std::min((
int)SAMPLE_MEAN+1, count);
764 for (
int j = 0; j < cnt; ++j) {
765 ElementType* v = points_[ind[j]];
766 for (
size_t k=0; k<veclen_; ++k) {
770 DistanceType div_factor = DistanceType(1)/cnt;
771 for (
size_t k=0; k<veclen_; ++k) {
772 mean_[k] *= div_factor;
776 for (
int j = 0; j < cnt; ++j) {
777 ElementType* v = points_[ind[j]];
778 for (
size_t k=0; k<veclen_; ++k) {
779 DistanceType dist = v[k] - mean_[k];
780 var_[k] += dist * dist;
784 cutfeat = selectDivision(var_);
785 cutval = mean_[cutfeat];
788 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
790 if (lim1>count/2) index = lim1;
791 else if (lim2<count/2) index = lim2;
792 else index = count/2;
797 if ((lim1==count)||(lim2==0)) index = count/2;
808 size_t topind[RAND_DIM];
811 for (
size_t i = 0; i < veclen_; ++i) {
812 if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
814 if (num < RAND_DIM) {
822 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
823 std::swap(topind[j], topind[j-1]);
830 return (
int)topind[rnd];
843 void planeSplit(
int* ind,
int count,
int cutfeat, DistanceType cutval,
int& lim1,
int& lim2)
849 while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
850 while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
851 if (left>right)
break;
852 std::swap(ind[left], ind[right]); ++left; --right;
857 while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
858 while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
859 if (left>right)
break;
860 std::swap(ind[left], ind[right]); ++left; --right;
869 template<
bool with_removed>
875 fprintf(stderr,
"It doesn't make any sense to use more than one tree for exact search");
878 searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError);
887 template<
bool with_removed>
898 for (i = 0; i < trees_; ++i) {
899 searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
903 while ( heap->
popMin(branch) && (checkCount < maxCheck || !result.
full() )) {
904 searchLevel<with_removed>(result, vec, branch.
node, branch.
mindist, checkCount, maxCheck, epsError, heap, checked);
917 template<
bool with_removed>
928 for (i = 0; i < trees_; ++i) {
929 searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
933 while ( heap->
popMin(branch) && (checkCount < maxCheck || !result.
full() )) {
934 searchLevel<with_removed>(result, vec, branch.
node, branch.
mindist, checkCount, maxCheck, epsError, heap, checked);
945 template<
bool with_removed>
958 if (removed_points_.test(index))
return;
961 if ( checked.
test(index) || ((checkCount>=maxCheck)&& result_set.
full()) )
return;
965 DistanceType dist = distance_(node->
point, vec, veclen_);
971 ElementType val = vec[node->
divfeat];
972 DistanceType diff = val - node->
divval;
973 NodePtr bestChild = (diff < 0) ? node->
child1 : node->
child2;
974 NodePtr otherChild = (diff < 0) ? node->
child2 : node->
child1;
984 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->
divval, node->
divfeat);
986 if ((new_distsq*epsError < result_set.
worstDist())|| !result_set.
full()) {
987 heap->
insert( BranchSt(otherChild, new_distsq) );
991 searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
997 template<
bool with_removed>
1004 if (removed_points_.test(index))
return;
1006 DistanceType dist = distance_(node->
point, vec, veclen_);
1013 ElementType val = vec[node->
divfeat];
1014 DistanceType diff = val - node->
divval;
1015 NodePtr bestChild = (diff < 0) ? node->
child1 : node->
child2;
1016 NodePtr otherChild = (diff < 0) ? node->
child2 : node->
child1;
1026 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->
divval, node->
divfeat);
1029 searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError);
1031 if (mindist*epsError<=result_set.worstDist()) {
1032 searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError);
1038 ElementType* point = points_[ind];
1041 ElementType* leaf_point = node->
point;
1042 ElementType max_span = 0;
1043 size_t div_feat = 0;
1044 for (
size_t i=0;i<veclen_;++i) {
1045 ElementType span =
std::abs(point[i]-leaf_point[i]);
1046 if (span > max_span) {
1051 NodePtr left =
new(pool_)
Node();
1053 NodePtr right =
new(pool_)
Node();
1056 if (point[div_feat]<leaf_point[div_feat]) {
1058 left->
point = point;
1066 right->
point = point;
1069 node->
divval = (point[div_feat]+leaf_point[div_feat])/2;
1075 addPointToTree(node->
child1,ind);
1078 addPointToTree(node->
child2,ind);
1085 BaseClass::swap(other);
1086 std::swap(trees_, other.
trees_);
1088 std::swap(pool_, other.
pool_);
1139 #endif //FLANN_KDTREE_INDEX_H_
std::map< std::string, any > IndexParams
void addPoints(const Matrix< ElementType > &points, float rebuild_threshold=2)
Incrementally add points to the index.
void copy(size_t *indices, DistanceType *dists, size_t num_elements, bool sorted=true)
void copy(size_t *indices, DistanceType *dists, size_t num_elements, bool sorted=true)
T get_param(const IndexParams ¶ms, std::string name, const T &default_value)
GLM_FUNC_DECL genType min(genType const &x, genType const &y)
void loadIndex(FILE *stream)
KDTreeIndexParams(int trees=4)
NNIndex< Distance > BaseClass
KDTreeIndex & operator=(KDTreeIndex other)
KDTreeIndex(const KDTreeIndex &other)
void swap(KDTreeIndex &other)
void findNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
int rand_int(int high=RAND_MAX, int low=0)
void insert(const T &value)
void searchLevelExact(ResultSet< DistanceType > &result_set, const ElementType *vec, const NodePtr node, DistanceType mindist, const float epsError) const
#define USING_BASECLASS_SYMBOLS
std::vector< NodePtr > tree_roots_
void copyTree(NodePtr &dst, const NodePtr &src)
BranchStruct< NodePtr, DistanceType > BranchSt
void getExactNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, float epsError) const
bool needs_kdtree_distance
bool test(size_t index) const
GLM_FUNC_DECL genType abs(genType const &x)
KDTreeIndex(const Matrix< ElementType > &dataset, const IndexParams ¶ms=KDTreeIndexParams(), Distance d=Distance())
void serialize(Archive &ar)
flann_algorithm_t getType() const
#define KNN_HEAP_THRESHOLD
void getNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, int maxCheck, float epsError) const
void planeSplit(int *ind, int count, int cutfeat, DistanceType cutval, int &lim1, int &lim2)
virtual DistanceType worstDist() const =0
void copy(size_t *indices, DistanceType *dists, size_t num_elements, bool sorted=true)
int selectDivision(DistanceType *v)
Distance::ResultType DistanceType
virtual bool full() const =0
void searchLevel(ResultSet< DistanceType > &result_set, const ElementType *vec, NodePtr node, DistanceType mindist, int &checkCount, int maxCheck, float epsError, Heap< BranchSt > *heap, DynamicBitset &checked) const
static void freeIndex(sqlite3 *db, Index *p)
void addPointToTree(NodePtr node, int ind)
void serialize(Archive &ar)
BaseClass * clone() const
void meanSplit(int *ind, int count, int &index, int &cutfeat, DistanceType &cutval)
void copy(size_t *indices, DistanceType *dists, size_t num_elements, bool sorted=true)
NodePtr divideTree(int *ind, int count)
KDTreeIndex(const IndexParams ¶ms=KDTreeIndexParams(), Distance d=Distance())
virtual void addPoint(DistanceType dist, size_t index)=0
void saveIndex(FILE *stream)
Distance::ElementType ElementType