00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef RTABMAP_FLANN_KDTREE_SINGLE_INDEX_H_
00032 #define RTABMAP_FLANN_KDTREE_SINGLE_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038
00039 #include "rtflann/general.h"
00040 #include "rtflann/algorithms/nn_index.h"
00041 #include "rtflann/util/matrix.h"
00042 #include "rtflann/util/result_set.h"
00043 #include "rtflann/util/heap.h"
00044 #include "rtflann/util/allocator.h"
00045 #include "rtflann/util/random.h"
00046 #include "rtflann/util/saving.h"
00047
00048 namespace rtflann
00049 {
00050
00051 struct KDTreeSingleIndexParams : public IndexParams
00052 {
00053 KDTreeSingleIndexParams(int leaf_max_size = 10, bool reorder = true)
00054 {
00055 (*this)["algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
00056 (*this)["leaf_max_size"] = leaf_max_size;
00057 (*this)["reorder"] = reorder;
00058 }
00059 };
00060
00061
00068 template <typename Distance>
00069 class KDTreeSingleIndex : public NNIndex<Distance>
00070 {
00071 public:
00072 typedef typename Distance::ElementType ElementType;
00073 typedef typename Distance::ResultType DistanceType;
00074
00075 typedef NNIndex<Distance> BaseClass;
00076
00077 typedef bool needs_kdtree_distance;
00078
00085 KDTreeSingleIndex(const IndexParams& params = KDTreeSingleIndexParams(), Distance d = Distance() ) :
00086 BaseClass(params, d), root_node_(NULL)
00087 {
00088 leaf_max_size_ = get_param(params,"leaf_max_size",10);
00089 reorder_ = get_param(params, "reorder", true);
00090 }
00091
00099 KDTreeSingleIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeSingleIndexParams(),
00100 Distance d = Distance() ) : BaseClass(params, d), root_node_(NULL)
00101 {
00102 leaf_max_size_ = get_param(params,"leaf_max_size",10);
00103 reorder_ = get_param(params, "reorder", true);
00104
00105 setDataset(inputData);
00106 }
00107
00108
00109 KDTreeSingleIndex(const KDTreeSingleIndex& other) : BaseClass(other),
00110 leaf_max_size_(other.leaf_max_size_),
00111 reorder_(other.reorder_),
00112 vind_(other.vind_),
00113 root_bbox_(other.root_bbox_)
00114 {
00115 if (reorder_) {
00116 data_ = rtflann::Matrix<ElementType>(new ElementType[size_*veclen_], size_, veclen_);
00117 std::copy(other.data_[0], other.data_[0]+size_*veclen_, data_[0]);
00118 }
00119 copyTree(root_node_, other.root_node_);
00120 }
00121
00122 KDTreeSingleIndex& operator=(KDTreeSingleIndex other)
00123 {
00124 this->swap(other);
00125 return *this;
00126 }
00127
00131 virtual ~KDTreeSingleIndex()
00132 {
00133 freeIndex();
00134 }
00135
00136 BaseClass* clone() const
00137 {
00138 return new KDTreeSingleIndex(*this);
00139 }
00140
00141 using BaseClass::buildIndex;
00142
00143 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00144 {
00145 assert(points.cols==veclen_);
00146 extendDataset(points);
00147 buildIndex();
00148 }
00149
00150 flann_algorithm_t getType() const
00151 {
00152 return FLANN_INDEX_KDTREE_SINGLE;
00153 }
00154
00155
00156 template<typename Archive>
00157 void serialize(Archive& ar)
00158 {
00159 ar.setObject(this);
00160
00161 if (reorder_) index_params_["save_dataset"] = false;
00162
00163 ar & *static_cast<NNIndex<Distance>*>(this);
00164
00165 ar & reorder_;
00166 ar & leaf_max_size_;
00167 ar & root_bbox_;
00168 ar & vind_;
00169
00170 if (reorder_) {
00171 ar & data_;
00172 }
00173
00174 if (Archive::is_loading::value) {
00175 root_node_ = new(pool_) Node();
00176 }
00177
00178 ar & *root_node_;
00179
00180 if (Archive::is_loading::value) {
00181 index_params_["algorithm"] = getType();
00182 index_params_["leaf_max_size"] = leaf_max_size_;
00183 index_params_["reorder"] = reorder_;
00184 }
00185 }
00186
00187
00188 void saveIndex(FILE* stream)
00189 {
00190 serialization::SaveArchive sa(stream);
00191 sa & *this;
00192 }
00193
00194
00195 void loadIndex(FILE* stream)
00196 {
00197 freeIndex();
00198 serialization::LoadArchive la(stream);
00199 la & *this;
00200 }
00201
00206 int usedMemory() const
00207 {
00208 return pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int);
00209 }
00210
00220 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00221 {
00222 float epsError = 1+searchParams.eps;
00223
00224 std::vector<DistanceType> dists(veclen_,0);
00225 DistanceType distsq = computeInitialDistances(vec, dists);
00226 if (removed_) {
00227 searchLevel<true>(result, vec, root_node_, distsq, dists, epsError);
00228 }
00229 else {
00230 searchLevel<false>(result, vec, root_node_, distsq, dists, epsError);
00231 }
00232 }
00233
00234 protected:
00235
00239 void buildIndexImpl()
00240 {
00241
00242 vind_.resize(size_);
00243 for (size_t i = 0; i < size_; i++) {
00244 vind_[i] = i;
00245 }
00246
00247 computeBoundingBox(root_bbox_);
00248 root_node_ = divideTree(0, size_, root_bbox_ );
00249
00250 if (reorder_) {
00251 data_ = rtflann::Matrix<ElementType>(new ElementType[size_*veclen_], size_, veclen_);
00252 for (size_t i=0; i<size_; ++i) {
00253 std::copy(points_[vind_[i]], points_[vind_[i]]+veclen_, data_[i]);
00254 }
00255 }
00256 }
00257
00258 private:
00259
00260
00261
00262 struct Node
00263 {
00267 int left, right;
00271 int divfeat;
00275 DistanceType divlow, divhigh;
00279 Node* child1, * child2;
00280
00281 ~Node()
00282 {
00283 if (child1) child1->~Node();
00284 if (child2) child2->~Node();
00285 }
00286
00287 private:
00288 template<typename Archive>
00289 void serialize(Archive& ar)
00290 {
00291 typedef KDTreeSingleIndex<Distance> Index;
00292 Index* obj = static_cast<Index*>(ar.getObject());
00293
00294 ar & left;
00295 ar & right;
00296 ar & divfeat;
00297 ar & divlow;
00298 ar & divhigh;
00299
00300 bool leaf_node = false;
00301 if (Archive::is_saving::value) {
00302 leaf_node = ((child1==NULL) && (child2==NULL));
00303 }
00304 ar & leaf_node;
00305
00306 if (!leaf_node) {
00307 if (Archive::is_loading::value) {
00308 child1 = new(obj->pool_) Node();
00309 child2 = new(obj->pool_) Node();
00310 }
00311 ar & *child1;
00312 ar & *child2;
00313 }
00314 }
00315 friend struct serialization::access;
00316 };
00317 typedef Node* NodePtr;
00318
00319
00320 struct Interval
00321 {
00322 DistanceType low, high;
00323
00324 private:
00325 template <typename Archive>
00326 void serialize(Archive& ar)
00327 {
00328 ar & low;
00329 ar & high;
00330 }
00331 friend struct serialization::access;
00332 };
00333
00334 typedef std::vector<Interval> BoundingBox;
00335
00336 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00337 typedef BranchSt* Branch;
00338
00339
00340
00341 void freeIndex()
00342 {
00343 if (data_.ptr()) {
00344 delete[] data_.ptr();
00345 data_ = rtflann::Matrix<ElementType>();
00346 }
00347 if (root_node_) root_node_->~Node();
00348 pool_.free();
00349 }
00350
00351 void copyTree(NodePtr& dst, const NodePtr& src)
00352 {
00353 dst = new(pool_) Node();
00354 *dst = *src;
00355 if (src->child1!=NULL && src->child2!=NULL) {
00356 copyTree(dst->child1, src->child1);
00357 copyTree(dst->child2, src->child2);
00358 }
00359 }
00360
00361
00362
00363 void computeBoundingBox(BoundingBox& bbox)
00364 {
00365 bbox.resize(veclen_);
00366 for (size_t i=0; i<veclen_; ++i) {
00367 bbox[i].low = (DistanceType)points_[0][i];
00368 bbox[i].high = (DistanceType)points_[0][i];
00369 }
00370 for (size_t k=1; k<size_; ++k) {
00371 for (size_t i=0; i<veclen_; ++i) {
00372 if (points_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)points_[k][i];
00373 if (points_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)points_[k][i];
00374 }
00375 }
00376 }
00377
00378
00388 NodePtr divideTree(int left, int right, BoundingBox& bbox)
00389 {
00390 NodePtr node = new (pool_) Node();
00391
00392
00393 if ( (right-left) <= leaf_max_size_) {
00394 node->child1 = node->child2 = NULL;
00395 node->left = left;
00396 node->right = right;
00397
00398
00399 for (size_t i=0; i<veclen_; ++i) {
00400 bbox[i].low = (DistanceType)points_[vind_[left]][i];
00401 bbox[i].high = (DistanceType)points_[vind_[left]][i];
00402 }
00403 for (int k=left+1; k<right; ++k) {
00404 for (size_t i=0; i<veclen_; ++i) {
00405 if (bbox[i].low>points_[vind_[k]][i]) bbox[i].low=(DistanceType)points_[vind_[k]][i];
00406 if (bbox[i].high<points_[vind_[k]][i]) bbox[i].high=(DistanceType)points_[vind_[k]][i];
00407 }
00408 }
00409 }
00410 else {
00411 int idx;
00412 int cutfeat;
00413 DistanceType cutval;
00414 middleSplit(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
00415
00416 node->divfeat = cutfeat;
00417
00418 BoundingBox left_bbox(bbox);
00419 left_bbox[cutfeat].high = cutval;
00420 node->child1 = divideTree(left, left+idx, left_bbox);
00421
00422 BoundingBox right_bbox(bbox);
00423 right_bbox[cutfeat].low = cutval;
00424 node->child2 = divideTree(left+idx, right, right_bbox);
00425
00426 node->divlow = left_bbox[cutfeat].high;
00427 node->divhigh = right_bbox[cutfeat].low;
00428
00429 for (size_t i=0; i<veclen_; ++i) {
00430 bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
00431 bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
00432 }
00433 }
00434
00435 return node;
00436 }
00437
00438 void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
00439 {
00440 min_elem = points_[ind[0]][dim];
00441 max_elem = points_[ind[0]][dim];
00442 for (int i=1; i<count; ++i) {
00443 ElementType val = points_[ind[i]][dim];
00444 if (val<min_elem) min_elem = val;
00445 if (val>max_elem) max_elem = val;
00446 }
00447 }
00448
00449 void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
00450 {
00451
00452 ElementType max_span = bbox[0].high-bbox[0].low;
00453 cutfeat = 0;
00454 cutval = (bbox[0].high+bbox[0].low)/2;
00455 for (size_t i=1; i<veclen_; ++i) {
00456 ElementType span = bbox[i].high-bbox[i].low;
00457 if (span>max_span) {
00458 max_span = span;
00459 cutfeat = i;
00460 cutval = (bbox[i].high+bbox[i].low)/2;
00461 }
00462 }
00463
00464
00465 ElementType min_elem, max_elem;
00466 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00467 cutval = (min_elem+max_elem)/2;
00468 max_span = max_elem - min_elem;
00469
00470
00471 size_t k = cutfeat;
00472 for (size_t i=0; i<veclen_; ++i) {
00473 if (i==k) continue;
00474 ElementType span = bbox[i].high-bbox[i].low;
00475 if (span>max_span) {
00476 computeMinMax(ind, count, i, min_elem, max_elem);
00477 span = max_elem - min_elem;
00478 if (span>max_span) {
00479 max_span = span;
00480 cutfeat = i;
00481 cutval = (min_elem+max_elem)/2;
00482 }
00483 }
00484 }
00485 int lim1, lim2;
00486 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00487
00488 if (lim1>count/2) index = lim1;
00489 else if (lim2<count/2) index = lim2;
00490 else index = count/2;
00491
00492 assert(index > 0 && index < count);
00493 }
00494
00495
00496 void middleSplit_(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
00497 {
00498 const float eps_val=0.00001f;
00499 DistanceType max_span = bbox[0].high-bbox[0].low;
00500 for (size_t i=1; i<veclen_; ++i) {
00501 DistanceType span = bbox[i].high-bbox[i].low;
00502 if (span>max_span) {
00503 max_span = span;
00504 }
00505 }
00506 DistanceType max_spread = -1;
00507 cutfeat = 0;
00508 for (size_t i=0; i<veclen_; ++i) {
00509 DistanceType span = bbox[i].high-bbox[i].low;
00510 if (span>(DistanceType)((1-eps_val)*max_span)) {
00511 ElementType min_elem, max_elem;
00512 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00513 DistanceType spread = (DistanceType)(max_elem-min_elem);
00514 if (spread>max_spread) {
00515 cutfeat = i;
00516 max_spread = spread;
00517 }
00518 }
00519 }
00520
00521 DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
00522 ElementType min_elem, max_elem;
00523 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00524
00525 if (split_val<min_elem) cutval = (DistanceType)min_elem;
00526 else if (split_val>max_elem) cutval = (DistanceType)max_elem;
00527 else cutval = split_val;
00528
00529 int lim1, lim2;
00530 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00531
00532 if (lim1>count/2) index = lim1;
00533 else if (lim2<count/2) index = lim2;
00534 else index = count/2;
00535
00536 assert(index > 0 && index < count);
00537 }
00538
00539
00549 void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00550 {
00551 int left = 0;
00552 int right = count-1;
00553 for (;; ) {
00554 while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left;
00555 while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right;
00556 if (left>right) break;
00557 std::swap(ind[left], ind[right]); ++left; --right;
00558 }
00559
00560 lim1 = left;
00561 right = count-1;
00562 for (;; ) {
00563 while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left;
00564 while (left<=right && points_[ind[right]][cutfeat]>cutval) --right;
00565 if (left>right) break;
00566 std::swap(ind[left], ind[right]); ++left; --right;
00567 }
00568 lim2 = left;
00569 }
00570
00571 DistanceType computeInitialDistances(const ElementType* vec, std::vector<DistanceType>& dists) const
00572 {
00573 DistanceType distsq = 0.0;
00574
00575 for (size_t i = 0; i < veclen_; ++i) {
00576 if (vec[i] < root_bbox_[i].low) {
00577 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, i);
00578 distsq += dists[i];
00579 }
00580 if (vec[i] > root_bbox_[i].high) {
00581 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, i);
00582 distsq += dists[i];
00583 }
00584 }
00585
00586 return distsq;
00587 }
00588
00592 template <bool with_removed>
00593 void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindistsq,
00594 std::vector<DistanceType>& dists, const float epsError) const
00595 {
00596
00597 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00598 DistanceType worst_dist = result_set.worstDist();
00599 for (int i=node->left; i<node->right; ++i) {
00600 if (with_removed) {
00601 if (removed_points_.test(vind_[i])) continue;
00602 }
00603 ElementType* point = reorder_ ? data_[i] : points_[vind_[i]];
00604 DistanceType dist = distance_(vec, point, veclen_, worst_dist);
00605 if (dist<worst_dist) {
00606 result_set.addPoint(dist,vind_[i]);
00607 }
00608 }
00609 return;
00610 }
00611
00612
00613 int idx = node->divfeat;
00614 ElementType val = vec[idx];
00615 DistanceType diff1 = val - node->divlow;
00616 DistanceType diff2 = val - node->divhigh;
00617
00618 NodePtr bestChild;
00619 NodePtr otherChild;
00620 DistanceType cut_dist;
00621 if ((diff1+diff2)<0) {
00622 bestChild = node->child1;
00623 otherChild = node->child2;
00624 cut_dist = distance_.accum_dist(val, node->divhigh, idx);
00625 }
00626 else {
00627 bestChild = node->child2;
00628 otherChild = node->child1;
00629 cut_dist = distance_.accum_dist( val, node->divlow, idx);
00630 }
00631
00632
00633 searchLevel<with_removed>(result_set, vec, bestChild, mindistsq, dists, epsError);
00634
00635 DistanceType dst = dists[idx];
00636 mindistsq = mindistsq + cut_dist - dst;
00637 dists[idx] = cut_dist;
00638 if (mindistsq*epsError<=result_set.worstDist()) {
00639 searchLevel<with_removed>(result_set, vec, otherChild, mindistsq, dists, epsError);
00640 }
00641 dists[idx] = dst;
00642 }
00643
00644
00645 void swap(KDTreeSingleIndex& other)
00646 {
00647 BaseClass::swap(other);
00648 std::swap(leaf_max_size_, other.leaf_max_size_);
00649 std::swap(reorder_, other.reorder_);
00650 std::swap(vind_, other.vind_);
00651 std::swap(data_, other.data_);
00652 std::swap(root_node_, other.root_node_);
00653 std::swap(root_bbox_, other.root_bbox_);
00654 std::swap(pool_, other.pool_);
00655 }
00656
00657 private:
00658
00659
00660
00661 int leaf_max_size_;
00662
00663
00664 bool reorder_;
00665
00669 std::vector<int> vind_;
00670
00671 Matrix<ElementType> data_;
00672
00676 NodePtr root_node_;
00677
00681 BoundingBox root_bbox_;
00682
00690 PooledAllocator pool_;
00691
00692 USING_BASECLASS_SYMBOLS
00693
00694 };
00695
00696 }
00697
00698 #endif //FLANN_KDTREE_SINGLE_INDEX_H_