kmeans_index.h
Go to the documentation of this file.
1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  * notice, this list of conditions and the following disclaimer in the
17  * documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30 
31 #ifndef RTABMAP_FLANN_KMEANS_INDEX_H_
32 #define RTABMAP_FLANN_KMEANS_INDEX_H_
33 
34 #include <algorithm>
35 #include <string>
36 #include <map>
37 #include <cassert>
38 #include <limits>
39 #include <cmath>
40 
41 #include "rtflann/general.h"
45 #include "rtflann/util/matrix.h"
47 #include "rtflann/util/heap.h"
48 #include "rtflann/util/allocator.h"
49 #include "rtflann/util/random.h"
50 #include "rtflann/util/saving.h"
51 #include "rtflann/util/logger.h"
52 
53 
54 
55 namespace rtflann
56 {
57 
58 struct KMeansIndexParams : public IndexParams
59 {
60  KMeansIndexParams(int branching = 32, int iterations = 11,
61  flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 )
62  {
63  (*this)["algorithm"] = FLANN_INDEX_KMEANS;
64  // branching factor
65  (*this)["branching"] = branching;
66  // max iterations to perform in one kmeans clustering (kmeans tree)
67  (*this)["iterations"] = iterations;
68  // algorithm used for picking the initial cluster centers for kmeans tree
69  (*this)["centers_init"] = centers_init;
70  // cluster boundary index. Used when searching the kmeans tree
71  (*this)["cb_index"] = cb_index;
72  }
73 };
74 
75 
82 template <typename Distance>
83 class KMeansIndex : public NNIndex<Distance>
84 {
85 public:
86  typedef typename Distance::ElementType ElementType;
87  typedef typename Distance::ResultType DistanceType;
88 
90 
91  typedef bool needs_vector_space_distance;
92 
93 
94 
96  {
97  return FLANN_INDEX_KMEANS;
98  }
99 
107  KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(),
108  Distance d = Distance())
110  {
111  branching_ = get_param(params,"branching",32);
112  iterations_ = get_param(params,"iterations",11);
113  if (iterations_<0) {
115  }
116  centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
117  cb_index_ = get_param(params,"cb_index",0.4f);
118 
120  setDataset(inputData);
121  }
122 
123 
131  KMeansIndex(const IndexParams& params = KMeansIndexParams(), Distance d = Distance())
133  {
134  branching_ = get_param(params,"branching",32);
135  iterations_ = get_param(params,"iterations",11);
136  if (iterations_<0) {
138  }
139  centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
140  cb_index_ = get_param(params,"cb_index",0.4f);
141 
143  }
144 
145 
146  KMeansIndex(const KMeansIndex& other) : BaseClass(other),
152  {
154 
155  copyTree(root_, other.root_);
156  }
157 
159  {
160  this->swap(other);
161  return *this;
162  }
163 
164 
165  void initCenterChooser()
166  {
167  switch(centers_init_) {
170  break;
173  break;
176  break;
177  default:
178  throw FLANNException("Unknown algorithm for choosing initial centers.");
179  }
180  }
181 
187  virtual ~KMeansIndex()
188  {
189  delete chooseCenters_;
190  freeIndex();
191  }
192 
193  BaseClass* clone() const
194  {
195  return new KMeansIndex(*this);
196  }
197 
198 
199  void set_cb_index( float index)
200  {
201  cb_index_ = index;
202  }
203 
208  int usedMemory() const
209  {
211  }
212 
213  using BaseClass::buildIndex;
214 
215  void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
216  {
217  assert(points.cols==veclen_);
218  size_t old_size = size_;
219 
220  extendDataset(points);
221 
222  if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
223  buildIndex();
224  }
225  else {
226  for (size_t i=0;i<points.rows;++i) {
228  addPointToTree(root_, old_size + i, dist);
229  }
230  }
231  }
232 
233  template<typename Archive>
234  void serialize(Archive& ar)
235  {
236  ar.setObject(this);
237 
238  ar & *static_cast<NNIndex<Distance>*>(this);
239 
240  ar & branching_;
241  ar & iterations_;
242  ar & memoryCounter_;
243  ar & cb_index_;
244  ar & centers_init_;
245 
246  if (Archive::is_loading::value) {
247  root_ = new(pool_) Node();
248  }
249  ar & *root_;
250 
251  if (Archive::is_loading::value) {
252  index_params_["algorithm"] = getType();
253  index_params_["branching"] = branching_;
254  index_params_["iterations"] = iterations_;
255  index_params_["centers_init"] = centers_init_;
256  index_params_["cb_index"] = cb_index_;
257  }
258  }
259 
260  void saveIndex(FILE* stream)
261  {
263  sa & *this;
264  }
265 
266  void loadIndex(FILE* stream)
267  {
268  freeIndex();
269  serialization::LoadArchive la(stream);
270  la & *this;
271  }
272 
283  void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
284  {
285  if (removed_) {
286  findNeighborsWithRemoved<true>(result, vec, searchParams);
287  }
288  else {
289  findNeighborsWithRemoved<false>(result, vec, searchParams);
290  }
291 
292  }
293 
302  {
303  int numClusters = centers.rows;
304  if (numClusters<1) {
305  throw FLANNException("Number of clusters must be at least 1");
306  }
307 
308  DistanceType variance;
309  std::vector<NodePtr> clusters(numClusters);
310 
311  int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
312 
313  Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
314 
315  for (int i=0; i<clusterCount; ++i) {
316  DistanceType* center = clusters[i]->pivot;
317  for (size_t j=0; j<veclen_; ++j) {
318  centers[i][j] = center[j];
319  }
320  }
321 
322  return clusterCount;
323  }
324 
325 protected:
330  {
331  chooseCenters_->setDataSize(veclen_);
332 
333  if (branching_<2) {
334  throw FLANNException("Branching factor must be at least 2");
335  }
336 
337  std::vector<int> indices(size_);
338  for (size_t i=0; i<size_; ++i) {
339  indices[i] = int(i);
340  }
341 
342  root_ = new(pool_) Node();
345  }
346 
347 private:
348 
349  struct PointInfo
350  {
351  size_t index;
353  private:
354  template<typename Archive>
355  void serialize(Archive& ar)
356  {
358  Index* obj = static_cast<Index*>(ar.getObject());
359 
360  ar & index;
361 // ar & point;
362 
363  if (Archive::is_loading::value) point = obj->points_[index];
364  }
365  friend struct serialization::access;
366  };
367 
371  struct Node
372  {
388  int size;
392  std::vector<Node*> childs;
396  std::vector<PointInfo> points;
400 // int level;
401 
402  ~Node()
403  {
404  delete[] pivot;
405  if (!childs.empty()) {
406  for (size_t i=0; i<childs.size(); ++i) {
407  childs[i]->~Node();
408  }
409  }
410  }
411 
412  template<typename Archive>
413  void serialize(Archive& ar)
414  {
416  Index* obj = static_cast<Index*>(ar.getObject());
417 
418  if (Archive::is_loading::value) {
419  pivot = new DistanceType[obj->veclen_];
420  }
421  ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType));
422  ar & radius;
423  ar & variance;
424  ar & size;
425 
426  size_t childs_size;
427  if (Archive::is_saving::value) {
428  childs_size = childs.size();
429  }
430  ar & childs_size;
431 
432  if (childs_size==0) {
433  ar & points;
434  }
435  else {
436  if (Archive::is_loading::value) {
437  childs.resize(childs_size);
438  }
439  for (size_t i=0;i<childs_size;++i) {
440  if (Archive::is_loading::value) {
441  childs[i] = new(obj->pool_) Node();
442  }
443  ar & *childs[i];
444  }
445  }
446  }
447  friend struct serialization::access;
448  };
449  typedef Node* NodePtr;
450 
454  typedef BranchStruct<NodePtr, DistanceType> BranchSt;
455 
456 
460  void freeIndex()
461  {
462  if (root_) root_->~Node();
463  root_ = NULL;
464  pool_.free();
465  }
466 
467  void copyTree(NodePtr& dst, const NodePtr& src)
468  {
469  dst = new(pool_) Node();
470  dst->pivot = new DistanceType[veclen_];
471  std::copy(src->pivot, src->pivot+veclen_, dst->pivot);
472  dst->radius = src->radius;
473  dst->variance = src->variance;
474  dst->size = src->size;
475 
476  if (src->childs.size()==0) {
477  dst->points = src->points;
478  }
479  else {
480  dst->childs.resize(src->childs.size());
481  for (size_t i=0;i<src->childs.size();++i) {
482  copyTree(dst->childs[i], src->childs[i]);
483  }
484  }
485  }
486 
487 
495  void computeNodeStatistics(NodePtr node, const std::vector<int>& indices)
496  {
497  size_t size = indices.size();
498 
501  memset(mean,0,veclen_*sizeof(DistanceType));
502 
503  for (size_t i=0; i<size; ++i) {
504  ElementType* vec = points_[indices[i]];
505  for (size_t j=0; j<veclen_; ++j) {
506  mean[j] += vec[j];
507  }
508  }
509  DistanceType div_factor = DistanceType(1)/size;
510  for (size_t j=0; j<veclen_; ++j) {
511  mean[j] *= div_factor;
512  }
513 
514  DistanceType radius = 0;
515  DistanceType variance = 0;
516  for (size_t i=0; i<size; ++i) {
518  if (dist>radius) {
519  radius = dist;
520  }
521  variance += dist;
522  }
523  variance /= size;
524 
525  node->variance = variance;
526  node->radius = radius;
527  node->pivot = mean;
528  }
529 
530 
542  void computeClustering(NodePtr node, int* indices, int indices_length, int branching)
543  {
544  node->size = indices_length;
545 
546  if (indices_length < branching) {
547  node->points.resize(indices_length);
548  for (int i=0;i<indices_length;++i) {
549  node->points[i].index = indices[i];
550  node->points[i].point = points_[indices[i]];
551  }
552  node->childs.clear();
553  return;
554  }
555 
556  std::vector<int> centers_idx(branching);
557  int centers_length;
558  (*chooseCenters_)(branching, indices, indices_length, &centers_idx[0], centers_length);
559 
560  if (centers_length<branching) {
561  node->points.resize(indices_length);
562  for (int i=0;i<indices_length;++i) {
563  node->points[i].index = indices[i];
564  node->points[i].point = points_[indices[i]];
565  }
566  node->childs.clear();
567  return;
568  }
569 
570 
571  Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
572  for (int i=0; i<centers_length; ++i) {
573  ElementType* vec = points_[centers_idx[i]];
574  for (size_t k=0; k<veclen_; ++k) {
575  dcenters[i][k] = double(vec[k]);
576  }
577  }
578 
579  std::vector<DistanceType> radiuses(branching,0);
580  std::vector<int> count(branching,0);
581 
582  // assign points to clusters
583  std::vector<int> belongs_to(indices_length);
584  for (int i=0; i<indices_length; ++i) {
585 
586  DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
587  belongs_to[i] = 0;
588  for (int j=1; j<branching; ++j) {
589  DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
590  if (sq_dist>new_sq_dist) {
591  belongs_to[i] = j;
592  sq_dist = new_sq_dist;
593  }
594  }
595  if (sq_dist>radiuses[belongs_to[i]]) {
596  radiuses[belongs_to[i]] = sq_dist;
597  }
598  count[belongs_to[i]]++;
599  }
600 
601  bool converged = false;
602  int iteration = 0;
603  while (!converged && iteration<iterations_) {
604  converged = true;
605  iteration++;
606 
607  // compute the new cluster centers
608  for (int i=0; i<branching; ++i) {
609  memset(dcenters[i],0,sizeof(double)*veclen_);
610  radiuses[i] = 0;
611  }
612  for (int i=0; i<indices_length; ++i) {
613  ElementType* vec = points_[indices[i]];
614  double* center = dcenters[belongs_to[i]];
615  for (size_t k=0; k<veclen_; ++k) {
616  center[k] += vec[k];
617  }
618  }
619  for (int i=0; i<branching; ++i) {
620  int cnt = count[i];
621  double div_factor = 1.0/cnt;
622  for (size_t k=0; k<veclen_; ++k) {
623  dcenters[i][k] *= div_factor;
624  }
625  }
626 
627  // reassign points to clusters
628  for (int i=0; i<indices_length; ++i) {
629  DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
630  int new_centroid = 0;
631  for (int j=1; j<branching; ++j) {
632  DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
633  if (sq_dist>new_sq_dist) {
634  new_centroid = j;
635  sq_dist = new_sq_dist;
636  }
637  }
638  if (sq_dist>radiuses[new_centroid]) {
639  radiuses[new_centroid] = sq_dist;
640  }
641  if (new_centroid != belongs_to[i]) {
642  count[belongs_to[i]]--;
643  count[new_centroid]++;
644  belongs_to[i] = new_centroid;
645 
646  converged = false;
647  }
648  }
649 
650  for (int i=0; i<branching; ++i) {
651  // if one cluster converges to an empty cluster,
652  // move an element into that cluster
653  if (count[i]==0) {
654  int j = (i+1)%branching;
655  while (count[j]<=1) {
656  j = (j+1)%branching;
657  }
658 
659  for (int k=0; k<indices_length; ++k) {
660  if (belongs_to[k]==j) {
661  belongs_to[k] = i;
662  count[j]--;
663  count[i]++;
664  break;
665  }
666  }
667  converged = false;
668  }
669  }
670 
671  }
672 
673  std::vector<DistanceType*> centers(branching);
674 
675  for (int i=0; i<branching; ++i) {
676  centers[i] = new DistanceType[veclen_];
678  for (size_t k=0; k<veclen_; ++k) {
679  centers[i][k] = (DistanceType)dcenters[i][k];
680  }
681  }
682 
683 
684  // compute kmeans clustering for each of the resulting clusters
685  node->childs.resize(branching);
686  int start = 0;
687  int end = start;
688  for (int c=0; c<branching; ++c) {
689  int s = count[c];
690 
691  DistanceType variance = 0;
692  for (int i=0; i<indices_length; ++i) {
693  if (belongs_to[i]==c) {
694  variance += distance_(centers[c], points_[indices[i]], veclen_);
695  std::swap(indices[i],indices[end]);
696  std::swap(belongs_to[i],belongs_to[end]);
697  end++;
698  }
699  }
700  variance /= s;
701 
702  node->childs[c] = new(pool_) Node();
703  node->childs[c]->radius = radiuses[c];
704  node->childs[c]->pivot = centers[c];
705  node->childs[c]->variance = variance;
706  computeClustering(node->childs[c],indices+start, end-start, branching);
707  start=end;
708  }
709 
710  delete[] dcenters.ptr();
711  }
712 
713 
714  template<bool with_removed>
715  void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
716  {
717 
718  int maxChecks = searchParams.checks;
719 
720  if (maxChecks==FLANN_CHECKS_UNLIMITED) {
721  findExactNN<with_removed>(root_, result, vec);
722  }
723  else {
724  // Priority queue storing intermediate branches in the best-bin-first search
725  Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
726 
727  int checks = 0;
728  findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
729 
730  BranchSt branch;
731  while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
732  NodePtr node = branch.node;
733  findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
734  }
735 
736  delete heap;
737  }
738 
739  }
740 
741 
754  template<bool with_removed>
755  void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
756  Heap<BranchSt>* heap) const
757  {
758  // Ignore those clusters that are too far away
759  {
760  DistanceType bsq = distance_(vec, node->pivot, veclen_);
761  DistanceType rsq = node->radius;
762  DistanceType wsq = result.worstDist();
763 
764  DistanceType val = bsq-rsq-wsq;
765  DistanceType val2 = val*val-4*rsq*wsq;
766 
767  //if (val>0) {
768  if ((val>0)&&(val2>0)) {
769  return;
770  }
771  }
772 
773  if (node->childs.empty()) {
774  if (checks>=maxChecks) {
775  if (result.full()) return;
776  }
777  for (int i=0; i<node->size; ++i) {
778  PointInfo& point_info = node->points[i];
779  int index = point_info.index;
780  if (with_removed) {
781  if (removed_points_.test(index)) continue;
782  }
783  DistanceType dist = distance_(point_info.point, vec, veclen_);
784  result.addPoint(dist, index);
785  ++checks;
786  }
787  }
788  else {
789  int closest_center = exploreNodeBranches(node, vec, heap);
790  findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
791  }
792  }
793 
802  int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const
803  {
804  std::vector<DistanceType> domain_distances(branching_);
805  int best_index = 0;
806  domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_);
807  for (int i=1; i<branching_; ++i) {
808  domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_);
809  if (domain_distances[i]<domain_distances[best_index]) {
810  best_index = i;
811  }
812  }
813 
814  // float* best_center = node->childs[best_index]->pivot;
815  for (int i=0; i<branching_; ++i) {
816  if (i != best_index) {
817  domain_distances[i] -= cb_index_*node->childs[i]->variance;
818 
819  // float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q);
820  // if (domain_distances[i]<dist_to_border) {
821  // domain_distances[i] = dist_to_border;
822  // }
823  heap->insert(BranchSt(node->childs[i],domain_distances[i]));
824  }
825  }
826 
827  return best_index;
828  }
829 
830 
834  template<bool with_removed>
835  void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const
836  {
837  // Ignore those clusters that are too far away
838  {
839  DistanceType bsq = distance_(vec, node->pivot, veclen_);
840  DistanceType rsq = node->radius;
841  DistanceType wsq = result.worstDist();
842 
843  DistanceType val = bsq-rsq-wsq;
844  DistanceType val2 = val*val-4*rsq*wsq;
845 
846  // if (val>0) {
847  if ((val>0)&&(val2>0)) {
848  return;
849  }
850  }
851 
852  if (node->childs.empty()) {
853  for (int i=0; i<node->size; ++i) {
854  PointInfo& point_info = node->points[i];
855  int index = point_info.index;
856  if (with_removed) {
857  if (removed_points_.test(index)) continue;
858  }
859  DistanceType dist = distance_(point_info.point, vec, veclen_);
860  result.addPoint(dist, index);
861  }
862  }
863  else {
864  std::vector<int> sort_indices(branching_);
865  getCenterOrdering(node, vec, sort_indices);
866 
867  for (int i=0; i<branching_; ++i) {
868  findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec);
869  }
870 
871  }
872  }
873 
874 
880  void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const
881  {
882  std::vector<DistanceType> domain_distances(branching_);
883  for (int i=0; i<branching_; ++i) {
884  DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_);
885 
886  int j=0;
887  while (domain_distances[j]<dist && j<i) j++;
888  for (int k=i; k>j; --k) {
889  domain_distances[k] = domain_distances[k-1];
890  sort_indices[k] = sort_indices[k-1];
891  }
892  domain_distances[j] = dist;
893  sort_indices[j] = i;
894  }
895  }
896 
903  {
904  DistanceType sum = 0;
905  DistanceType sum2 = 0;
906 
907  for (int i=0; i<veclen_; ++i) {
909  sum += t*(q[i]-(c[i]+p[i])/2);
910  sum2 += t*t;
911  }
912 
913  return sum*sum/sum2;
914  }
915 
916 
926  int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const
927  {
928  int clusterCount = 1;
929  clusters[0] = root;
930 
931  DistanceType meanVariance = root->variance*root->size;
932 
933  while (clusterCount<clusters_length) {
935  int splitIndex = -1;
936 
937  for (int i=0; i<clusterCount; ++i) {
938  if (!clusters[i]->childs.empty()) {
939 
940  DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
941 
942  for (int j=0; j<branching_; ++j) {
943  variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
944  }
945  if (variance<minVariance) {
946  minVariance = variance;
947  splitIndex = i;
948  }
949  }
950  }
951 
952  if (splitIndex==-1) break;
953  if ( (branching_+clusterCount-1) > clusters_length) break;
954 
955  meanVariance = minVariance;
956 
957  // split node
958  NodePtr toSplit = clusters[splitIndex];
959  clusters[splitIndex] = toSplit->childs[0];
960  for (int i=1; i<branching_; ++i) {
961  clusters[clusterCount++] = toSplit->childs[i];
962  }
963  }
964 
965  varianceValue = meanVariance/root->size;
966  return clusterCount;
967  }
968 
969  void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
970  {
971  ElementType* point = points_[index];
972  if (dist_to_pivot>node->radius) {
973  node->radius = dist_to_pivot;
974  }
975  // if radius changed above, the variance will be an approximation
976  node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
977  node->size++;
978 
979  if (node->childs.empty()) { // leaf node
980  PointInfo point_info;
981  point_info.index = index;
982  point_info.point = point;
983  node->points.push_back(point_info);
984 
985  std::vector<int> indices(node->points.size());
986  for (size_t i=0;i<node->points.size();++i) {
987  indices[i] = node->points[i].index;
988  }
989  computeNodeStatistics(node, indices);
990  if (indices.size()>=size_t(branching_)) {
991  computeClustering(node, &indices[0], indices.size(), branching_);
992  }
993  }
994  else {
995  // find the closest child
996  int closest = 0;
997  DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_);
998  for (size_t i=1;i<size_t(branching_);++i) {
999  DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_);
1000  if (crt_dist<dist) {
1001  dist = crt_dist;
1002  closest = i;
1003  }
1004  }
1005  addPointToTree(node->childs[closest], index, dist);
1006  }
1007  }
1008 
1009 
1010  void swap(KMeansIndex& other)
1011  {
1012  std::swap(branching_, other.branching_);
1013  std::swap(iterations_, other.iterations_);
1014  std::swap(centers_init_, other.centers_init_);
1015  std::swap(cb_index_, other.cb_index_);
1016  std::swap(root_, other.root_);
1017  std::swap(pool_, other.pool_);
1018  std::swap(memoryCounter_, other.memoryCounter_);
1019  std::swap(chooseCenters_, other.chooseCenters_);
1020  }
1021 
1022 
1023 private:
1025  int branching_;
1026 
1028  int iterations_;
1029 
1032 
1039  float cb_index_;
1040 
1044  NodePtr root_;
1045 
1049  PooledAllocator pool_;
1050 
1054  int memoryCounter_;
1055 
1060 
1062 };
1063 
1064 }
1065 
1066 #endif //FLANN_KMEANS_INDEX_H_
rtflann::NNIndex::size
size_t size() const
Definition: nn_index.h:229
int
int
rtflann::RandomCenterChooser
Definition: center_chooser.h:105
rtflann::FLANN_CHECKS_UNLIMITED
@ FLANN_CHECKS_UNLIMITED
Definition: defines.h:147
rtflann::KMeansIndex::addPoints
void addPoints(const Matrix< ElementType > &points, float rebuild_threshold=2)
Incrementally add points to the index.
Definition: kmeans_index.h:243
rtflann::KMeansIndex::~KMeansIndex
virtual ~KMeansIndex()
Definition: kmeans_index.h:215
general.h
rtflann::KMeansIndexParams
Definition: kmeans_index.h:86
rtflann::ResultSet
Definition: result_set.h:110
rtflann::NNIndex
Definition: nn_index.h:101
rtflann::KMeansIndex::Node::radius
DistanceType radius
Definition: kmeans_index.h:408
rtflann::serialization::SaveArchive
Definition: serialization.h:376
s
RealScalar s
rtflann::KMeansIndex::PointInfo
Definition: kmeans_index.h:377
rtflann::PooledAllocator::free
void free()
Definition: allocator.h:142
rtflann::KMeansIndex::BaseClass
NNIndex< Distance > BaseClass
Definition: kmeans_index.h:117
rtflann::KMeansIndex::serialize
void serialize(Archive &ar)
Definition: kmeans_index.h:262
logger.h
rtflann::KMeansIndex::cb_index_
float cb_index_
Definition: kmeans_index.h:1067
rtflann::Matrix_::cols
size_t cols
Definition: matrix.h:101
c
Scalar Scalar * c
rtflann::get_param
T get_param(const IndexParams &params, std::string name, const T &default_value)
Definition: params.h:121
rtflann::KMeansIndex::BranchSt
BranchStruct< NodePtr, DistanceType > BranchSt
Definition: kmeans_index.h:482
rtflann::KMeansIndex::computeNodeStatistics
void computeNodeStatistics(NodePtr node, const std::vector< int > &indices)
Definition: kmeans_index.h:523
rtflann::CenterChooser
Definition: center_chooser.h:74
rtflann::FLANNException
Definition: general.h:70
rtflann::KMeansIndex::clone
BaseClass * clone() const
Definition: kmeans_index.h:221
rtflann::KMeansIndex::PointInfo::serialize
void serialize(Archive &ar)
Definition: kmeans_index.h:383
rtflann::KMeansIndex::needs_vector_space_distance
bool needs_vector_space_distance
Definition: kmeans_index.h:119
rtflann::KMeansIndex::saveIndex
void saveIndex(FILE *stream)
Definition: kmeans_index.h:288
count
Index count
end
end
rtflann::KMeansIndex::ElementType
Distance::ElementType ElementType
Definition: kmeans_index.h:114
rtflann::GonzalesCenterChooser
Definition: center_chooser.h:154
nn_index.h
rtflann::NNIndex::points_
std::vector< ElementType * > points_
Definition: nn_index.h:914
rtflann::PooledAllocator::wastedMemory
int wastedMemory
Definition: allocator.h:119
dist.h
rtflann::KMeansIndex::getCenterOrdering
void getCenterOrdering(NodePtr node, const ElementType *q, std::vector< int > &sort_indices) const
Definition: kmeans_index.h:908
rtflann::KMeansIndex::KMeansIndex
KMeansIndex(const Matrix< ElementType > &inputData, const IndexParams &params=KMeansIndexParams(), Distance d=Distance())
Definition: kmeans_index.h:135
rtflann::KMeansIndex::iterations_
int iterations_
Definition: kmeans_index.h:1056
Eigen::PlainObjectBase::resize
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(Index rows, Index cols)
rtflann::KMeansIndex::buildIndexImpl
void buildIndexImpl()
Definition: kmeans_index.h:357
rtflann::KMeansIndex::Node
Definition: kmeans_index.h:399
rtflann::KMeansIndex::findNeighbors
void findNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
Definition: kmeans_index.h:311
rtflann::KMeansIndex::Node::access
friend struct serialization::access
Definition: kmeans_index.h:475
rtflann::serialization::access
Definition: serialization.h:29
rtflann::PooledAllocator::usedMemory
int usedMemory
Definition: allocator.h:118
rtflann::KMeansIndex::centers_init_
flann_centers_init_t centers_init_
Definition: kmeans_index.h:1059
indices
indices
rtflann::FLANN_CENTERS_KMEANSPP
@ FLANN_CENTERS_KMEANSPP
Definition: defines.h:99
rtflann::DynamicBitset::test
bool test(size_t index) const
Definition: dynamic_bitset.h:199
rtflann::KMeansIndex::set_cb_index
void set_cb_index(float index)
Definition: kmeans_index.h:227
rtflann::KMeansIndex::initCenterChooser
void initCenterChooser()
Definition: kmeans_index.h:193
rtflann::FLANN_CENTERS_RANDOM
@ FLANN_CENTERS_RANDOM
Definition: defines.h:97
rtflann::KMeansIndex::findNN
void findNN(NodePtr node, ResultSet< DistanceType > &result, const ElementType *vec, int &checks, int maxChecks, Heap< BranchSt > *heap) const
Definition: kmeans_index.h:783
rtflann::KMeansIndex::memoryCounter_
int memoryCounter_
Definition: kmeans_index.h:1082
rtflann::KMeansIndex
Definition: kmeans_index.h:111
point
point
random.h
matrix.h
rtflann::KMeansIndex::PointInfo::point
ElementType * point
Definition: kmeans_index.h:380
rtflann::FLANN_CENTERS_GONZALES
@ FLANN_CENTERS_GONZALES
Definition: defines.h:98
rtflann::NNIndex::size_
size_t size_
Definition: nn_index.h:874
rtflann::KMeansIndex::addPointToTree
void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
Definition: kmeans_index.h:997
j
std::ptrdiff_t j
rtflann::KMeansIndex::Node::points
std::vector< PointInfo > points
Definition: kmeans_index.h:424
q
EIGEN_DEVICE_FUNC const Scalar & q
rtflann::KMeansIndex::findNeighborsWithRemoved
void findNeighborsWithRemoved(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
Definition: kmeans_index.h:743
rtflann::KMeansIndex::getDistanceToBorder
DistanceType getDistanceToBorder(DistanceType *p, DistanceType *c, DistanceType *q) const
Definition: kmeans_index.h:930
rtflann::KMeansIndex::exploreNodeBranches
int exploreNodeBranches(NodePtr node, const ElementType *q, Heap< BranchSt > *heap) const
Definition: kmeans_index.h:830
rtflann::KMeansIndex::getClusterCenters
int getClusterCenters(Matrix< DistanceType > &centers)
Definition: kmeans_index.h:329
rtflann::KMeansIndex::chooseCenters_
CenterChooser< Distance > * chooseCenters_
Definition: kmeans_index.h:1087
rtflann::KMeansIndex::operator=
KMeansIndex & operator=(KMeansIndex other)
Definition: kmeans_index.h:186
rtflann::serialization::LoadArchive
Definition: serialization.h:550
Eigen::PlainObjectBase::rows
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
rtflann::NNIndex::veclen_
size_t veclen_
Definition: nn_index.h:884
glm::max
GLM_FUNC_DECL genType max(genType const &x, genType const &y)
rtflann::NNIndex::removed_points_
DynamicBitset removed_points_
Definition: nn_index.h:899
std::swap
void swap(GeographicLib::NearestNeighbor< dist_t, pos_t, distfun_t > &a, GeographicLib::NearestNeighbor< dist_t, pos_t, distfun_t > &b)
rtflann::KMeansIndex::Node::pivot
DistanceType * pivot
Definition: kmeans_index.h:404
rtflann::Matrix_::rows
size_t rows
Definition: matrix.h:100
rtflann::NNIndex::setDataset
void setDataset(const Matrix< ElementType > &dataset)
Definition: nn_index.h:784
d
d
saving.h
result_set.h
rtflann::KMeansIndex::getMinVarianceClusters
int getMinVarianceClusters(NodePtr root, std::vector< NodePtr > &clusters, int clusters_length, DistanceType &varianceValue) const
Definition: kmeans_index.h:954
params
SmartProjectionParams params(gtsam::HESSIAN, gtsam::ZERO_ON_DEGENERACY)
p
Point3_ p(2)
rtflann::KMeansIndex::PointInfo::index
size_t index
Definition: kmeans_index.h:379
rtflann::KMeansIndex::Node::~Node
~Node()
Definition: kmeans_index.h:430
rtflann::NNIndex::distance_
Distance distance_
Definition: nn_index.h:861
rtflann::NNIndex::removed_
bool removed_
Definition: nn_index.h:894
size_t
std::size_t size_t
rtflann::KMeansppCenterChooser
Definition: center_chooser.h:211
rtflann::FLANN_INDEX_KMEANS
@ FLANN_INDEX_KMEANS
Definition: defines.h:83
rtflann::KMeansIndex::root_
NodePtr root_
Definition: kmeans_index.h:1072
rtflann::KMeansIndex::computeClustering
void computeClustering(NodePtr node, int *indices, int indices_length, int branching)
Definition: kmeans_index.h:570
rtflann::KMeansIndex::DistanceType
Distance::ResultType DistanceType
Definition: kmeans_index.h:115
rtflann::flann_algorithm_t
flann_algorithm_t
Definition: defines.h:79
rtflann::KMeansIndex::findExactNN
void findExactNN(NodePtr node, ResultSet< DistanceType > &result, const ElementType *vec) const
Definition: kmeans_index.h:863
center_chooser.h
allocator.h
rtflann::KMeansIndex::Node::childs
std::vector< Node * > childs
Definition: kmeans_index.h:420
rtflann::serialization::make_binary_object
const binary_object make_binary_object(void *t, size_t size)
Definition: serialization.h:229
mean
Point3 mean(const CONTAINER &points)
rtflann::KMeansIndexParams::KMeansIndexParams
KMeansIndexParams(int branching=32, int iterations=11, flann_centers_init_t centers_init=FLANN_CENTERS_RANDOM, float cb_index=0.2)
Definition: kmeans_index.h:116
Index
struct Index Index
Definition: sqlite3.c:8577
rtflann::KMeansIndex::swap
void swap(KMeansIndex &other)
Definition: kmeans_index.h:1038
rtflann::NNIndex::size_at_build_
size_t size_at_build_
Definition: nn_index.h:879
rtflann::KMeansIndex::loadIndex
void loadIndex(FILE *stream)
Definition: kmeans_index.h:294
rtflann::KMeansIndex::usedMemory
int usedMemory() const
Definition: kmeans_index.h:236
rtflann::KMeansIndex::Node::variance
DistanceType variance
Definition: kmeans_index.h:412
rtflann::KMeansIndex::Node::serialize
void serialize(Archive &ar)
Definition: kmeans_index.h:441
Eigen.Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor >
rtflann::KMeansIndex::branching_
int branching_
Definition: kmeans_index.h:1053
rtflann::KMeansIndex::pool_
PooledAllocator pool_
Definition: kmeans_index.h:1077
rtflann::IndexParams
std::map< std::string, any > IndexParams
Definition: params.h:77
rtflann::NNIndex::buildIndex
virtual void buildIndex()
Definition: nn_index.h:153
NULL
#define NULL
root
root
dist
dist
t
Point2 t(10, 10)
USING_BASECLASS_SYMBOLS
#define USING_BASECLASS_SYMBOLS
Definition: nn_index.h:925
dst
char * dst
Definition: lz4.h:354
rtflann::KMeansIndex::freeIndex
void freeIndex()
Definition: kmeans_index.h:488
rtflann::flann_centers_init_t
flann_centers_init_t
Definition: defines.h:95
rtflann::NNIndex::extendDataset
void extendDataset(const Matrix< ElementType > &new_points)
Definition: nn_index.h:801
rtflann::KMeansIndex::getType
flann_algorithm_t getType() const
Definition: kmeans_index.h:123
rtflann
Definition: all_indices.h:49
i
int i
other
other
rtflann::KMeansIndex::Node::size
int size
Definition: kmeans_index.h:416
rtflann::KMeansIndex::copyTree
void copyTree(NodePtr &dst, const NodePtr &src)
Definition: kmeans_index.h:495
result
RESULT & result
heap.h
rtflann::NNIndex::index_params_
IndexParams index_params_
Definition: nn_index.h:889
rtflann::Matrix< ElementType >
rtflann::KMeansIndex::NodePtr
Node * NodePtr
Definition: kmeans_index.h:477
rtflann::Index
Definition: flann.hpp:104


rtabmap
Author(s): Mathieu Labbe
autogenerated on Sun Dec 1 2024 03:42:47