31 #ifndef RTABMAP_FLANN_KMEANS_INDEX_H_ 32 #define RTABMAP_FLANN_KMEANS_INDEX_H_ 65 (*this)[
"branching"] = branching;
67 (*this)[
"iterations"] = iterations;
69 (*this)[
"centers_init"] = centers_init;
71 (*this)[
"cb_index"] = cb_index;
82 template <
typename Distance>
108 Distance d = Distance())
109 : BaseClass(params,
d), root_(
NULL), memoryCounter_(0)
111 branching_ =
get_param(params,
"branching",32);
112 iterations_ =
get_param(params,
"iterations",11);
117 cb_index_ =
get_param(params,
"cb_index",0.4
f);
120 setDataset(inputData);
132 : BaseClass(params,
d), root_(
NULL), memoryCounter_(0)
134 branching_ =
get_param(params,
"branching",32);
135 iterations_ =
get_param(params,
"iterations",11);
140 cb_index_ =
get_param(params,
"cb_index",0.4
f);
147 branching_(other.branching_),
148 iterations_(other.iterations_),
149 centers_init_(other.centers_init_),
150 cb_index_(other.cb_index_),
151 memoryCounter_(other.memoryCounter_)
155 copyTree(root_, other.
root_);
167 switch(centers_init_) {
178 throw FLANNException(
"Unknown algorithm for choosing initial centers.");
189 delete chooseCenters_;
210 return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
213 using BaseClass::buildIndex;
217 assert(points.
cols==veclen_);
218 size_t old_size = size_;
220 extendDataset(points);
222 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
226 for (
size_t i=0;i<points.
rows;++i) {
227 DistanceType dist = distance_(root_->pivot, points[i], veclen_);
228 addPointToTree(root_, old_size + i, dist);
233 template<
typename Archive>
246 if (Archive::is_loading::value) {
247 root_ =
new(pool_)
Node();
251 if (Archive::is_loading::value) {
252 index_params_[
"algorithm"] = getType();
253 index_params_[
"branching"] = branching_;
254 index_params_[
"iterations"] = iterations_;
255 index_params_[
"centers_init"] = centers_init_;
256 index_params_[
"cb_index"] = cb_index_;
286 findNeighborsWithRemoved<true>(result, vec, searchParams);
289 findNeighborsWithRemoved<false>(result, vec, searchParams);
303 int numClusters = centers.
rows;
308 DistanceType variance;
309 std::vector<NodePtr> clusters(numClusters);
311 int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
313 Logger::info(
"Clusters requested: %d, returning %d\n",numClusters, clusterCount);
315 for (
int i=0; i<clusterCount; ++i) {
316 DistanceType* center = clusters[i]->pivot;
317 for (
size_t j=0; j<veclen_; ++j) {
318 centers[i][j] = center[j];
331 chooseCenters_->setDataSize(veclen_);
337 std::vector<int> indices(size_);
338 for (
size_t i=0; i<size_; ++i) {
342 root_ =
new(pool_)
Node();
343 computeNodeStatistics(root_, indices);
344 computeClustering(root_, &indices[0], (
int)size_, branching_);
354 template<
typename Archive>
358 Index* obj =
static_cast<Index*
>(ar.getObject());
363 if (Archive::is_loading::value) point = obj->points_[index];
405 if (!childs.empty()) {
406 for (
size_t i=0; i<childs.size(); ++i) {
412 template<
typename Archive>
416 Index* obj =
static_cast<Index*
>(ar.getObject());
418 if (Archive::is_loading::value) {
419 pivot =
new DistanceType[obj->veclen_];
427 if (Archive::is_saving::value) {
428 childs_size = childs.size();
432 if (childs_size==0) {
436 if (Archive::is_loading::value) {
437 childs.resize(childs_size);
439 for (
size_t i=0;i<childs_size;++i) {
440 if (Archive::is_loading::value) {
441 childs[i] =
new(obj->pool_)
Node();
462 if (root_) root_->~Node();
469 dst =
new(pool_)
Node();
470 dst->
pivot =
new DistanceType[veclen_];
476 if (src->
childs.size()==0) {
481 for (
size_t i=0;i<src->
childs.size();++i) {
497 size_t size = indices.size();
499 DistanceType* mean =
new DistanceType[veclen_];
500 memoryCounter_ += int(veclen_*
sizeof(DistanceType));
501 memset(mean,0,veclen_*
sizeof(DistanceType));
503 for (
size_t i=0; i<size; ++i) {
504 ElementType* vec = points_[indices[i]];
505 for (
size_t j=0; j<veclen_; ++j) {
509 DistanceType div_factor = DistanceType(1)/size;
510 for (
size_t j=0; j<veclen_; ++j) {
511 mean[j] *= div_factor;
514 DistanceType radius = 0;
515 DistanceType variance = 0;
516 for (
size_t i=0; i<size; ++i) {
517 DistanceType dist = distance_(mean, points_[indices[i]], veclen_);
544 node->
size = indices_length;
546 if (indices_length < branching) {
547 node->
points.resize(indices_length);
548 for (
int i=0;i<indices_length;++i) {
549 node->
points[i].index = indices[i];
550 node->
points[i].point = points_[indices[i]];
556 std::vector<int> centers_idx(branching);
558 (*chooseCenters_)(branching, indices, indices_length, ¢ers_idx[0], centers_length);
560 if (centers_length<branching) {
561 node->
points.resize(indices_length);
562 for (
int i=0;i<indices_length;++i) {
563 node->
points[i].index = indices[i];
564 node->
points[i].point = points_[indices[i]];
571 Matrix<double> dcenters(
new double[branching*veclen_],branching,veclen_);
572 for (
int i=0; i<centers_length; ++i) {
573 ElementType* vec = points_[centers_idx[i]];
574 for (
size_t k=0; k<veclen_; ++k) {
575 dcenters[i][k] = double(vec[k]);
579 std::vector<DistanceType> radiuses(branching,0);
580 std::vector<int> count(branching,0);
583 std::vector<int> belongs_to(indices_length);
584 for (
int i=0; i<indices_length; ++i) {
586 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
588 for (
int j=1; j<branching; ++j) {
589 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
590 if (sq_dist>new_sq_dist) {
592 sq_dist = new_sq_dist;
595 if (sq_dist>radiuses[belongs_to[i]]) {
596 radiuses[belongs_to[i]] = sq_dist;
598 count[belongs_to[i]]++;
601 bool converged =
false;
603 while (!converged && iteration<iterations_) {
608 for (
int i=0; i<branching; ++i) {
609 memset(dcenters[i],0,
sizeof(
double)*veclen_);
612 for (
int i=0; i<indices_length; ++i) {
613 ElementType* vec = points_[indices[i]];
614 double* center = dcenters[belongs_to[i]];
615 for (
size_t k=0; k<veclen_; ++k) {
619 for (
int i=0; i<branching; ++i) {
621 double div_factor = 1.0/cnt;
622 for (
size_t k=0; k<veclen_; ++k) {
623 dcenters[i][k] *= div_factor;
628 for (
int i=0; i<indices_length; ++i) {
629 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
630 int new_centroid = 0;
631 for (
int j=1; j<branching; ++j) {
632 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
633 if (sq_dist>new_sq_dist) {
635 sq_dist = new_sq_dist;
638 if (sq_dist>radiuses[new_centroid]) {
639 radiuses[new_centroid] = sq_dist;
641 if (new_centroid != belongs_to[i]) {
642 count[belongs_to[i]]--;
643 count[new_centroid]++;
644 belongs_to[i] = new_centroid;
650 for (
int i=0; i<branching; ++i) {
654 int j = (i+1)%branching;
655 while (count[j]<=1) {
659 for (
int k=0; k<indices_length; ++k) {
660 if (belongs_to[k]==j) {
673 std::vector<DistanceType*> centers(branching);
675 for (
int i=0; i<branching; ++i) {
676 centers[i] =
new DistanceType[veclen_];
677 memoryCounter_ += veclen_*
sizeof(DistanceType);
678 for (
size_t k=0; k<veclen_; ++k) {
679 centers[i][k] = (DistanceType)dcenters[i][k];
685 node->
childs.resize(branching);
688 for (
int c=0; c<branching; ++c) {
691 DistanceType variance = 0;
692 for (
int i=0; i<indices_length; ++i) {
693 if (belongs_to[i]==c) {
694 variance += distance_(centers[c], points_[indices[i]], veclen_);
695 std::swap(indices[i],indices[end]);
696 std::swap(belongs_to[i],belongs_to[end]);
703 node->
childs[c]->radius = radiuses[c];
704 node->
childs[c]->pivot = centers[c];
705 node->
childs[c]->variance = variance;
706 computeClustering(node->
childs[c],indices+start, end-start, branching);
710 delete[] dcenters.
ptr();
714 template<
bool with_removed>
718 int maxChecks = searchParams.
checks;
721 findExactNN<with_removed>(root_, result, vec);
728 findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
731 while (heap->
popMin(branch) && (checks<maxChecks || !result.full())) {
732 NodePtr node = branch.node;
733 findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
754 template<
bool with_removed>
760 DistanceType bsq = distance_(vec, node->
pivot, veclen_);
761 DistanceType rsq = node->
radius;
764 DistanceType val = bsq-rsq-wsq;
765 DistanceType val2 = val*val-4*rsq*wsq;
768 if ((val>0)&&(val2>0)) {
773 if (node->
childs.empty()) {
774 if (checks>=maxChecks) {
775 if (result.
full())
return;
777 for (
int i=0; i<node->
size; ++i) {
779 int index = point_info.
index;
781 if (removed_points_.test(index))
continue;
783 DistanceType dist = distance_(point_info.
point, vec, veclen_);
789 int closest_center = exploreNodeBranches(node, vec, heap);
790 findNN<with_removed>(node->
childs[closest_center],result,vec, checks, maxChecks, heap);
804 std::vector<DistanceType> domain_distances(branching_);
806 domain_distances[best_index] = distance_(q, node->
childs[best_index]->pivot, veclen_);
807 for (
int i=1; i<branching_; ++i) {
808 domain_distances[i] = distance_(q, node->
childs[i]->pivot, veclen_);
809 if (domain_distances[i]<domain_distances[best_index]) {
815 for (
int i=0; i<branching_; ++i) {
816 if (i != best_index) {
817 domain_distances[i] -= cb_index_*node->
childs[i]->variance;
823 heap->
insert(BranchSt(node->
childs[i],domain_distances[i]));
834 template<
bool with_removed>
839 DistanceType bsq = distance_(vec, node->
pivot, veclen_);
840 DistanceType rsq = node->
radius;
843 DistanceType val = bsq-rsq-wsq;
844 DistanceType val2 = val*val-4*rsq*wsq;
847 if ((val>0)&&(val2>0)) {
852 if (node->
childs.empty()) {
853 for (
int i=0; i<node->
size; ++i) {
855 int index = point_info.
index;
857 if (removed_points_.test(index))
continue;
859 DistanceType dist = distance_(point_info.
point, vec, veclen_);
864 std::vector<int> sort_indices(branching_);
865 getCenterOrdering(node, vec, sort_indices);
867 for (
int i=0; i<branching_; ++i) {
868 findExactNN<with_removed>(node->
childs[sort_indices[i]],result,vec);
880 void getCenterOrdering(NodePtr node,
const ElementType* q, std::vector<int>& sort_indices)
const 882 std::vector<DistanceType> domain_distances(branching_);
883 for (
int i=0; i<branching_; ++i) {
884 DistanceType dist = distance_(q, node->
childs[i]->pivot, veclen_);
887 while (domain_distances[j]<dist && j<i) j++;
888 for (
int k=i; k>j; --k) {
889 domain_distances[k] = domain_distances[k-1];
890 sort_indices[k] = sort_indices[k-1];
892 domain_distances[j] = dist;
904 DistanceType sum = 0;
905 DistanceType sum2 = 0;
907 for (
int i=0; i<veclen_; ++i) {
908 DistanceType t = c[i]-p[i];
909 sum += t*(q[i]-(c[i]+p[i])/2);
926 int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters,
int clusters_length, DistanceType& varianceValue)
const 928 int clusterCount = 1;
933 while (clusterCount<clusters_length) {
937 for (
int i=0; i<clusterCount; ++i) {
938 if (!clusters[i]->childs.empty()) {
940 DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
942 for (
int j=0; j<branching_; ++j) {
943 variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
945 if (variance<minVariance) {
946 minVariance = variance;
952 if (splitIndex==-1)
break;
953 if ( (branching_+clusterCount-1) > clusters_length)
break;
955 meanVariance = minVariance;
958 NodePtr toSplit = clusters[splitIndex];
959 clusters[splitIndex] = toSplit->
childs[0];
960 for (
int i=1; i<branching_; ++i) {
961 clusters[clusterCount++] = toSplit->
childs[i];
965 varianceValue = meanVariance/root->
size;
971 ElementType* point = points_[index];
972 if (dist_to_pivot>node->
radius) {
973 node->
radius = dist_to_pivot;
979 if (node->
childs.empty()) {
981 point_info.
index = index;
982 point_info.
point = point;
983 node->
points.push_back(point_info);
985 std::vector<int> indices(node->
points.size());
986 for (
size_t i=0;i<node->
points.size();++i) {
987 indices[i] = node->
points[i].index;
989 computeNodeStatistics(node, indices);
990 if (indices.size()>=size_t(branching_)) {
991 computeClustering(node, &indices[0], indices.
size(), branching_);
997 DistanceType dist = distance_(node->
childs[closest]->pivot, point, veclen_);
998 for (
size_t i=1;i<size_t(branching_);++i) {
999 DistanceType crt_dist = distance_(node->
childs[i]->pivot, point, veclen_);
1000 if (crt_dist<dist) {
1005 addPointToTree(node->
childs[closest], index, dist);
1016 std::swap(root_, other.
root_);
1017 std::swap(pool_, other.
pool_);
1066 #endif //FLANN_KMEANS_INDEX_H_ void findNeighborsWithRemoved(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
DistanceType getDistanceToBorder(DistanceType *p, DistanceType *c, DistanceType *q) const
std::map< std::string, any > IndexParams
Distance::ResultType DistanceType
T get_param(const IndexParams ¶ms, std::string name, const T &default_value)
flann_centers_init_t centers_init_
KMeansIndex & operator=(KMeansIndex other)
int getClusterCenters(Matrix< DistanceType > ¢ers)
std::vector< Node * > childs
void insert(const T &value)
#define USING_BASECLASS_SYMBOLS
void serialize(Archive &ar)
void computeClustering(NodePtr node, int *indices, int indices_length, int branching)
void swap(KMeansIndex &other)
void saveIndex(FILE *stream)
flann_algorithm_t getType() const
KMeansIndex(const KMeansIndex &other)
void loadIndex(FILE *stream)
int exploreNodeBranches(NodePtr node, const ElementType *q, Heap< BranchSt > *heap) const
const binary_object make_binary_object(void *t, size_t size)
std::vector< PointInfo > points
Distance::ElementType ElementType
KMeansIndex(const IndexParams ¶ms=KMeansIndexParams(), Distance d=Distance())
void findNeighbors(ResultSet< DistanceType > &result, const ElementType *vec, const SearchParams &searchParams) const
void serialize(Archive &ar)
void addPoints(const Matrix< ElementType > &points, float rebuild_threshold=2)
Incrementally add points to the index.
virtual DistanceType worstDist() const =0
void copyTree(NodePtr &dst, const NodePtr &src)
void set_cb_index(float index)
BaseClass * clone() const
GLM_FUNC_DECL genType max(genType const &x, genType const &y)
virtual bool full() const =0
void getCenterOrdering(NodePtr node, const ElementType *q, std::vector< int > &sort_indices) const
CenterChooser< Distance > * chooseCenters_
NNIndex< Distance > BaseClass
void findNN(NodePtr node, ResultSet< DistanceType > &result, const ElementType *vec, int &checks, int maxChecks, Heap< BranchSt > *heap) const
void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
KMeansIndexParams(int branching=32, int iterations=11, flann_centers_init_t centers_init=FLANN_CENTERS_RANDOM, float cb_index=0.2)
static void freeIndex(sqlite3 *db, Index *p)
void computeNodeStatistics(NodePtr node, const std::vector< int > &indices)
bool needs_vector_space_distance
void findExactNN(NodePtr node, ResultSet< DistanceType > &result, const ElementType *vec) const
void serialize(Archive &ar)
BranchStruct< NodePtr, DistanceType > BranchSt
KMeansIndex(const Matrix< ElementType > &inputData, const IndexParams ¶ms=KMeansIndexParams(), Distance d=Distance())
virtual void addPoint(DistanceType dist, size_t index)=0
int getMinVarianceClusters(NodePtr root, std::vector< NodePtr > &clusters, int clusters_length, DistanceType &varianceValue) const