00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef RTABMAP_FLANN_KMEANS_INDEX_H_
00032 #define RTABMAP_FLANN_KMEANS_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <string>
00036 #include <map>
00037 #include <cassert>
00038 #include <limits>
00039 #include <cmath>
00040
00041 #include "rtflann/general.h"
00042 #include "rtflann/algorithms/nn_index.h"
00043 #include "rtflann/algorithms/dist.h"
00044 #include "rtflann/algorithms/center_chooser.h"
00045 #include "rtflann/util/matrix.h"
00046 #include "rtflann/util/result_set.h"
00047 #include "rtflann/util/heap.h"
00048 #include "rtflann/util/allocator.h"
00049 #include "rtflann/util/random.h"
00050 #include "rtflann/util/saving.h"
00051 #include "rtflann/util/logger.h"
00052
00053
00054
00055 namespace rtflann
00056 {
00057
00058 struct KMeansIndexParams : public IndexParams
00059 {
00060 KMeansIndexParams(int branching = 32, int iterations = 11,
00061 flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 )
00062 {
00063 (*this)["algorithm"] = FLANN_INDEX_KMEANS;
00064
00065 (*this)["branching"] = branching;
00066
00067 (*this)["iterations"] = iterations;
00068
00069 (*this)["centers_init"] = centers_init;
00070
00071 (*this)["cb_index"] = cb_index;
00072 }
00073 };
00074
00075
00082 template <typename Distance>
00083 class KMeansIndex : public NNIndex<Distance>
00084 {
00085 public:
00086 typedef typename Distance::ElementType ElementType;
00087 typedef typename Distance::ResultType DistanceType;
00088
00089 typedef NNIndex<Distance> BaseClass;
00090
00091 typedef bool needs_vector_space_distance;
00092
00093
00094
00095 flann_algorithm_t getType() const
00096 {
00097 return FLANN_INDEX_KMEANS;
00098 }
00099
00107 KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(),
00108 Distance d = Distance())
00109 : BaseClass(params,d), root_(NULL), memoryCounter_(0)
00110 {
00111 branching_ = get_param(params,"branching",32);
00112 iterations_ = get_param(params,"iterations",11);
00113 if (iterations_<0) {
00114 iterations_ = (std::numeric_limits<int>::max)();
00115 }
00116 centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
00117 cb_index_ = get_param(params,"cb_index",0.4f);
00118
00119 initCenterChooser();
00120 setDataset(inputData);
00121 }
00122
00123
00131 KMeansIndex(const IndexParams& params = KMeansIndexParams(), Distance d = Distance())
00132 : BaseClass(params, d), root_(NULL), memoryCounter_(0)
00133 {
00134 branching_ = get_param(params,"branching",32);
00135 iterations_ = get_param(params,"iterations",11);
00136 if (iterations_<0) {
00137 iterations_ = (std::numeric_limits<int>::max)();
00138 }
00139 centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
00140 cb_index_ = get_param(params,"cb_index",0.4f);
00141
00142 initCenterChooser();
00143 }
00144
00145
00146 KMeansIndex(const KMeansIndex& other) : BaseClass(other),
00147 branching_(other.branching_),
00148 iterations_(other.iterations_),
00149 centers_init_(other.centers_init_),
00150 cb_index_(other.cb_index_),
00151 memoryCounter_(other.memoryCounter_)
00152 {
00153 initCenterChooser();
00154
00155 copyTree(root_, other.root_);
00156 }
00157
00158 KMeansIndex& operator=(KMeansIndex other)
00159 {
00160 this->swap(other);
00161 return *this;
00162 }
00163
00164
00165 void initCenterChooser()
00166 {
00167 switch(centers_init_) {
00168 case FLANN_CENTERS_RANDOM:
00169 chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
00170 break;
00171 case FLANN_CENTERS_GONZALES:
00172 chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
00173 break;
00174 case FLANN_CENTERS_KMEANSPP:
00175 chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
00176 break;
00177 default:
00178 throw FLANNException("Unknown algorithm for choosing initial centers.");
00179 }
00180 }
00181
00187 virtual ~KMeansIndex()
00188 {
00189 delete chooseCenters_;
00190 freeIndex();
00191 }
00192
00193 BaseClass* clone() const
00194 {
00195 return new KMeansIndex(*this);
00196 }
00197
00198
00199 void set_cb_index( float index)
00200 {
00201 cb_index_ = index;
00202 }
00203
00208 int usedMemory() const
00209 {
00210 return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
00211 }
00212
00213 using BaseClass::buildIndex;
00214
00215 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00216 {
00217 assert(points.cols==veclen_);
00218 size_t old_size = size_;
00219
00220 extendDataset(points);
00221
00222 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00223 buildIndex();
00224 }
00225 else {
00226 for (size_t i=0;i<points.rows;++i) {
00227 DistanceType dist = distance_(root_->pivot, points[i], veclen_);
00228 addPointToTree(root_, old_size + i, dist);
00229 }
00230 }
00231 }
00232
00233 template<typename Archive>
00234 void serialize(Archive& ar)
00235 {
00236 ar.setObject(this);
00237
00238 ar & *static_cast<NNIndex<Distance>*>(this);
00239
00240 ar & branching_;
00241 ar & iterations_;
00242 ar & memoryCounter_;
00243 ar & cb_index_;
00244 ar & centers_init_;
00245
00246 if (Archive::is_loading::value) {
00247 root_ = new(pool_) Node();
00248 }
00249 ar & *root_;
00250
00251 if (Archive::is_loading::value) {
00252 index_params_["algorithm"] = getType();
00253 index_params_["branching"] = branching_;
00254 index_params_["iterations"] = iterations_;
00255 index_params_["centers_init"] = centers_init_;
00256 index_params_["cb_index"] = cb_index_;
00257 }
00258 }
00259
00260 void saveIndex(FILE* stream)
00261 {
00262 serialization::SaveArchive sa(stream);
00263 sa & *this;
00264 }
00265
00266 void loadIndex(FILE* stream)
00267 {
00268 freeIndex();
00269 serialization::LoadArchive la(stream);
00270 la & *this;
00271 }
00272
00283 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00284 {
00285 if (removed_) {
00286 findNeighborsWithRemoved<true>(result, vec, searchParams);
00287 }
00288 else {
00289 findNeighborsWithRemoved<false>(result, vec, searchParams);
00290 }
00291
00292 }
00293
00301 int getClusterCenters(Matrix<DistanceType>& centers)
00302 {
00303 int numClusters = centers.rows;
00304 if (numClusters<1) {
00305 throw FLANNException("Number of clusters must be at least 1");
00306 }
00307
00308 DistanceType variance;
00309 std::vector<NodePtr> clusters(numClusters);
00310
00311 int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
00312
00313 Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
00314
00315 for (int i=0; i<clusterCount; ++i) {
00316 DistanceType* center = clusters[i]->pivot;
00317 for (size_t j=0; j<veclen_; ++j) {
00318 centers[i][j] = center[j];
00319 }
00320 }
00321
00322 return clusterCount;
00323 }
00324
00325 protected:
00329 void buildIndexImpl()
00330 {
00331 chooseCenters_->setDataSize(veclen_);
00332
00333 if (branching_<2) {
00334 throw FLANNException("Branching factor must be at least 2");
00335 }
00336
00337 std::vector<int> indices(size_);
00338 for (size_t i=0; i<size_; ++i) {
00339 indices[i] = int(i);
00340 }
00341
00342 root_ = new(pool_) Node();
00343 computeNodeStatistics(root_, indices);
00344 computeClustering(root_, &indices[0], (int)size_, branching_);
00345 }
00346
00347 private:
00348
00349 struct PointInfo
00350 {
00351 size_t index;
00352 ElementType* point;
00353 private:
00354 template<typename Archive>
00355 void serialize(Archive& ar)
00356 {
00357 typedef KMeansIndex<Distance> Index;
00358 Index* obj = static_cast<Index*>(ar.getObject());
00359
00360 ar & index;
00361
00362
00363 if (Archive::is_loading::value) point = obj->points_[index];
00364 }
00365 friend struct serialization::access;
00366 };
00367
00371 struct Node
00372 {
00376 DistanceType* pivot;
00380 DistanceType radius;
00384 DistanceType variance;
00388 int size;
00392 std::vector<Node*> childs;
00396 std::vector<PointInfo> points;
00400
00401
00402 ~Node()
00403 {
00404 delete[] pivot;
00405 if (!childs.empty()) {
00406 for (size_t i=0; i<childs.size(); ++i) {
00407 childs[i]->~Node();
00408 }
00409 }
00410 }
00411
00412 template<typename Archive>
00413 void serialize(Archive& ar)
00414 {
00415 typedef KMeansIndex<Distance> Index;
00416 Index* obj = static_cast<Index*>(ar.getObject());
00417
00418 if (Archive::is_loading::value) {
00419 pivot = new DistanceType[obj->veclen_];
00420 }
00421 ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType));
00422 ar & radius;
00423 ar & variance;
00424 ar & size;
00425
00426 size_t childs_size;
00427 if (Archive::is_saving::value) {
00428 childs_size = childs.size();
00429 }
00430 ar & childs_size;
00431
00432 if (childs_size==0) {
00433 ar & points;
00434 }
00435 else {
00436 if (Archive::is_loading::value) {
00437 childs.resize(childs_size);
00438 }
00439 for (size_t i=0;i<childs_size;++i) {
00440 if (Archive::is_loading::value) {
00441 childs[i] = new(obj->pool_) Node();
00442 }
00443 ar & *childs[i];
00444 }
00445 }
00446 }
00447 friend struct serialization::access;
00448 };
00449 typedef Node* NodePtr;
00450
00454 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00455
00456
00460 void freeIndex()
00461 {
00462 if (root_) root_->~Node();
00463 root_ = NULL;
00464 pool_.free();
00465 }
00466
00467 void copyTree(NodePtr& dst, const NodePtr& src)
00468 {
00469 dst = new(pool_) Node();
00470 dst->pivot = new DistanceType[veclen_];
00471 std::copy(src->pivot, src->pivot+veclen_, dst->pivot);
00472 dst->radius = src->radius;
00473 dst->variance = src->variance;
00474 dst->size = src->size;
00475
00476 if (src->childs.size()==0) {
00477 dst->points = src->points;
00478 }
00479 else {
00480 dst->childs.resize(src->childs.size());
00481 for (size_t i=0;i<src->childs.size();++i) {
00482 copyTree(dst->childs[i], src->childs[i]);
00483 }
00484 }
00485 }
00486
00487
00495 void computeNodeStatistics(NodePtr node, const std::vector<int>& indices)
00496 {
00497 size_t size = indices.size();
00498
00499 DistanceType* mean = new DistanceType[veclen_];
00500 memoryCounter_ += int(veclen_*sizeof(DistanceType));
00501 memset(mean,0,veclen_*sizeof(DistanceType));
00502
00503 for (size_t i=0; i<size; ++i) {
00504 ElementType* vec = points_[indices[i]];
00505 for (size_t j=0; j<veclen_; ++j) {
00506 mean[j] += vec[j];
00507 }
00508 }
00509 DistanceType div_factor = DistanceType(1)/size;
00510 for (size_t j=0; j<veclen_; ++j) {
00511 mean[j] *= div_factor;
00512 }
00513
00514 DistanceType radius = 0;
00515 DistanceType variance = 0;
00516 for (size_t i=0; i<size; ++i) {
00517 DistanceType dist = distance_(mean, points_[indices[i]], veclen_);
00518 if (dist>radius) {
00519 radius = dist;
00520 }
00521 variance += dist;
00522 }
00523 variance /= size;
00524
00525 node->variance = variance;
00526 node->radius = radius;
00527 node->pivot = mean;
00528 }
00529
00530
00542 void computeClustering(NodePtr node, int* indices, int indices_length, int branching)
00543 {
00544 node->size = indices_length;
00545
00546 if (indices_length < branching) {
00547 node->points.resize(indices_length);
00548 for (int i=0;i<indices_length;++i) {
00549 node->points[i].index = indices[i];
00550 node->points[i].point = points_[indices[i]];
00551 }
00552 node->childs.clear();
00553 return;
00554 }
00555
00556 std::vector<int> centers_idx(branching);
00557 int centers_length;
00558 (*chooseCenters_)(branching, indices, indices_length, ¢ers_idx[0], centers_length);
00559
00560 if (centers_length<branching) {
00561 node->points.resize(indices_length);
00562 for (int i=0;i<indices_length;++i) {
00563 node->points[i].index = indices[i];
00564 node->points[i].point = points_[indices[i]];
00565 }
00566 node->childs.clear();
00567 return;
00568 }
00569
00570
00571 Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
00572 for (int i=0; i<centers_length; ++i) {
00573 ElementType* vec = points_[centers_idx[i]];
00574 for (size_t k=0; k<veclen_; ++k) {
00575 dcenters[i][k] = double(vec[k]);
00576 }
00577 }
00578
00579 std::vector<DistanceType> radiuses(branching,0);
00580 std::vector<int> count(branching,0);
00581
00582
00583 std::vector<int> belongs_to(indices_length);
00584 for (int i=0; i<indices_length; ++i) {
00585
00586 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
00587 belongs_to[i] = 0;
00588 for (int j=1; j<branching; ++j) {
00589 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
00590 if (sq_dist>new_sq_dist) {
00591 belongs_to[i] = j;
00592 sq_dist = new_sq_dist;
00593 }
00594 }
00595 if (sq_dist>radiuses[belongs_to[i]]) {
00596 radiuses[belongs_to[i]] = sq_dist;
00597 }
00598 count[belongs_to[i]]++;
00599 }
00600
00601 bool converged = false;
00602 int iteration = 0;
00603 while (!converged && iteration<iterations_) {
00604 converged = true;
00605 iteration++;
00606
00607
00608 for (int i=0; i<branching; ++i) {
00609 memset(dcenters[i],0,sizeof(double)*veclen_);
00610 radiuses[i] = 0;
00611 }
00612 for (int i=0; i<indices_length; ++i) {
00613 ElementType* vec = points_[indices[i]];
00614 double* center = dcenters[belongs_to[i]];
00615 for (size_t k=0; k<veclen_; ++k) {
00616 center[k] += vec[k];
00617 }
00618 }
00619 for (int i=0; i<branching; ++i) {
00620 int cnt = count[i];
00621 double div_factor = 1.0/cnt;
00622 for (size_t k=0; k<veclen_; ++k) {
00623 dcenters[i][k] *= div_factor;
00624 }
00625 }
00626
00627
00628 for (int i=0; i<indices_length; ++i) {
00629 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
00630 int new_centroid = 0;
00631 for (int j=1; j<branching; ++j) {
00632 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
00633 if (sq_dist>new_sq_dist) {
00634 new_centroid = j;
00635 sq_dist = new_sq_dist;
00636 }
00637 }
00638 if (sq_dist>radiuses[new_centroid]) {
00639 radiuses[new_centroid] = sq_dist;
00640 }
00641 if (new_centroid != belongs_to[i]) {
00642 count[belongs_to[i]]--;
00643 count[new_centroid]++;
00644 belongs_to[i] = new_centroid;
00645
00646 converged = false;
00647 }
00648 }
00649
00650 for (int i=0; i<branching; ++i) {
00651
00652
00653 if (count[i]==0) {
00654 int j = (i+1)%branching;
00655 while (count[j]<=1) {
00656 j = (j+1)%branching;
00657 }
00658
00659 for (int k=0; k<indices_length; ++k) {
00660 if (belongs_to[k]==j) {
00661 belongs_to[k] = i;
00662 count[j]--;
00663 count[i]++;
00664 break;
00665 }
00666 }
00667 converged = false;
00668 }
00669 }
00670
00671 }
00672
00673 std::vector<DistanceType*> centers(branching);
00674
00675 for (int i=0; i<branching; ++i) {
00676 centers[i] = new DistanceType[veclen_];
00677 memoryCounter_ += veclen_*sizeof(DistanceType);
00678 for (size_t k=0; k<veclen_; ++k) {
00679 centers[i][k] = (DistanceType)dcenters[i][k];
00680 }
00681 }
00682
00683
00684
00685 node->childs.resize(branching);
00686 int start = 0;
00687 int end = start;
00688 for (int c=0; c<branching; ++c) {
00689 int s = count[c];
00690
00691 DistanceType variance = 0;
00692 for (int i=0; i<indices_length; ++i) {
00693 if (belongs_to[i]==c) {
00694 variance += distance_(centers[c], points_[indices[i]], veclen_);
00695 std::swap(indices[i],indices[end]);
00696 std::swap(belongs_to[i],belongs_to[end]);
00697 end++;
00698 }
00699 }
00700 variance /= s;
00701
00702 node->childs[c] = new(pool_) Node();
00703 node->childs[c]->radius = radiuses[c];
00704 node->childs[c]->pivot = centers[c];
00705 node->childs[c]->variance = variance;
00706 computeClustering(node->childs[c],indices+start, end-start, branching);
00707 start=end;
00708 }
00709
00710 delete[] dcenters.ptr();
00711 }
00712
00713
00714 template<bool with_removed>
00715 void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00716 {
00717
00718 int maxChecks = searchParams.checks;
00719
00720 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00721 findExactNN<with_removed>(root_, result, vec);
00722 }
00723 else {
00724
00725 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00726
00727 int checks = 0;
00728 findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
00729
00730 BranchSt branch;
00731 while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
00732 NodePtr node = branch.node;
00733 findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
00734 }
00735
00736 delete heap;
00737 }
00738
00739 }
00740
00741
00754 template<bool with_removed>
00755 void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
00756 Heap<BranchSt>* heap) const
00757 {
00758
00759 {
00760 DistanceType bsq = distance_(vec, node->pivot, veclen_);
00761 DistanceType rsq = node->radius;
00762 DistanceType wsq = result.worstDist();
00763
00764 DistanceType val = bsq-rsq-wsq;
00765 DistanceType val2 = val*val-4*rsq*wsq;
00766
00767
00768 if ((val>0)&&(val2>0)) {
00769 return;
00770 }
00771 }
00772
00773 if (node->childs.empty()) {
00774 if (checks>=maxChecks) {
00775 if (result.full()) return;
00776 }
00777 for (int i=0; i<node->size; ++i) {
00778 PointInfo& point_info = node->points[i];
00779 int index = point_info.index;
00780 if (with_removed) {
00781 if (removed_points_.test(index)) continue;
00782 }
00783 DistanceType dist = distance_(point_info.point, vec, veclen_);
00784 result.addPoint(dist, index);
00785 ++checks;
00786 }
00787 }
00788 else {
00789 int closest_center = exploreNodeBranches(node, vec, heap);
00790 findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
00791 }
00792 }
00793
00802 int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const
00803 {
00804 std::vector<DistanceType> domain_distances(branching_);
00805 int best_index = 0;
00806 domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_);
00807 for (int i=1; i<branching_; ++i) {
00808 domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_);
00809 if (domain_distances[i]<domain_distances[best_index]) {
00810 best_index = i;
00811 }
00812 }
00813
00814
00815 for (int i=0; i<branching_; ++i) {
00816 if (i != best_index) {
00817 domain_distances[i] -= cb_index_*node->childs[i]->variance;
00818
00819
00820
00821
00822
00823 heap->insert(BranchSt(node->childs[i],domain_distances[i]));
00824 }
00825 }
00826
00827 return best_index;
00828 }
00829
00830
00834 template<bool with_removed>
00835 void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const
00836 {
00837
00838 {
00839 DistanceType bsq = distance_(vec, node->pivot, veclen_);
00840 DistanceType rsq = node->radius;
00841 DistanceType wsq = result.worstDist();
00842
00843 DistanceType val = bsq-rsq-wsq;
00844 DistanceType val2 = val*val-4*rsq*wsq;
00845
00846
00847 if ((val>0)&&(val2>0)) {
00848 return;
00849 }
00850 }
00851
00852 if (node->childs.empty()) {
00853 for (int i=0; i<node->size; ++i) {
00854 PointInfo& point_info = node->points[i];
00855 int index = point_info.index;
00856 if (with_removed) {
00857 if (removed_points_.test(index)) continue;
00858 }
00859 DistanceType dist = distance_(point_info.point, vec, veclen_);
00860 result.addPoint(dist, index);
00861 }
00862 }
00863 else {
00864 std::vector<int> sort_indices(branching_);
00865 getCenterOrdering(node, vec, sort_indices);
00866
00867 for (int i=0; i<branching_; ++i) {
00868 findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec);
00869 }
00870
00871 }
00872 }
00873
00874
00880 void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const
00881 {
00882 std::vector<DistanceType> domain_distances(branching_);
00883 for (int i=0; i<branching_; ++i) {
00884 DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_);
00885
00886 int j=0;
00887 while (domain_distances[j]<dist && j<i) j++;
00888 for (int k=i; k>j; --k) {
00889 domain_distances[k] = domain_distances[k-1];
00890 sort_indices[k] = sort_indices[k-1];
00891 }
00892 domain_distances[j] = dist;
00893 sort_indices[j] = i;
00894 }
00895 }
00896
00902 DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) const
00903 {
00904 DistanceType sum = 0;
00905 DistanceType sum2 = 0;
00906
00907 for (int i=0; i<veclen_; ++i) {
00908 DistanceType t = c[i]-p[i];
00909 sum += t*(q[i]-(c[i]+p[i])/2);
00910 sum2 += t*t;
00911 }
00912
00913 return sum*sum/sum2;
00914 }
00915
00916
00926 int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const
00927 {
00928 int clusterCount = 1;
00929 clusters[0] = root;
00930
00931 DistanceType meanVariance = root->variance*root->size;
00932
00933 while (clusterCount<clusters_length) {
00934 DistanceType minVariance = (std::numeric_limits<DistanceType>::max)();
00935 int splitIndex = -1;
00936
00937 for (int i=0; i<clusterCount; ++i) {
00938 if (!clusters[i]->childs.empty()) {
00939
00940 DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
00941
00942 for (int j=0; j<branching_; ++j) {
00943 variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
00944 }
00945 if (variance<minVariance) {
00946 minVariance = variance;
00947 splitIndex = i;
00948 }
00949 }
00950 }
00951
00952 if (splitIndex==-1) break;
00953 if ( (branching_+clusterCount-1) > clusters_length) break;
00954
00955 meanVariance = minVariance;
00956
00957
00958 NodePtr toSplit = clusters[splitIndex];
00959 clusters[splitIndex] = toSplit->childs[0];
00960 for (int i=1; i<branching_; ++i) {
00961 clusters[clusterCount++] = toSplit->childs[i];
00962 }
00963 }
00964
00965 varianceValue = meanVariance/root->size;
00966 return clusterCount;
00967 }
00968
00969 void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
00970 {
00971 ElementType* point = points_[index];
00972 if (dist_to_pivot>node->radius) {
00973 node->radius = dist_to_pivot;
00974 }
00975
00976 node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
00977 node->size++;
00978
00979 if (node->childs.empty()) {
00980 PointInfo point_info;
00981 point_info.index = index;
00982 point_info.point = point;
00983 node->points.push_back(point_info);
00984
00985 std::vector<int> indices(node->points.size());
00986 for (size_t i=0;i<node->points.size();++i) {
00987 indices[i] = node->points[i].index;
00988 }
00989 computeNodeStatistics(node, indices);
00990 if (indices.size()>=size_t(branching_)) {
00991 computeClustering(node, &indices[0], indices.size(), branching_);
00992 }
00993 }
00994 else {
00995
00996 int closest = 0;
00997 DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_);
00998 for (size_t i=1;i<size_t(branching_);++i) {
00999 DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_);
01000 if (crt_dist<dist) {
01001 dist = crt_dist;
01002 closest = i;
01003 }
01004 }
01005 addPointToTree(node->childs[closest], index, dist);
01006 }
01007 }
01008
01009
01010 void swap(KMeansIndex& other)
01011 {
01012 std::swap(branching_, other.branching_);
01013 std::swap(iterations_, other.iterations_);
01014 std::swap(centers_init_, other.centers_init_);
01015 std::swap(cb_index_, other.cb_index_);
01016 std::swap(root_, other.root_);
01017 std::swap(pool_, other.pool_);
01018 std::swap(memoryCounter_, other.memoryCounter_);
01019 std::swap(chooseCenters_, other.chooseCenters_);
01020 }
01021
01022
01023 private:
01025 int branching_;
01026
01028 int iterations_;
01029
01031 flann_centers_init_t centers_init_;
01032
01039 float cb_index_;
01040
01044 NodePtr root_;
01045
01049 PooledAllocator pool_;
01050
01054 int memoryCounter_;
01055
01059 CenterChooser<Distance>* chooseCenters_;
01060
01061 USING_BASECLASS_SYMBOLS
01062 };
01063
01064 }
01065
01066 #endif //FLANN_KMEANS_INDEX_H_