00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef RTABMAP_FLANN_HIERARCHICAL_CLUSTERING_INDEX_H_
00032 #define RTABMAP_FLANN_HIERARCHICAL_CLUSTERING_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <string>
00036 #include <map>
00037 #include <cassert>
00038 #include <limits>
00039 #include <cmath>
00040
00041 #ifndef SIZE_MAX
00042 #define SIZE_MAX ((size_t) -1)
00043 #endif
00044
00045 #include "rtflann/general.h"
00046 #include "rtflann/algorithms/nn_index.h"
00047 #include "rtflann/algorithms/dist.h"
00048 #include "rtflann/util/matrix.h"
00049 #include "rtflann/util/result_set.h"
00050 #include "rtflann/util/heap.h"
00051 #include "rtflann/util/allocator.h"
00052 #include "rtflann/util/random.h"
00053 #include "rtflann/util/saving.h"
00054 #include "rtflann/util/serialization.h"
00055
00056 namespace rtflann
00057 {
00058
00059 struct HierarchicalClusteringIndexParams : public IndexParams
00060 {
00061 HierarchicalClusteringIndexParams(int branching = 32,
00062 flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM,
00063 int trees = 4, int leaf_max_size = 100)
00064 {
00065 (*this)["algorithm"] = FLANN_INDEX_HIERARCHICAL;
00066
00067 (*this)["branching"] = branching;
00068
00069 (*this)["centers_init"] = centers_init;
00070
00071 (*this)["trees"] = trees;
00072
00073 (*this)["leaf_max_size"] = leaf_max_size;
00074 }
00075 };
00076
00077
00078
00085 template <typename Distance>
00086 class HierarchicalClusteringIndex : public NNIndex<Distance>
00087 {
00088 public:
00089 typedef typename Distance::ElementType ElementType;
00090 typedef typename Distance::ResultType DistanceType;
00091
00092 typedef NNIndex<Distance> BaseClass;
00093
00100 HierarchicalClusteringIndex(const IndexParams& index_params = HierarchicalClusteringIndexParams(), Distance d = Distance())
00101 : BaseClass(index_params, d)
00102 {
00103 memoryCounter_ = 0;
00104
00105 branching_ = get_param(index_params_,"branching",32);
00106 centers_init_ = get_param(index_params_,"centers_init", FLANN_CENTERS_RANDOM);
00107 trees_ = get_param(index_params_,"trees",4);
00108 leaf_max_size_ = get_param(index_params_,"leaf_max_size",100);
00109
00110 initCenterChooser();
00111 }
00112
00113
00121 HierarchicalClusteringIndex(const Matrix<ElementType>& inputData, const IndexParams& index_params = HierarchicalClusteringIndexParams(),
00122 Distance d = Distance())
00123 : BaseClass(index_params, d)
00124 {
00125 memoryCounter_ = 0;
00126
00127 branching_ = get_param(index_params_,"branching",32);
00128 centers_init_ = get_param(index_params_,"centers_init", FLANN_CENTERS_RANDOM);
00129 trees_ = get_param(index_params_,"trees",4);
00130 leaf_max_size_ = get_param(index_params_,"leaf_max_size",100);
00131
00132 initCenterChooser();
00133
00134 setDataset(inputData);
00135
00136 chooseCenters_->setDataSize(veclen_);
00137 }
00138
00139
00140 HierarchicalClusteringIndex(const HierarchicalClusteringIndex& other) : BaseClass(other),
00141 memoryCounter_(other.memoryCounter_),
00142 branching_(other.branching_),
00143 trees_(other.trees_),
00144 centers_init_(other.centers_init_),
00145 leaf_max_size_(other.leaf_max_size_)
00146
00147 {
00148 initCenterChooser();
00149 tree_roots_.resize(other.tree_roots_.size());
00150 for (size_t i=0;i<tree_roots_.size();++i) {
00151 copyTree(tree_roots_[i], other.tree_roots_[i]);
00152 }
00153 }
00154
00155 HierarchicalClusteringIndex& operator=(HierarchicalClusteringIndex other)
00156 {
00157 this->swap(other);
00158 return *this;
00159 }
00160
00161
00162 void initCenterChooser()
00163 {
00164 switch(centers_init_) {
00165 case FLANN_CENTERS_RANDOM:
00166 chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
00167 break;
00168 case FLANN_CENTERS_GONZALES:
00169 chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
00170 break;
00171 case FLANN_CENTERS_KMEANSPP:
00172 chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
00173 break;
00174 case FLANN_CENTERS_GROUPWISE:
00175 chooseCenters_ = new GroupWiseCenterChooser<Distance>(distance_, points_);
00176 break;
00177 default:
00178 throw FLANNException("Unknown algorithm for choosing initial centers.");
00179 }
00180 }
00181
00187 virtual ~HierarchicalClusteringIndex()
00188 {
00189 delete chooseCenters_;
00190 freeIndex();
00191 }
00192
00193 BaseClass* clone() const
00194 {
00195 return new HierarchicalClusteringIndex(*this);
00196 }
00197
00202 int usedMemory() const
00203 {
00204 return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
00205 }
00206
00207 using BaseClass::buildIndex;
00208
00209 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00210 {
00211 assert(points.cols==veclen_);
00212 size_t old_size = size_;
00213
00214 extendDataset(points);
00215
00216 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00217 buildIndex();
00218 }
00219 else {
00220 for (size_t i=0;i<points.rows;++i) {
00221 for (int j = 0; j < trees_; j++) {
00222 addPointToTree(tree_roots_[j], old_size + i);
00223 }
00224 }
00225 }
00226 }
00227
00228
00229 flann_algorithm_t getType() const
00230 {
00231 return FLANN_INDEX_HIERARCHICAL;
00232 }
00233
00234
00235 template<typename Archive>
00236 void serialize(Archive& ar)
00237 {
00238 ar.setObject(this);
00239
00240 ar & *static_cast<NNIndex<Distance>*>(this);
00241
00242 ar & branching_;
00243 ar & trees_;
00244 ar & centers_init_;
00245 ar & leaf_max_size_;
00246
00247 if (Archive::is_loading::value) {
00248 tree_roots_.resize(trees_);
00249 }
00250 for (size_t i=0;i<tree_roots_.size();++i) {
00251 if (Archive::is_loading::value) {
00252 tree_roots_[i] = new(pool_) Node();
00253 }
00254 ar & *tree_roots_[i];
00255 }
00256
00257 if (Archive::is_loading::value) {
00258 index_params_["algorithm"] = getType();
00259 index_params_["branching"] = branching_;
00260 index_params_["trees"] = trees_;
00261 index_params_["centers_init"] = centers_init_;
00262 index_params_["leaf_size"] = leaf_max_size_;
00263 }
00264 }
00265
00266 void saveIndex(FILE* stream)
00267 {
00268 serialization::SaveArchive sa(stream);
00269 sa & *this;
00270 }
00271
00272
00273 void loadIndex(FILE* stream)
00274 {
00275 serialization::LoadArchive la(stream);
00276 la & *this;
00277 }
00278
00279
00290 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00291 {
00292 if (removed_) {
00293 findNeighborsWithRemoved<true>(result, vec, searchParams);
00294 }
00295 else {
00296 findNeighborsWithRemoved<false>(result, vec, searchParams);
00297 }
00298 }
00299
00300 protected:
00301
00305 void buildIndexImpl()
00306 {
00307 chooseCenters_->setDataSize(veclen_);
00308
00309 if (branching_<2) {
00310 throw FLANNException("Branching factor must be at least 2");
00311 }
00312 tree_roots_.resize(trees_);
00313 std::vector<int> indices(size_);
00314 for (int i=0; i<trees_; ++i) {
00315 for (size_t j=0; j<size_; ++j) {
00316 indices[j] = j;
00317 }
00318 tree_roots_[i] = new(pool_) Node();
00319 computeClustering(tree_roots_[i], &indices[0], size_);
00320 }
00321 }
00322
00323 private:
00324
00325 struct PointInfo
00326 {
00328 size_t index;
00330 ElementType* point;
00331
00332 private:
00333 template<typename Archive>
00334 void serialize(Archive& ar)
00335 {
00336 typedef HierarchicalClusteringIndex<Distance> Index;
00337 Index* obj = static_cast<Index*>(ar.getObject());
00338
00339 ar & index;
00340
00341
00342 if (Archive::is_loading::value) {
00343 point = obj->points_[index];
00344 }
00345 }
00346 friend struct serialization::access;
00347 };
00348
00352 struct Node
00353 {
00357 ElementType* pivot;
00358 size_t pivot_index;
00362 std::vector<Node*> childs;
00366 std::vector<PointInfo> points;
00367
00368 Node(){
00369 pivot = NULL;
00370 pivot_index = SIZE_MAX;
00371 }
00376 ~Node()
00377 {
00378 for(size_t i=0; i<childs.size(); i++){
00379 childs[i]->~Node();
00380 pivot = NULL;
00381 pivot_index = -1;
00382 }
00383 };
00384
00385 private:
00386 template<typename Archive>
00387 void serialize(Archive& ar)
00388 {
00389 typedef HierarchicalClusteringIndex<Distance> Index;
00390 Index* obj = static_cast<Index*>(ar.getObject());
00391 ar & pivot_index;
00392 if (Archive::is_loading::value) {
00393 if (pivot_index != SIZE_MAX)
00394 pivot = obj->points_[pivot_index];
00395 else
00396 pivot = NULL;
00397 }
00398 size_t childs_size;
00399 if (Archive::is_saving::value) {
00400 childs_size = childs.size();
00401 }
00402 ar & childs_size;
00403
00404 if (childs_size==0) {
00405 ar & points;
00406 }
00407 else {
00408 if (Archive::is_loading::value) {
00409 childs.resize(childs_size);
00410 }
00411 for (size_t i=0;i<childs_size;++i) {
00412 if (Archive::is_loading::value) {
00413 childs[i] = new(obj->pool_) Node();
00414 }
00415 ar & *childs[i];
00416 }
00417 }
00418
00419 }
00420 friend struct serialization::access;
00421 };
00422 typedef Node* NodePtr;
00423
00424
00425
00429 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00430
00431
00436 void freeIndex(){
00437 for (size_t i=0; i<tree_roots_.size(); ++i) {
00438 tree_roots_[i]->~Node();
00439 }
00440 pool_.free();
00441 }
00442
00443 void copyTree(NodePtr& dst, const NodePtr& src)
00444 {
00445 dst = new(pool_) Node();
00446 dst->pivot_index = src->pivot_index;
00447 dst->pivot = points_[dst->pivot_index];
00448
00449 if (src->childs.size()==0) {
00450 dst->points = src->points;
00451 }
00452 else {
00453 dst->childs.resize(src->childs.size());
00454 for (size_t i=0;i<src->childs.size();++i) {
00455 copyTree(dst->childs[i], src->childs[i]);
00456 }
00457 }
00458 }
00459
00460
00461
00462 void computeLabels(int* indices, int indices_length, int* centers, int centers_length, int* labels, DistanceType& cost)
00463 {
00464 cost = 0;
00465 for (int i=0; i<indices_length; ++i) {
00466 ElementType* point = points_[indices[i]];
00467 DistanceType dist = distance_(point, points_[centers[0]], veclen_);
00468 labels[i] = 0;
00469 for (int j=1; j<centers_length; ++j) {
00470 DistanceType new_dist = distance_(point, points_[centers[j]], veclen_);
00471 if (dist>new_dist) {
00472 labels[i] = j;
00473 dist = new_dist;
00474 }
00475 }
00476 cost += dist;
00477 }
00478 }
00479
00490 void computeClustering(NodePtr node, int* indices, int indices_length)
00491 {
00492 if (indices_length < leaf_max_size_) {
00493 node->points.resize(indices_length);
00494 for (int i=0;i<indices_length;++i) {
00495 node->points[i].index = indices[i];
00496 node->points[i].point = points_[indices[i]];
00497 }
00498 node->childs.clear();
00499 return;
00500 }
00501
00502 std::vector<int> centers(branching_);
00503 std::vector<int> labels(indices_length);
00504
00505 int centers_length;
00506 (*chooseCenters_)(branching_, indices, indices_length, ¢ers[0], centers_length);
00507
00508 if (centers_length<branching_) {
00509 node->points.resize(indices_length);
00510 for (int i=0;i<indices_length;++i) {
00511 node->points[i].index = indices[i];
00512 node->points[i].point = points_[indices[i]];
00513 }
00514 node->childs.clear();
00515 return;
00516 }
00517
00518
00519
00520 DistanceType cost;
00521 computeLabels(indices, indices_length, ¢ers[0], centers_length, &labels[0], cost);
00522
00523 node->childs.resize(branching_);
00524 int start = 0;
00525 int end = start;
00526 for (int i=0; i<branching_; ++i) {
00527 for (int j=0; j<indices_length; ++j) {
00528 if (labels[j]==i) {
00529 std::swap(indices[j],indices[end]);
00530 std::swap(labels[j],labels[end]);
00531 end++;
00532 }
00533 }
00534
00535 node->childs[i] = new(pool_) Node();
00536 node->childs[i]->pivot_index = centers[i];
00537 node->childs[i]->pivot = points_[centers[i]];
00538 node->childs[i]->points.clear();
00539 computeClustering(node->childs[i],indices+start, end-start);
00540 start=end;
00541 }
00542 }
00543
00544
00545 template<bool with_removed>
00546 void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00547 {
00548 int maxChecks = searchParams.checks;
00549
00550
00551 Heap<BranchSt>* heap = new Heap<BranchSt>(size_);
00552
00553 DynamicBitset checked(size_);
00554 int checks = 0;
00555 for (int i=0; i<trees_; ++i) {
00556 findNN<with_removed>(tree_roots_[i], result, vec, checks, maxChecks, heap, checked);
00557 }
00558
00559 BranchSt branch;
00560 while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
00561 NodePtr node = branch.node;
00562 findNN<with_removed>(node, result, vec, checks, maxChecks, heap, checked);
00563 }
00564
00565 delete heap;
00566 }
00567
00568
00581 template<bool with_removed>
00582 void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
00583 Heap<BranchSt>* heap, DynamicBitset& checked) const
00584 {
00585 if (node->childs.empty()) {
00586 if (checks>=maxChecks) {
00587 if (result.full()) return;
00588 }
00589
00590 for (size_t i=0; i<node->points.size(); ++i) {
00591 PointInfo& pointInfo = node->points[i];
00592 if (with_removed) {
00593 if (removed_points_.test(pointInfo.index)) continue;
00594 }
00595 if (checked.test(pointInfo.index)) continue;
00596 DistanceType dist = distance_(pointInfo.point, vec, veclen_);
00597 result.addPoint(dist, pointInfo.index);
00598 checked.set(pointInfo.index);
00599 ++checks;
00600 }
00601 }
00602 else {
00603 DistanceType* domain_distances = new DistanceType[branching_];
00604 int best_index = 0;
00605 domain_distances[best_index] = distance_(vec, node->childs[best_index]->pivot, veclen_);
00606 for (int i=1; i<branching_; ++i) {
00607 domain_distances[i] = distance_(vec, node->childs[i]->pivot, veclen_);
00608 if (domain_distances[i]<domain_distances[best_index]) {
00609 best_index = i;
00610 }
00611 }
00612 for (int i=0; i<branching_; ++i) {
00613 if (i!=best_index) {
00614 heap->insert(BranchSt(node->childs[i],domain_distances[i]));
00615 }
00616 }
00617 delete[] domain_distances;
00618 findNN<with_removed>(node->childs[best_index],result,vec, checks, maxChecks, heap, checked);
00619 }
00620 }
00621
00622 void addPointToTree(NodePtr node, size_t index)
00623 {
00624 ElementType* point = points_[index];
00625
00626 if (node->childs.empty()) {
00627 PointInfo pointInfo;
00628 pointInfo.point = point;
00629 pointInfo.index = index;
00630 node->points.push_back(pointInfo);
00631
00632 if (node->points.size()>=size_t(branching_)) {
00633 std::vector<int> indices(node->points.size());
00634
00635 for (size_t i=0;i<node->points.size();++i) {
00636 indices[i] = node->points[i].index;
00637 }
00638 computeClustering(node, &indices[0], indices.size());
00639 }
00640 }
00641 else {
00642
00643 int closest = 0;
00644 ElementType* center = node->childs[closest]->pivot;
00645 DistanceType dist = distance_(center, point, veclen_);
00646 for (size_t i=1;i<size_t(branching_);++i) {
00647 center = node->childs[i]->pivot;
00648 DistanceType crt_dist = distance_(center, point, veclen_);
00649 if (crt_dist<dist) {
00650 dist = crt_dist;
00651 closest = i;
00652 }
00653 }
00654 addPointToTree(node->childs[closest], index);
00655 }
00656 }
00657
00658 void swap(HierarchicalClusteringIndex& other)
00659 {
00660 BaseClass::swap(other);
00661
00662 std::swap(tree_roots_, other.tree_roots_);
00663 std::swap(pool_, other.pool_);
00664 std::swap(memoryCounter_, other.memoryCounter_);
00665 std::swap(branching_, other.branching_);
00666 std::swap(trees_, other.trees_);
00667 std::swap(centers_init_, other.centers_init_);
00668 std::swap(leaf_max_size_, other.leaf_max_size_);
00669 std::swap(chooseCenters_, other.chooseCenters_);
00670 }
00671
00672 private:
00673
00677 std::vector<Node*> tree_roots_;
00678
00686 PooledAllocator pool_;
00687
00691 int memoryCounter_;
00692
00697 int branching_;
00698
00702 int trees_;
00703
00707 flann_centers_init_t centers_init_;
00708
00712 int leaf_max_size_;
00713
00717 CenterChooser<Distance>* chooseCenters_;
00718
00719 USING_BASECLASS_SYMBOLS
00720 };
00721
00722 }
00723
00724 #endif