31 #ifndef RTABMAP_FLANN_KMEANS_INDEX_H_
32 #define RTABMAP_FLANN_KMEANS_INDEX_H_
65 (*this)[
"branching"] = branching;
67 (*this)[
"iterations"] = iterations;
69 (*this)[
"centers_init"] = centers_init;
71 (*this)[
"cb_index"] = cb_index;
82 template <
typename Distance>
83 class KMeansIndex :
public NNIndex<Distance>
108 Distance d = Distance())
178 throw FLANNException(
"Unknown algorithm for choosing initial centers.");
218 size_t old_size =
size_;
226 for (
size_t i=0;
i<points.
rows;++
i) {
233 template<
typename Archive>
246 if (Archive::is_loading::value) {
251 if (Archive::is_loading::value) {
286 findNeighborsWithRemoved<true>(result, vec, searchParams);
289 findNeighborsWithRemoved<false>(
result, vec, searchParams);
303 int numClusters = centers.
rows;
305 throw FLANNException(
"Number of clusters must be at least 1");
309 std::vector<NodePtr> clusters(numClusters);
313 Logger::info(
"Clusters requested: %d, returning %d\n",numClusters, clusterCount);
315 for (
int i=0;
i<clusterCount; ++
i) {
318 centers[
i][
j] = center[
j];
354 template<
typename Archive>
358 Index* obj =
static_cast<Index*
>(ar.getObject());
363 if (Archive::is_loading::value)
point = obj->points_[
index];
392 std::vector<Node*>
childs;
396 std::vector<PointInfo>
points;
412 template<
typename Archive>
418 if (Archive::is_loading::value) {
427 if (Archive::is_saving::value) {
428 childs_size =
childs.size();
432 if (childs_size==0) {
436 if (Archive::is_loading::value) {
437 childs.resize(childs_size);
439 for (
size_t i=0;
i<childs_size;++
i) {
440 if (Archive::is_loading::value) {
454 typedef BranchStruct<NodePtr, DistanceType>
BranchSt;
471 std::copy(src->pivot, src->pivot+
veclen_,
dst->pivot);
472 dst->radius = src->radius;
473 dst->variance = src->variance;
474 dst->size = src->size;
476 if (src->childs.size()==0) {
477 dst->points = src->points;
480 dst->childs.resize(src->childs.size());
481 for (
size_t i=0;
i<src->childs.size();++
i) {
503 for (
size_t i=0;
i<
size; ++
i) {
511 mean[
j] *= div_factor;
516 for (
size_t i=0;
i<
size; ++
i) {
544 node->size = indices_length;
546 if (indices_length < branching) {
547 node->points.resize(indices_length);
548 for (
int i=0;
i<indices_length;++
i) {
552 node->childs.clear();
556 std::vector<int> centers_idx(branching);
558 (*chooseCenters_)(branching,
indices, indices_length, ¢ers_idx[0], centers_length);
560 if (centers_length<branching) {
561 node->points.resize(indices_length);
562 for (
int i=0;
i<indices_length;++
i) {
566 node->childs.clear();
572 for (
int i=0;
i<centers_length; ++
i) {
574 for (
size_t k=0; k<
veclen_; ++k) {
575 dcenters[
i][k] = double(vec[k]);
579 std::vector<DistanceType> radiuses(branching,0);
580 std::vector<int>
count(branching,0);
583 std::vector<int> belongs_to(indices_length);
584 for (
int i=0;
i<indices_length; ++
i) {
588 for (
int j=1;
j<branching; ++
j) {
590 if (sq_dist>new_sq_dist) {
592 sq_dist = new_sq_dist;
595 if (sq_dist>radiuses[belongs_to[i]]) {
596 radiuses[belongs_to[
i]] = sq_dist;
601 bool converged =
false;
608 for (
int i=0;
i<branching; ++
i) {
609 memset(dcenters[i],0,
sizeof(
double)*
veclen_);
612 for (
int i=0;
i<indices_length; ++
i) {
614 double* center = dcenters[belongs_to[
i]];
615 for (
size_t k=0; k<
veclen_; ++k) {
619 for (
int i=0;
i<branching; ++
i) {
621 double div_factor = 1.0/cnt;
622 for (
size_t k=0; k<
veclen_; ++k) {
623 dcenters[
i][k] *= div_factor;
628 for (
int i=0;
i<indices_length; ++
i) {
630 int new_centroid = 0;
631 for (
int j=1;
j<branching; ++
j) {
633 if (sq_dist>new_sq_dist) {
635 sq_dist = new_sq_dist;
638 if (sq_dist>radiuses[new_centroid]) {
639 radiuses[new_centroid] = sq_dist;
641 if (new_centroid != belongs_to[i]) {
643 count[new_centroid]++;
644 belongs_to[
i] = new_centroid;
650 for (
int i=0;
i<branching; ++
i) {
654 int j = (
i+1)%branching;
655 while (count[j]<=1) {
659 for (
int k=0; k<indices_length; ++k) {
660 if (belongs_to[k]==j) {
673 std::vector<DistanceType*> centers(branching);
675 for (
int i=0;
i<branching; ++
i) {
678 for (
size_t k=0; k<
veclen_; ++k) {
685 node->childs.
resize(branching);
688 for (
int c=0;
c<branching; ++
c) {
692 for (
int i=0;
i<indices_length; ++
i) {
693 if (belongs_to[i]==c) {
696 std::swap(belongs_to[i],belongs_to[end]);
702 node->childs[
c] =
new(
pool_) Node();
703 node->childs[
c]->radius = radiuses[
c];
704 node->childs[
c]->pivot = centers[
c];
705 node->childs[
c]->variance = variance;
710 delete[] dcenters.ptr();
714 template<
bool with_removed>
718 int maxChecks = searchParams.checks;
721 findExactNN<with_removed>(
root_, result, vec);
725 Heap<BranchSt>* heap =
new Heap<BranchSt>((
int)
size_);
728 findNN<with_removed>(
root_, result, vec, checks, maxChecks, heap);
731 while (heap->popMin(branch) && (checks<maxChecks || !
result.full())) {
733 findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
754 template<
bool with_removed>
756 Heap<BranchSt>* heap)
const
768 if ((val>0)&&(val2>0)) {
773 if (node->childs.empty()) {
774 if (checks>=maxChecks) {
775 if (
result.full())
return;
777 for (
int i=0;
i<node->size; ++
i) {
778 PointInfo& point_info = node->points[
i];
779 int index = point_info.index;
784 result.addPoint(dist, index);
790 findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
804 std::vector<DistanceType> domain_distances(
branching_);
806 domain_distances[best_index] =
distance_(q, node->childs[best_index]->pivot,
veclen_);
809 if (domain_distances[i]<domain_distances[best_index]) {
816 if (i != best_index) {
817 domain_distances[
i] -=
cb_index_*node->childs[
i]->variance;
823 heap->insert(
BranchSt(node->childs[i],domain_distances[i]));
834 template<
bool with_removed>
847 if ((val>0)&&(val2>0)) {
852 if (node->
childs.empty()) {
853 for (
int i=0;
i<node->
size; ++
i) {
855 int index = point_info.index;
868 findExactNN<with_removed>(node->
childs[sort_indices[
i]],
result,vec);
882 std::vector<DistanceType> domain_distances(
branching_);
887 while (domain_distances[j]<dist && j<i)
j++;
888 for (
int k=i; k>
j; --k) {
889 domain_distances[k] = domain_distances[k-1];
890 sort_indices[k] = sort_indices[k-1];
892 domain_distances[
j] =
dist;
928 int clusterCount = 1;
933 while (clusterCount<clusters_length) {
937 for (
int i=0;
i<clusterCount; ++
i) {
938 if (!clusters[
i]->childs.empty()) {
940 DistanceType variance = meanVariance - clusters[
i]->variance*clusters[
i]->size;
943 variance += clusters[
i]->childs[
j]->variance*clusters[
i]->childs[
j]->size;
945 if (variance<minVariance) {
946 minVariance = variance;
952 if (splitIndex==-1)
break;
953 if ( (
branching_+clusterCount-1) > clusters_length)
break;
955 meanVariance = minVariance;
958 NodePtr toSplit = clusters[splitIndex];
959 clusters[splitIndex] = toSplit->
childs[0];
961 clusters[clusterCount++] = toSplit->
childs[
i];
965 varianceValue = meanVariance/
root->size;
972 if (dist_to_pivot>node->radius) {
973 node->radius = dist_to_pivot;
976 node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
979 if (node->childs.empty()) {
980 PointInfo point_info;
981 point_info.index = index;
982 point_info.point =
point;
983 node->points.push_back(point_info);
985 std::vector<int>
indices(node->points.size());
986 for (
size_t i=0;
i<node->points.size();++
i) {
1000 if (crt_dist<dist) {
1049 PooledAllocator
pool_;
1066 #endif //FLANN_KMEANS_INDEX_H_