00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef RTABMAP_FLANN_NNINDEX_H
00032 #define RTABMAP_FLANN_NNINDEX_H
00033
00034 #include <vector>
00035
00036 #include "rtflann/general.h"
00037 #include "rtflann/util/matrix.h"
00038 #include "rtflann/util/params.h"
00039 #include "rtflann/util/result_set.h"
00040 #include "rtflann/util/dynamic_bitset.h"
00041 #include "rtflann/util/saving.h"
00042
00043 namespace rtflann
00044 {
00045
00046 #define KNN_HEAP_THRESHOLD 250
00047
00048
00049 class IndexBase
00050 {
00051 public:
00052 virtual ~IndexBase() {};
00053
00054 virtual size_t veclen() const = 0;
00055
00056 virtual size_t size() const = 0;
00057
00058 virtual flann_algorithm_t getType() const = 0;
00059
00060 virtual int usedMemory() const = 0;
00061
00062 virtual IndexParams getParameters() const = 0;
00063
00064 virtual void loadIndex(FILE* stream) = 0;
00065
00066 virtual void saveIndex(FILE* stream) = 0;
00067 };
00068
00072 template <typename Distance>
00073 class NNIndex : public IndexBase
00074 {
00075 public:
00076 typedef typename Distance::ElementType ElementType;
00077 typedef typename Distance::ResultType DistanceType;
00078
00079 NNIndex(Distance d) : distance_(d), last_id_(0), size_(0), size_at_build_(0), veclen_(0),
00080 removed_(false), removed_count_(0), data_ptr_(NULL)
00081 {
00082 }
00083
00084 NNIndex(const IndexParams& params, Distance d) : distance_(d), last_id_(0), size_(0), size_at_build_(0), veclen_(0),
00085 index_params_(params), removed_(false), removed_count_(0), data_ptr_(NULL)
00086 {
00087 }
00088
00089 NNIndex(const NNIndex& other) :
00090 distance_(other.distance_),
00091 last_id_(other.last_id_),
00092 size_(other.size_),
00093 size_at_build_(other.size_at_build_),
00094 veclen_(other.veclen_),
00095 index_params_(other.index_params_),
00096 removed_(other.removed_),
00097 removed_points_(other.removed_points_),
00098 removed_count_(other.removed_count_),
00099 ids_(other.ids_),
00100 points_(other.points_),
00101 data_ptr_(NULL)
00102 {
00103 if (other.data_ptr_) {
00104 data_ptr_ = new ElementType[size_*veclen_];
00105 std::copy(other.data_ptr_, other.data_ptr_+size_*veclen_, data_ptr_);
00106 for (size_t i=0;i<size_;++i) {
00107 points_[i] = data_ptr_ + i*veclen_;
00108 }
00109 }
00110 }
00111
00112 virtual ~NNIndex()
00113 {
00114 if (data_ptr_) {
00115 delete[] data_ptr_;
00116 }
00117 }
00118
00119
00120 virtual NNIndex* clone() const = 0;
00121
00125 virtual void buildIndex()
00126 {
00127 freeIndex();
00128 cleanRemovedPoints();
00129
00130
00131 buildIndexImpl();
00132
00133 size_at_build_ = size_;
00134
00135 }
00136
00141 virtual void buildIndex(const Matrix<ElementType>& dataset)
00142 {
00143 setDataset(dataset);
00144 this->buildIndex();
00145 }
00146
00152 virtual void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00153 {
00154 throw FLANNException("Functionality not supported by this index");
00155 }
00156
00161 virtual void removePoint(size_t id)
00162 {
00163 if (!removed_) {
00164 ids_.resize(size_);
00165 for (size_t i=0;i<size_;++i) {
00166 ids_[i] = i;
00167 }
00168 removed_points_.resize(size_);
00169 removed_points_.reset();
00170 last_id_ = size_;
00171 removed_ = true;
00172 }
00173
00174 size_t point_index = id_to_index(id);
00175 if (point_index!=size_t(-1) && !removed_points_.test(point_index)) {
00176 removed_points_.set(point_index);
00177 removed_count_++;
00178 }
00179 }
00180
00181
00187 virtual ElementType* getPoint(size_t id)
00188 {
00189 size_t index = id_to_index(id);
00190 if (index!=size_t(-1)) {
00191 return points_[index];
00192 }
00193 else {
00194 return NULL;
00195 }
00196 }
00197
00201 inline size_t size() const
00202 {
00203 return size_ - removed_count_;
00204 }
00205
00206 inline size_t removedCount() const
00207 {
00208 return removed_count_;
00209 }
00210
00211 inline size_t sizeAtBuild() const
00212 {
00213 return size_at_build_;
00214 }
00215
00219 inline size_t veclen() const
00220 {
00221 return veclen_;
00222 }
00223
00229 IndexParams getParameters() const
00230 {
00231 return index_params_;
00232 }
00233
00234
00235 template<typename Archive>
00236 void serialize(Archive& ar)
00237 {
00238 IndexHeader header;
00239
00240 if (Archive::is_saving::value) {
00241 header.h.data_type = flann_datatype_value<ElementType>::value;
00242 header.h.index_type = getType();
00243 header.h.rows = size_;
00244 header.h.cols = veclen_;
00245 }
00246 ar & header;
00247
00248
00249 if (Archive::is_loading::value) {
00250 if (strncmp(header.h.signature,
00251 FLANN_SIGNATURE_,
00252 strlen(FLANN_SIGNATURE_) - strlen("v0.0")) != 0) {
00253 throw FLANNException("Invalid index file, wrong signature");
00254 }
00255
00256 if (header.h.data_type != flann_datatype_value<ElementType>::value) {
00257 throw FLANNException("Datatype of saved index is different than of the one to be created.");
00258 }
00259
00260 if (header.h.index_type != getType()) {
00261 throw FLANNException("Saved index type is different then the current index type.");
00262 }
00263
00264
00265 }
00266
00267 ar & size_;
00268 ar & veclen_;
00269 ar & size_at_build_;
00270
00271 bool save_dataset;
00272 if (Archive::is_saving::value) {
00273 save_dataset = get_param(index_params_,"save_dataset", false);
00274 }
00275 ar & save_dataset;
00276
00277 if (save_dataset) {
00278 if (Archive::is_loading::value) {
00279 if (data_ptr_) {
00280 delete[] data_ptr_;
00281 }
00282 data_ptr_ = new ElementType[size_*veclen_];
00283 points_.resize(size_);
00284 for (size_t i=0;i<size_;++i) {
00285 points_[i] = data_ptr_ + i*veclen_;
00286 }
00287 }
00288 for (size_t i=0;i<size_;++i) {
00289 ar & serialization::make_binary_object (points_[i], veclen_*sizeof(ElementType));
00290 }
00291 } else {
00292 if (points_.size()!=size_) {
00293 throw FLANNException("Saved index does not contain the dataset and no dataset was provided.");
00294 }
00295 }
00296
00297 ar & last_id_;
00298 ar & ids_;
00299 ar & removed_;
00300 if (removed_) {
00301 ar & removed_points_;
00302 }
00303 ar & removed_count_;
00304 }
00305
00306
00315 virtual int knnSearch(const Matrix<ElementType>& queries,
00316 Matrix<size_t>& indices,
00317 Matrix<DistanceType>& dists,
00318 size_t knn,
00319 const SearchParams& params) const
00320 {
00321 assert(queries.cols == veclen());
00322 assert(indices.rows >= queries.rows);
00323 assert(dists.rows >= queries.rows);
00324 assert(indices.cols >= knn);
00325 assert(dists.cols >= knn);
00326 bool use_heap;
00327
00328 if (params.use_heap==FLANN_Undefined) {
00329 use_heap = (knn>KNN_HEAP_THRESHOLD)?true:false;
00330 }
00331 else {
00332 use_heap = (params.use_heap==FLANN_True)?true:false;
00333 }
00334 int count = 0;
00335
00336 if (use_heap) {
00337 #pragma omp parallel num_threads(params.cores)
00338 {
00339 KNNResultSet2<DistanceType> resultSet(knn);
00340 #pragma omp for schedule(static) reduction(+:count)
00341 for (int i = 0; i < (int)queries.rows; i++) {
00342 resultSet.clear();
00343 findNeighbors(resultSet, queries[i], params);
00344 size_t n = std::min(resultSet.size(), knn);
00345 resultSet.copy(indices[i], dists[i], n, params.sorted);
00346 indices_to_ids(indices[i], indices[i], n);
00347 count += n;
00348 }
00349 }
00350 }
00351 else {
00352 #pragma omp parallel num_threads(params.cores)
00353 {
00354 KNNSimpleResultSet<DistanceType> resultSet(knn);
00355 #pragma omp for schedule(static) reduction(+:count)
00356 for (int i = 0; i < (int)queries.rows; i++) {
00357 resultSet.clear();
00358 findNeighbors(resultSet, queries[i], params);
00359 size_t n = std::min(resultSet.size(), knn);
00360 resultSet.copy(indices[i], dists[i], n, params.sorted);
00361 indices_to_ids(indices[i], indices[i], n);
00362 count += n;
00363 }
00364 }
00365 }
00366 return count;
00367 }
00368
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00405 virtual int knnSearch(const Matrix<ElementType>& queries,
00406 std::vector< std::vector<size_t> >& indices,
00407 std::vector<std::vector<DistanceType> >& dists,
00408 size_t knn,
00409 const SearchParams& params) const
00410 {
00411 assert(queries.cols == veclen());
00412 bool use_heap;
00413 if (params.use_heap==FLANN_Undefined) {
00414 use_heap = (knn>KNN_HEAP_THRESHOLD)?true:false;
00415 }
00416 else {
00417 use_heap = (params.use_heap==FLANN_True)?true:false;
00418 }
00419
00420 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00421 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00422
00423 int count = 0;
00424 if (use_heap) {
00425 #pragma omp parallel num_threads(params.cores)
00426 {
00427 KNNResultSet2<DistanceType> resultSet(knn);
00428 #pragma omp for schedule(static) reduction(+:count)
00429 for (int i = 0; i < (int)queries.rows; i++) {
00430 resultSet.clear();
00431 findNeighbors(resultSet, queries[i], params);
00432 size_t n = std::min(resultSet.size(), knn);
00433 indices[i].resize(n);
00434 dists[i].resize(n);
00435 if (n>0) {
00436 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00437 indices_to_ids(&indices[i][0], &indices[i][0], n);
00438 }
00439 count += n;
00440 }
00441 }
00442 }
00443 else {
00444 #pragma omp parallel num_threads(params.cores)
00445 {
00446 KNNSimpleResultSet<DistanceType> resultSet(knn);
00447 #pragma omp for schedule(static) reduction(+:count)
00448 for (int i = 0; i < (int)queries.rows; i++) {
00449 resultSet.clear();
00450 findNeighbors(resultSet, queries[i], params);
00451 size_t n = std::min(resultSet.size(), knn);
00452 indices[i].resize(n);
00453 dists[i].resize(n);
00454 if (n>0) {
00455 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00456 indices_to_ids(&indices[i][0], &indices[i][0], n);
00457 }
00458 count += n;
00459 }
00460 }
00461 }
00462
00463 return count;
00464 }
00465
00466
00476 int knnSearch(const Matrix<ElementType>& queries,
00477 std::vector< std::vector<int> >& indices,
00478 std::vector<std::vector<DistanceType> >& dists,
00479 size_t knn,
00480 const SearchParams& params) const
00481 {
00482 std::vector<std::vector<size_t> > indices_;
00483 int result = knnSearch(queries, indices_, dists, knn, params);
00484
00485 indices.resize(indices_.size());
00486 for (size_t i=0;i<indices_.size();++i) {
00487 indices[i].assign(indices_[i].begin(), indices_[i].end());
00488 }
00489 return result;
00490 }
00491
00501 virtual int radiusSearch(const Matrix<ElementType>& queries,
00502 Matrix<size_t>& indices,
00503 Matrix<DistanceType>& dists,
00504 float radius,
00505 const SearchParams& params) const
00506 {
00507 assert(queries.cols == veclen());
00508 int count = 0;
00509 size_t num_neighbors = std::min(indices.cols, dists.cols);
00510 int max_neighbors = params.max_neighbors;
00511 if (max_neighbors<0) max_neighbors = num_neighbors;
00512 else max_neighbors = std::min(max_neighbors,(int)num_neighbors);
00513
00514 if (max_neighbors==0) {
00515 #pragma omp parallel num_threads(params.cores)
00516 {
00517 CountRadiusResultSet<DistanceType> resultSet(radius);
00518 #pragma omp for schedule(static) reduction(+:count)
00519 for (int i = 0; i < (int)queries.rows; i++) {
00520 resultSet.clear();
00521 findNeighbors(resultSet, queries[i], params);
00522 count += resultSet.size();
00523 }
00524 }
00525 }
00526 else {
00527
00528
00529 if (params.max_neighbors<0 && (num_neighbors>=size())) {
00530 #pragma omp parallel num_threads(params.cores)
00531 {
00532 RadiusResultSet<DistanceType> resultSet(radius);
00533 #pragma omp for schedule(static) reduction(+:count)
00534 for (int i = 0; i < (int)queries.rows; i++) {
00535 resultSet.clear();
00536 findNeighbors(resultSet, queries[i], params);
00537 size_t n = resultSet.size();
00538 count += n;
00539 if (n>num_neighbors) n = num_neighbors;
00540 resultSet.copy(indices[i], dists[i], n, params.sorted);
00541
00542
00543 if (n<indices.cols) indices[i][n] = size_t(-1);
00544 if (n<dists.cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
00545 indices_to_ids(indices[i], indices[i], n);
00546 }
00547 }
00548 }
00549 else {
00550
00551 #pragma omp parallel num_threads(params.cores)
00552 {
00553 KNNRadiusResultSet<DistanceType> resultSet(radius, max_neighbors);
00554 #pragma omp for schedule(static) reduction(+:count)
00555 for (int i = 0; i < (int)queries.rows; i++) {
00556 resultSet.clear();
00557 findNeighbors(resultSet, queries[i], params);
00558 size_t n = resultSet.size();
00559 count += n;
00560 if ((int)n>max_neighbors) n = max_neighbors;
00561 resultSet.copy(indices[i], dists[i], n, params.sorted);
00562
00563
00564 if (n<indices.cols) indices[i][n] = size_t(-1);
00565 if (n<dists.cols) dists[i][n] = std::numeric_limits<DistanceType>::infinity();
00566 indices_to_ids(indices[i], indices[i], n);
00567 }
00568 }
00569 }
00570 }
00571 return count;
00572 }
00573
00574
00584 int radiusSearch(const Matrix<ElementType>& queries,
00585 Matrix<int>& indices,
00586 Matrix<DistanceType>& dists,
00587 float radius,
00588 const SearchParams& params) const
00589 {
00590 rtflann::Matrix<size_t> indices_(new size_t[indices.rows*indices.cols], indices.rows, indices.cols);
00591 int result = radiusSearch(queries, indices_, dists, radius, params);
00592
00593 for (size_t i=0;i<indices.rows;++i) {
00594 for (size_t j=0;j<indices.cols;++j) {
00595 indices[i][j] = indices_[i][j];
00596 }
00597 }
00598 delete[] indices_.ptr();
00599 return result;
00600 }
00601
00611 virtual int radiusSearch(const Matrix<ElementType>& queries,
00612 std::vector< std::vector<size_t> >& indices,
00613 std::vector<std::vector<DistanceType> >& dists,
00614 float radius,
00615 const SearchParams& params) const
00616 {
00617 assert(queries.cols == veclen());
00618 int count = 0;
00619
00620 if (params.max_neighbors==0) {
00621 #pragma omp parallel num_threads(params.cores)
00622 {
00623 CountRadiusResultSet<DistanceType> resultSet(radius);
00624 #pragma omp for schedule(static) reduction(+:count)
00625 for (int i = 0; i < (int)queries.rows; i++) {
00626 resultSet.clear();
00627 findNeighbors(resultSet, queries[i], params);
00628 count += resultSet.size();
00629 }
00630 }
00631 }
00632 else {
00633 if (indices.size() < queries.rows ) indices.resize(queries.rows);
00634 if (dists.size() < queries.rows ) dists.resize(queries.rows);
00635
00636 if (params.max_neighbors<0) {
00637
00638 #pragma omp parallel num_threads(params.cores)
00639 {
00640 RadiusResultSet<DistanceType> resultSet(radius);
00641 #pragma omp for schedule(static) reduction(+:count)
00642 for (int i = 0; i < (int)queries.rows; i++) {
00643 resultSet.clear();
00644 findNeighbors(resultSet, queries[i], params);
00645 size_t n = resultSet.size();
00646 count += n;
00647 indices[i].resize(n);
00648 dists[i].resize(n);
00649 if (n > 0) {
00650 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00651 indices_to_ids(&indices[i][0], &indices[i][0], n);
00652 }
00653 }
00654 }
00655 }
00656 else {
00657
00658 #pragma omp parallel num_threads(params.cores)
00659 {
00660 KNNRadiusResultSet<DistanceType> resultSet(radius, params.max_neighbors);
00661 #pragma omp for schedule(static) reduction(+:count)
00662 for (int i = 0; i < (int)queries.rows; i++) {
00663 resultSet.clear();
00664 findNeighbors(resultSet, queries[i], params);
00665 size_t n = resultSet.size();
00666 count += n;
00667 if ((int)n>params.max_neighbors) n = params.max_neighbors;
00668 indices[i].resize(n);
00669 dists[i].resize(n);
00670 if (n > 0) {
00671 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
00672 indices_to_ids(&indices[i][0], &indices[i][0], n);
00673 }
00674 }
00675 }
00676 }
00677 }
00678 return count;
00679 }
00680
00690 int radiusSearch(const Matrix<ElementType>& queries,
00691 std::vector< std::vector<int> >& indices,
00692 std::vector<std::vector<DistanceType> >& dists,
00693 float radius,
00694 const SearchParams& params) const
00695 {
00696 std::vector<std::vector<size_t> > indices_;
00697 int result = radiusSearch(queries, indices_, dists, radius, params);
00698
00699 indices.resize(indices_.size());
00700 for (size_t i=0;i<indices_.size();++i) {
00701 indices[i].assign(indices_[i].begin(), indices_[i].end());
00702 }
00703 return result;
00704 }
00705
00706
00707 virtual void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const = 0;
00708
00709 protected:
00710
00711 virtual void freeIndex() = 0;
00712
00713 virtual void buildIndexImpl() = 0;
00714
00715 size_t id_to_index(size_t id)
00716 {
00717 if (ids_.size()==0) {
00718 return id;
00719 }
00720 size_t point_index = size_t(-1);
00721 if (id < ids_.size() && ids_[id]==id) {
00722 return id;
00723 }
00724 else {
00725
00726 size_t start = 0;
00727 size_t end = ids_.size();
00728
00729 while (start<end) {
00730 size_t mid = (start+end)/2;
00731 if (ids_[mid]==id) {
00732 point_index = mid;
00733 break;
00734 }
00735 else if (ids_[mid]<id) {
00736 start = mid + 1;
00737 }
00738 else {
00739 end = mid;
00740 }
00741 }
00742 }
00743 return point_index;
00744 }
00745
00746
00747 void indices_to_ids(const size_t* in, size_t* out, size_t size) const
00748 {
00749 if (removed_) {
00750 for (size_t i=0;i<size;++i) {
00751 out[i] = ids_[in[i]];
00752 }
00753 }
00754 }
00755
00756 void setDataset(const Matrix<ElementType>& dataset)
00757 {
00758 size_ = dataset.rows;
00759 veclen_ = dataset.cols;
00760 last_id_ = 0;
00761
00762 ids_.clear();
00763 removed_points_.clear();
00764 removed_ = false;
00765 removed_count_ = 0;
00766
00767 points_.resize(size_);
00768 for (size_t i=0;i<size_;++i) {
00769 points_[i] = dataset[i];
00770 }
00771 }
00772
00773 void extendDataset(const Matrix<ElementType>& new_points)
00774 {
00775 size_t new_size = size_ + new_points.rows;
00776 if (removed_) {
00777 removed_points_.resize(new_size);
00778 ids_.resize(new_size);
00779 }
00780 points_.resize(new_size);
00781 for (size_t i=size_;i<new_size;++i) {
00782 points_[i] = new_points[i-size_];
00783 if (removed_) {
00784 ids_[i] = last_id_++;
00785 removed_points_.reset(i);
00786 }
00787 }
00788 size_ = new_size;
00789 }
00790
00791
00792 void cleanRemovedPoints()
00793 {
00794 if (!removed_) return;
00795
00796 size_t last_idx = 0;
00797 for (size_t i=0;i<size_;++i) {
00798 if (!removed_points_.test(i)) {
00799 points_[last_idx] = points_[i];
00800 ids_[last_idx] = ids_[i];
00801 removed_points_.reset(last_idx);
00802 ++last_idx;
00803 }
00804 }
00805 points_.resize(last_idx);
00806 ids_.resize(last_idx);
00807 removed_points_.resize(last_idx);
00808 size_ = last_idx;
00809 removed_count_ = 0;
00810 }
00811
00812 void swap(NNIndex& other)
00813 {
00814 std::swap(distance_, other.distance_);
00815 std::swap(last_id_, other.last_id_);
00816 std::swap(size_, other.size_);
00817 std::swap(size_at_build_, other.size_at_build_);
00818 std::swap(veclen_, other.veclen_);
00819 std::swap(index_params_, other.index_params_);
00820 std::swap(removed_, other.removed_);
00821 std::swap(removed_points_, other.removed_points_);
00822 std::swap(removed_count_, other.removed_count_);
00823 std::swap(ids_, other.ids_);
00824 std::swap(points_, other.points_);
00825 std::swap(data_ptr_, other.data_ptr_);
00826 }
00827
00828 protected:
00829
00833 Distance distance_;
00834
00835
00841 size_t last_id_;
00842
00846 size_t size_;
00847
00851 size_t size_at_build_;
00852
00856 size_t veclen_;
00857
00861 IndexParams index_params_;
00862
00866 bool removed_;
00867
00871 DynamicBitset removed_points_;
00872
00876 size_t removed_count_;
00877
00881 std::vector<size_t> ids_;
00882
00886 std::vector<ElementType*> points_;
00887
00891 ElementType* data_ptr_;
00892
00893
00894 };
00895
00896
00897 #define USING_BASECLASS_SYMBOLS \
00898 using NNIndex<Distance>::distance_;\
00899 using NNIndex<Distance>::size_;\
00900 using NNIndex<Distance>::size_at_build_;\
00901 using NNIndex<Distance>::veclen_;\
00902 using NNIndex<Distance>::index_params_;\
00903 using NNIndex<Distance>::removed_points_;\
00904 using NNIndex<Distance>::ids_;\
00905 using NNIndex<Distance>::removed_;\
00906 using NNIndex<Distance>::points_;\
00907 using NNIndex<Distance>::extendDataset;\
00908 using NNIndex<Distance>::setDataset;\
00909 using NNIndex<Distance>::cleanRemovedPoints;\
00910 using NNIndex<Distance>::indices_to_ids;
00911
00912
00913
00914 }
00915
00916
00917 #endif //FLANN_NNINDEX_H