kmeans_index.h
Go to the documentation of this file.
1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  * notice, this list of conditions and the following disclaimer in the
17  * documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30 
31 #ifndef RTABMAP_FLANN_KMEANS_INDEX_H_
32 #define RTABMAP_FLANN_KMEANS_INDEX_H_
33 
34 #include <algorithm>
35 #include <string>
36 #include <map>
37 #include <cassert>
38 #include <limits>
39 #include <cmath>
40 
41 #include "rtflann/general.h"
45 #include "rtflann/util/matrix.h"
47 #include "rtflann/util/heap.h"
48 #include "rtflann/util/allocator.h"
49 #include "rtflann/util/random.h"
50 #include "rtflann/util/saving.h"
51 #include "rtflann/util/logger.h"
52 
53 
54 
55 namespace rtflann
56 {
57 
59 {
60  KMeansIndexParams(int branching = 32, int iterations = 11,
61  flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 )
62  {
63  (*this)["algorithm"] = FLANN_INDEX_KMEANS;
64  // branching factor
65  (*this)["branching"] = branching;
66  // max iterations to perform in one kmeans clustering (kmeans tree)
67  (*this)["iterations"] = iterations;
68  // algorithm used for picking the initial cluster centers for kmeans tree
69  (*this)["centers_init"] = centers_init;
70  // cluster boundary index. Used when searching the kmeans tree
71  (*this)["cb_index"] = cb_index;
72  }
73 };
74 
75 
82 template <typename Distance>
83 class KMeansIndex : public NNIndex<Distance>
84 {
85 public:
86  typedef typename Distance::ElementType ElementType;
87  typedef typename Distance::ResultType DistanceType;
88 
90 
92 
93 
94 
96  {
97  return FLANN_INDEX_KMEANS;
98  }
99 
107  KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(),
108  Distance d = Distance())
109  : BaseClass(params,d), root_(NULL), memoryCounter_(0)
110  {
111  branching_ = get_param(params,"branching",32);
112  iterations_ = get_param(params,"iterations",11);
113  if (iterations_<0) {
114  iterations_ = (std::numeric_limits<int>::max)();
115  }
116  centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
117  cb_index_ = get_param(params,"cb_index",0.4f);
118 
119  initCenterChooser();
120  setDataset(inputData);
121  }
122 
123 
131  KMeansIndex(const IndexParams& params = KMeansIndexParams(), Distance d = Distance())
132  : BaseClass(params, d), root_(NULL), memoryCounter_(0)
133  {
134  branching_ = get_param(params,"branching",32);
135  iterations_ = get_param(params,"iterations",11);
136  if (iterations_<0) {
137  iterations_ = (std::numeric_limits<int>::max)();
138  }
139  centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
140  cb_index_ = get_param(params,"cb_index",0.4f);
141 
142  initCenterChooser();
143  }
144 
145 
146  KMeansIndex(const KMeansIndex& other) : BaseClass(other),
147  branching_(other.branching_),
148  iterations_(other.iterations_),
149  centers_init_(other.centers_init_),
150  cb_index_(other.cb_index_),
151  memoryCounter_(other.memoryCounter_)
152  {
153  initCenterChooser();
154 
155  copyTree(root_, other.root_);
156  }
157 
159  {
160  this->swap(other);
161  return *this;
162  }
163 
164 
166  {
167  switch(centers_init_) {
169  chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
170  break;
172  chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
173  break;
175  chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
176  break;
177  default:
178  throw FLANNException("Unknown algorithm for choosing initial centers.");
179  }
180  }
181 
187  virtual ~KMeansIndex()
188  {
189  delete chooseCenters_;
190  freeIndex();
191  }
192 
193  BaseClass* clone() const
194  {
195  return new KMeansIndex(*this);
196  }
197 
198 
199  void set_cb_index( float index)
200  {
201  cb_index_ = index;
202  }
203 
208  int usedMemory() const
209  {
210  return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
211  }
212 
213  using BaseClass::buildIndex;
214 
215  void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
216  {
217  assert(points.cols==veclen_);
218  size_t old_size = size_;
219 
220  extendDataset(points);
221 
222  if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
223  buildIndex();
224  }
225  else {
226  for (size_t i=0;i<points.rows;++i) {
227  DistanceType dist = distance_(root_->pivot, points[i], veclen_);
228  addPointToTree(root_, old_size + i, dist);
229  }
230  }
231  }
232 
233  template<typename Archive>
234  void serialize(Archive& ar)
235  {
236  ar.setObject(this);
237 
238  ar & *static_cast<NNIndex<Distance>*>(this);
239 
240  ar & branching_;
241  ar & iterations_;
242  ar & memoryCounter_;
243  ar & cb_index_;
244  ar & centers_init_;
245 
246  if (Archive::is_loading::value) {
247  root_ = new(pool_) Node();
248  }
249  ar & *root_;
250 
251  if (Archive::is_loading::value) {
252  index_params_["algorithm"] = getType();
253  index_params_["branching"] = branching_;
254  index_params_["iterations"] = iterations_;
255  index_params_["centers_init"] = centers_init_;
256  index_params_["cb_index"] = cb_index_;
257  }
258  }
259 
260  void saveIndex(FILE* stream)
261  {
262  serialization::SaveArchive sa(stream);
263  sa & *this;
264  }
265 
266  void loadIndex(FILE* stream)
267  {
268  freeIndex();
269  serialization::LoadArchive la(stream);
270  la & *this;
271  }
272 
283  void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
284  {
285  if (removed_) {
286  findNeighborsWithRemoved<true>(result, vec, searchParams);
287  }
288  else {
289  findNeighborsWithRemoved<false>(result, vec, searchParams);
290  }
291 
292  }
293 
302  {
303  int numClusters = centers.rows;
304  if (numClusters<1) {
305  throw FLANNException("Number of clusters must be at least 1");
306  }
307 
308  DistanceType variance;
309  std::vector<NodePtr> clusters(numClusters);
310 
311  int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
312 
313  Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
314 
315  for (int i=0; i<clusterCount; ++i) {
316  DistanceType* center = clusters[i]->pivot;
317  for (size_t j=0; j<veclen_; ++j) {
318  centers[i][j] = center[j];
319  }
320  }
321 
322  return clusterCount;
323  }
324 
325 protected:
330  {
331  chooseCenters_->setDataSize(veclen_);
332 
333  if (branching_<2) {
334  throw FLANNException("Branching factor must be at least 2");
335  }
336 
337  std::vector<int> indices(size_);
338  for (size_t i=0; i<size_; ++i) {
339  indices[i] = int(i);
340  }
341 
342  root_ = new(pool_) Node();
343  computeNodeStatistics(root_, indices);
344  computeClustering(root_, &indices[0], (int)size_, branching_);
345  }
346 
347 private:
348 
349  struct PointInfo
350  {
351  size_t index;
352  ElementType* point;
353  private:
354  template<typename Archive>
355  void serialize(Archive& ar)
356  {
358  Index* obj = static_cast<Index*>(ar.getObject());
359 
360  ar & index;
361 // ar & point;
362 
363  if (Archive::is_loading::value) point = obj->points_[index];
364  }
365  friend struct serialization::access;
366  };
367 
371  struct Node
372  {
376  DistanceType* pivot;
380  DistanceType radius;
384  DistanceType variance;
388  int size;
392  std::vector<Node*> childs;
396  std::vector<PointInfo> points;
400 // int level;
401 
403  {
404  delete[] pivot;
405  if (!childs.empty()) {
406  for (size_t i=0; i<childs.size(); ++i) {
407  childs[i]->~Node();
408  }
409  }
410  }
411 
412  template<typename Archive>
413  void serialize(Archive& ar)
414  {
416  Index* obj = static_cast<Index*>(ar.getObject());
417 
418  if (Archive::is_loading::value) {
419  pivot = new DistanceType[obj->veclen_];
420  }
421  ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType));
422  ar & radius;
423  ar & variance;
424  ar & size;
425 
426  size_t childs_size;
427  if (Archive::is_saving::value) {
428  childs_size = childs.size();
429  }
430  ar & childs_size;
431 
432  if (childs_size==0) {
433  ar & points;
434  }
435  else {
436  if (Archive::is_loading::value) {
437  childs.resize(childs_size);
438  }
439  for (size_t i=0;i<childs_size;++i) {
440  if (Archive::is_loading::value) {
441  childs[i] = new(obj->pool_) Node();
442  }
443  ar & *childs[i];
444  }
445  }
446  }
447  friend struct serialization::access;
448  };
449  typedef Node* NodePtr;
450 
455 
456 
460  void freeIndex()
461  {
462  if (root_) root_->~Node();
463  root_ = NULL;
464  pool_.free();
465  }
466 
467  void copyTree(NodePtr& dst, const NodePtr& src)
468  {
469  dst = new(pool_) Node();
470  dst->pivot = new DistanceType[veclen_];
471  std::copy(src->pivot, src->pivot+veclen_, dst->pivot);
472  dst->radius = src->radius;
473  dst->variance = src->variance;
474  dst->size = src->size;
475 
476  if (src->childs.size()==0) {
477  dst->points = src->points;
478  }
479  else {
480  dst->childs.resize(src->childs.size());
481  for (size_t i=0;i<src->childs.size();++i) {
482  copyTree(dst->childs[i], src->childs[i]);
483  }
484  }
485  }
486 
487 
495  void computeNodeStatistics(NodePtr node, const std::vector<int>& indices)
496  {
497  size_t size = indices.size();
498 
499  DistanceType* mean = new DistanceType[veclen_];
500  memoryCounter_ += int(veclen_*sizeof(DistanceType));
501  memset(mean,0,veclen_*sizeof(DistanceType));
502 
503  for (size_t i=0; i<size; ++i) {
504  ElementType* vec = points_[indices[i]];
505  for (size_t j=0; j<veclen_; ++j) {
506  mean[j] += vec[j];
507  }
508  }
509  DistanceType div_factor = DistanceType(1)/size;
510  for (size_t j=0; j<veclen_; ++j) {
511  mean[j] *= div_factor;
512  }
513 
514  DistanceType radius = 0;
515  DistanceType variance = 0;
516  for (size_t i=0; i<size; ++i) {
517  DistanceType dist = distance_(mean, points_[indices[i]], veclen_);
518  if (dist>radius) {
519  radius = dist;
520  }
521  variance += dist;
522  }
523  variance /= size;
524 
525  node->variance = variance;
526  node->radius = radius;
527  node->pivot = mean;
528  }
529 
530 
542  void computeClustering(NodePtr node, int* indices, int indices_length, int branching)
543  {
544  node->size = indices_length;
545 
546  if (indices_length < branching) {
547  node->points.resize(indices_length);
548  for (int i=0;i<indices_length;++i) {
549  node->points[i].index = indices[i];
550  node->points[i].point = points_[indices[i]];
551  }
552  node->childs.clear();
553  return;
554  }
555 
556  std::vector<int> centers_idx(branching);
557  int centers_length;
558  (*chooseCenters_)(branching, indices, indices_length, &centers_idx[0], centers_length);
559 
560  if (centers_length<branching) {
561  node->points.resize(indices_length);
562  for (int i=0;i<indices_length;++i) {
563  node->points[i].index = indices[i];
564  node->points[i].point = points_[indices[i]];
565  }
566  node->childs.clear();
567  return;
568  }
569 
570 
571  Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
572  for (int i=0; i<centers_length; ++i) {
573  ElementType* vec = points_[centers_idx[i]];
574  for (size_t k=0; k<veclen_; ++k) {
575  dcenters[i][k] = double(vec[k]);
576  }
577  }
578 
579  std::vector<DistanceType> radiuses(branching,0);
580  std::vector<int> count(branching,0);
581 
582  // assign points to clusters
583  std::vector<int> belongs_to(indices_length);
584  for (int i=0; i<indices_length; ++i) {
585 
586  DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
587  belongs_to[i] = 0;
588  for (int j=1; j<branching; ++j) {
589  DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
590  if (sq_dist>new_sq_dist) {
591  belongs_to[i] = j;
592  sq_dist = new_sq_dist;
593  }
594  }
595  if (sq_dist>radiuses[belongs_to[i]]) {
596  radiuses[belongs_to[i]] = sq_dist;
597  }
598  count[belongs_to[i]]++;
599  }
600 
601  bool converged = false;
602  int iteration = 0;
603  while (!converged && iteration<iterations_) {
604  converged = true;
605  iteration++;
606 
607  // compute the new cluster centers
608  for (int i=0; i<branching; ++i) {
609  memset(dcenters[i],0,sizeof(double)*veclen_);
610  radiuses[i] = 0;
611  }
612  for (int i=0; i<indices_length; ++i) {
613  ElementType* vec = points_[indices[i]];
614  double* center = dcenters[belongs_to[i]];
615  for (size_t k=0; k<veclen_; ++k) {
616  center[k] += vec[k];
617  }
618  }
619  for (int i=0; i<branching; ++i) {
620  int cnt = count[i];
621  double div_factor = 1.0/cnt;
622  for (size_t k=0; k<veclen_; ++k) {
623  dcenters[i][k] *= div_factor;
624  }
625  }
626 
627  // reassign points to clusters
628  for (int i=0; i<indices_length; ++i) {
629  DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
630  int new_centroid = 0;
631  for (int j=1; j<branching; ++j) {
632  DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
633  if (sq_dist>new_sq_dist) {
634  new_centroid = j;
635  sq_dist = new_sq_dist;
636  }
637  }
638  if (sq_dist>radiuses[new_centroid]) {
639  radiuses[new_centroid] = sq_dist;
640  }
641  if (new_centroid != belongs_to[i]) {
642  count[belongs_to[i]]--;
643  count[new_centroid]++;
644  belongs_to[i] = new_centroid;
645 
646  converged = false;
647  }
648  }
649 
650  for (int i=0; i<branching; ++i) {
651  // if one cluster converges to an empty cluster,
652  // move an element into that cluster
653  if (count[i]==0) {
654  int j = (i+1)%branching;
655  while (count[j]<=1) {
656  j = (j+1)%branching;
657  }
658 
659  for (int k=0; k<indices_length; ++k) {
660  if (belongs_to[k]==j) {
661  belongs_to[k] = i;
662  count[j]--;
663  count[i]++;
664  break;
665  }
666  }
667  converged = false;
668  }
669  }
670 
671  }
672 
673  std::vector<DistanceType*> centers(branching);
674 
675  for (int i=0; i<branching; ++i) {
676  centers[i] = new DistanceType[veclen_];
677  memoryCounter_ += veclen_*sizeof(DistanceType);
678  for (size_t k=0; k<veclen_; ++k) {
679  centers[i][k] = (DistanceType)dcenters[i][k];
680  }
681  }
682 
683 
684  // compute kmeans clustering for each of the resulting clusters
685  node->childs.resize(branching);
686  int start = 0;
687  int end = start;
688  for (int c=0; c<branching; ++c) {
689  int s = count[c];
690 
691  DistanceType variance = 0;
692  for (int i=0; i<indices_length; ++i) {
693  if (belongs_to[i]==c) {
694  variance += distance_(centers[c], points_[indices[i]], veclen_);
695  std::swap(indices[i],indices[end]);
696  std::swap(belongs_to[i],belongs_to[end]);
697  end++;
698  }
699  }
700  variance /= s;
701 
702  node->childs[c] = new(pool_) Node();
703  node->childs[c]->radius = radiuses[c];
704  node->childs[c]->pivot = centers[c];
705  node->childs[c]->variance = variance;
706  computeClustering(node->childs[c],indices+start, end-start, branching);
707  start=end;
708  }
709 
710  delete[] dcenters.ptr();
711  }
712 
713 
714  template<bool with_removed>
715  void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
716  {
717 
718  int maxChecks = searchParams.checks;
719 
720  if (maxChecks==FLANN_CHECKS_UNLIMITED) {
721  findExactNN<with_removed>(root_, result, vec);
722  }
723  else {
724  // Priority queue storing intermediate branches in the best-bin-first search
725  Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
726 
727  int checks = 0;
728  findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
729 
730  BranchSt branch;
731  while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
732  NodePtr node = branch.node;
733  findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
734  }
735 
736  delete heap;
737  }
738 
739  }
740 
741 
754  template<bool with_removed>
755  void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
756  Heap<BranchSt>* heap) const
757  {
758  // Ignore those clusters that are too far away
759  {
760  DistanceType bsq = distance_(vec, node->pivot, veclen_);
761  DistanceType rsq = node->radius;
762  DistanceType wsq = result.worstDist();
763 
764  DistanceType val = bsq-rsq-wsq;
765  DistanceType val2 = val*val-4*rsq*wsq;
766 
767  //if (val>0) {
768  if ((val>0)&&(val2>0)) {
769  return;
770  }
771  }
772 
773  if (node->childs.empty()) {
774  if (checks>=maxChecks) {
775  if (result.full()) return;
776  }
777  for (int i=0; i<node->size; ++i) {
778  PointInfo& point_info = node->points[i];
779  int index = point_info.index;
780  if (with_removed) {
781  if (removed_points_.test(index)) continue;
782  }
783  DistanceType dist = distance_(point_info.point, vec, veclen_);
784  result.addPoint(dist, index);
785  ++checks;
786  }
787  }
788  else {
789  int closest_center = exploreNodeBranches(node, vec, heap);
790  findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
791  }
792  }
793 
802  int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const
803  {
804  std::vector<DistanceType> domain_distances(branching_);
805  int best_index = 0;
806  domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_);
807  for (int i=1; i<branching_; ++i) {
808  domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_);
809  if (domain_distances[i]<domain_distances[best_index]) {
810  best_index = i;
811  }
812  }
813 
814  // float* best_center = node->childs[best_index]->pivot;
815  for (int i=0; i<branching_; ++i) {
816  if (i != best_index) {
817  domain_distances[i] -= cb_index_*node->childs[i]->variance;
818 
819  // float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q);
820  // if (domain_distances[i]<dist_to_border) {
821  // domain_distances[i] = dist_to_border;
822  // }
823  heap->insert(BranchSt(node->childs[i],domain_distances[i]));
824  }
825  }
826 
827  return best_index;
828  }
829 
830 
834  template<bool with_removed>
835  void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const
836  {
837  // Ignore those clusters that are too far away
838  {
839  DistanceType bsq = distance_(vec, node->pivot, veclen_);
840  DistanceType rsq = node->radius;
841  DistanceType wsq = result.worstDist();
842 
843  DistanceType val = bsq-rsq-wsq;
844  DistanceType val2 = val*val-4*rsq*wsq;
845 
846  // if (val>0) {
847  if ((val>0)&&(val2>0)) {
848  return;
849  }
850  }
851 
852  if (node->childs.empty()) {
853  for (int i=0; i<node->size; ++i) {
854  PointInfo& point_info = node->points[i];
855  int index = point_info.index;
856  if (with_removed) {
857  if (removed_points_.test(index)) continue;
858  }
859  DistanceType dist = distance_(point_info.point, vec, veclen_);
860  result.addPoint(dist, index);
861  }
862  }
863  else {
864  std::vector<int> sort_indices(branching_);
865  getCenterOrdering(node, vec, sort_indices);
866 
867  for (int i=0; i<branching_; ++i) {
868  findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec);
869  }
870 
871  }
872  }
873 
874 
880  void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const
881  {
882  std::vector<DistanceType> domain_distances(branching_);
883  for (int i=0; i<branching_; ++i) {
884  DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_);
885 
886  int j=0;
887  while (domain_distances[j]<dist && j<i) j++;
888  for (int k=i; k>j; --k) {
889  domain_distances[k] = domain_distances[k-1];
890  sort_indices[k] = sort_indices[k-1];
891  }
892  domain_distances[j] = dist;
893  sort_indices[j] = i;
894  }
895  }
896 
902  DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) const
903  {
904  DistanceType sum = 0;
905  DistanceType sum2 = 0;
906 
907  for (int i=0; i<veclen_; ++i) {
908  DistanceType t = c[i]-p[i];
909  sum += t*(q[i]-(c[i]+p[i])/2);
910  sum2 += t*t;
911  }
912 
913  return sum*sum/sum2;
914  }
915 
916 
926  int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const
927  {
928  int clusterCount = 1;
929  clusters[0] = root;
930 
931  DistanceType meanVariance = root->variance*root->size;
932 
933  while (clusterCount<clusters_length) {
934  DistanceType minVariance = (std::numeric_limits<DistanceType>::max)();
935  int splitIndex = -1;
936 
937  for (int i=0; i<clusterCount; ++i) {
938  if (!clusters[i]->childs.empty()) {
939 
940  DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
941 
942  for (int j=0; j<branching_; ++j) {
943  variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
944  }
945  if (variance<minVariance) {
946  minVariance = variance;
947  splitIndex = i;
948  }
949  }
950  }
951 
952  if (splitIndex==-1) break;
953  if ( (branching_+clusterCount-1) > clusters_length) break;
954 
955  meanVariance = minVariance;
956 
957  // split node
958  NodePtr toSplit = clusters[splitIndex];
959  clusters[splitIndex] = toSplit->childs[0];
960  for (int i=1; i<branching_; ++i) {
961  clusters[clusterCount++] = toSplit->childs[i];
962  }
963  }
964 
965  varianceValue = meanVariance/root->size;
966  return clusterCount;
967  }
968 
969  void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
970  {
971  ElementType* point = points_[index];
972  if (dist_to_pivot>node->radius) {
973  node->radius = dist_to_pivot;
974  }
975  // if radius changed above, the variance will be an approximation
976  node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
977  node->size++;
978 
979  if (node->childs.empty()) { // leaf node
980  PointInfo point_info;
981  point_info.index = index;
982  point_info.point = point;
983  node->points.push_back(point_info);
984 
985  std::vector<int> indices(node->points.size());
986  for (size_t i=0;i<node->points.size();++i) {
987  indices[i] = node->points[i].index;
988  }
989  computeNodeStatistics(node, indices);
990  if (indices.size()>=size_t(branching_)) {
991  computeClustering(node, &indices[0], indices.size(), branching_);
992  }
993  }
994  else {
995  // find the closest child
996  int closest = 0;
997  DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_);
998  for (size_t i=1;i<size_t(branching_);++i) {
999  DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_);
1000  if (crt_dist<dist) {
1001  dist = crt_dist;
1002  closest = i;
1003  }
1004  }
1005  addPointToTree(node->childs[closest], index, dist);
1006  }
1007  }
1008 
1009 
1010  void swap(KMeansIndex& other)
1011  {
1012  std::swap(branching_, other.branching_);
1013  std::swap(iterations_, other.iterations_);
1014  std::swap(centers_init_, other.centers_init_);
1015  std::swap(cb_index_, other.cb_index_);
1016  std::swap(root_, other.root_);
1017  std::swap(pool_, other.pool_);
1018  std::swap(memoryCounter_, other.memoryCounter_);
1019  std::swap(chooseCenters_, other.chooseCenters_);
1020  }
1021 
1022 
1023 private:
1026 
1029 
1032 
1039  float cb_index_;
1040 
1044  NodePtr root_;
1045 
1050 
1055 
1060 
1062 };
1063 
1064 }
1065 
1066 #endif //FLANN_KMEANS_INDEX_H_
void getCenterOrdering(NodePtr node, const ElementType *q, std::vector< int > &sort_indices) const
Definition: kmeans_index.h:880
d
#define NULL
std::map< std::string, any > IndexParams
Definition: params.h:51
Distance::ResultType DistanceType
Definition: kmeans_index.h:87
void findExactNN(NodePtr node, ResultSet< DistanceType > &result, const ElementType *vec) const
Definition: kmeans_index.h:835
flann_centers_init_t
Definition: defines.h:95
T get_param(const IndexParams &params, std::string name, const T &default_value)
Definition: params.h:95
flann_centers_init_t centers_init_
void findNN(NodePtr node, ResultSet< DistanceType > &result, const ElementType *vec, int &checks, int maxChecks, Heap< BranchSt > *heap) const
Definition: kmeans_index.h:755
int exploreNodeBranches(NodePtr node, const ElementType *q, Heap< BranchSt > *heap) const
Definition: kmeans_index.h:802
KMeansIndex & operator=(KMeansIndex other)
Definition: kmeans_index.h:158
f
int getClusterCenters(Matrix< DistanceType > &centers)
Definition: kmeans_index.h:301
char * dst
Definition: lz4.h:354
std::vector< Node * > childs
Definition: kmeans_index.h:392
void swap(linb::any &lhs, linb::any &rhs) noexcept
size_t rows
Definition: matrix.h:72
int getMinVarianceClusters(NodePtr root, std::vector< NodePtr > &clusters, int clusters_length, DistanceType &varianceValue) const
Definition: kmeans_index.h:926
void insert(const T &value)
Definition: heap.h:135
#define USING_BASECLASS_SYMBOLS
Definition: nn_index.h:897
void serialize(Archive &ar)
Definition: kmeans_index.h:413
void computeClustering(NodePtr node, int *indices, int indices_length, int branching)
Definition: kmeans_index.h:542
void swap(KMeansIndex &other)
T * ptr() const
Definition: matrix.h:127
size_t cols
Definition: matrix.h:73
void saveIndex(FILE *stream)
Definition: kmeans_index.h:260
KMeansIndex(const KMeansIndex &other)
Definition: kmeans_index.h:146
void loadIndex(FILE *stream)
Definition: kmeans_index.h:266
flann_algorithm_t
Definition: defines.h:79
bool popMin(T &value)
Definition: heap.h:158
int usedMemory() const
Definition: kmeans_index.h:208
void findNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
Definition: kmeans_index.h:283
const binary_object make_binary_object(void *t, size_t size)
PooledAllocator pool_
std::vector< PointInfo > points
Definition: kmeans_index.h:396
Distance::ElementType ElementType
Definition: kmeans_index.h:86
struct Index Index
Definition: sqlite3.c:8577
KMeansIndex(const IndexParams &params=KMeansIndexParams(), Distance d=Distance())
Definition: kmeans_index.h:131
params
void addPoints(const Matrix< ElementType > &points, float rebuild_threshold=2)
Incrementally add points to the index.
Definition: kmeans_index.h:215
virtual DistanceType worstDist() const =0
flann_algorithm_t getType() const
Definition: kmeans_index.h:95
void copyTree(NodePtr &dst, const NodePtr &src)
Definition: kmeans_index.h:467
void set_cb_index(float index)
Definition: kmeans_index.h:199
GLM_FUNC_DECL genType max(genType const &x, genType const &y)
virtual bool full() const =0
CenterChooser< Distance > * chooseCenters_
NNIndex< Distance > BaseClass
Definition: kmeans_index.h:89
dist
void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
Definition: kmeans_index.h:969
KMeansIndexParams(int branching=32, int iterations=11, flann_centers_init_t centers_init=FLANN_CENTERS_RANDOM, float cb_index=0.2)
Definition: kmeans_index.h:60
static void freeIndex(sqlite3 *db, Index *p)
Definition: sqlite3.c:84627
void findNeighborsWithRemoved(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
Definition: kmeans_index.h:715
void computeNodeStatistics(NodePtr node, const std::vector< int > &indices)
Definition: kmeans_index.h:495
DistanceType getDistanceToBorder(DistanceType *p, DistanceType *c, DistanceType *q) const
Definition: kmeans_index.h:902
bool needs_vector_space_distance
Definition: kmeans_index.h:91
void serialize(Archive &ar)
Definition: kmeans_index.h:234
BranchStruct< NodePtr, DistanceType > BranchSt
Definition: kmeans_index.h:454
KMeansIndex(const Matrix< ElementType > &inputData, const IndexParams &params=KMeansIndexParams(), Distance d=Distance())
Definition: kmeans_index.h:107
virtual void addPoint(DistanceType dist, size_t index)=0
BaseClass * clone() const
Definition: kmeans_index.h:193
end


rtabmap
Author(s): Mathieu Labbe
autogenerated on Mon Jan 23 2023 03:37:28