kmeans_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 #ifndef RTABMAP_FLANN_KMEANS_INDEX_H_
00032 #define RTABMAP_FLANN_KMEANS_INDEX_H_
00033 
00034 #include <algorithm>
00035 #include <string>
00036 #include <map>
00037 #include <cassert>
00038 #include <limits>
00039 #include <cmath>
00040 
00041 #include "rtflann/general.h"
00042 #include "rtflann/algorithms/nn_index.h"
00043 #include "rtflann/algorithms/dist.h"
00044 #include "rtflann/algorithms/center_chooser.h"
00045 #include "rtflann/util/matrix.h"
00046 #include "rtflann/util/result_set.h"
00047 #include "rtflann/util/heap.h"
00048 #include "rtflann/util/allocator.h"
00049 #include "rtflann/util/random.h"
00050 #include "rtflann/util/saving.h"
00051 #include "rtflann/util/logger.h"
00052 
00053 
00054 
00055 namespace rtflann
00056 {
00057 
00058 struct KMeansIndexParams : public IndexParams
00059 {
00060     KMeansIndexParams(int branching = 32, int iterations = 11,
00061                       flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 )
00062     {
00063         (*this)["algorithm"] = FLANN_INDEX_KMEANS;
00064         // branching factor
00065         (*this)["branching"] = branching;
00066         // max iterations to perform in one kmeans clustering (kmeans tree)
00067         (*this)["iterations"] = iterations;
00068         // algorithm used for picking the initial cluster centers for kmeans tree
00069         (*this)["centers_init"] = centers_init;
00070         // cluster boundary index. Used when searching the kmeans tree
00071         (*this)["cb_index"] = cb_index;
00072     }
00073 };
00074 
00075 
00082 template <typename Distance>
00083 class KMeansIndex : public NNIndex<Distance>
00084 {
00085 public:
00086     typedef typename Distance::ElementType ElementType;
00087     typedef typename Distance::ResultType DistanceType;
00088 
00089     typedef NNIndex<Distance> BaseClass;
00090 
00091     typedef bool needs_vector_space_distance;
00092 
00093 
00094 
00095     flann_algorithm_t getType() const
00096     {
00097         return FLANN_INDEX_KMEANS;
00098     }
00099 
00107     KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(),
00108                 Distance d = Distance())
00109         : BaseClass(params,d), root_(NULL), memoryCounter_(0)
00110     {
00111         branching_ = get_param(params,"branching",32);
00112         iterations_ = get_param(params,"iterations",11);
00113         if (iterations_<0) {
00114             iterations_ = (std::numeric_limits<int>::max)();
00115         }
00116         centers_init_  = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
00117         cb_index_  = get_param(params,"cb_index",0.4f);
00118 
00119         initCenterChooser();
00120         setDataset(inputData);
00121     }
00122 
00123 
00131     KMeansIndex(const IndexParams& params = KMeansIndexParams(), Distance d = Distance())
00132         : BaseClass(params, d), root_(NULL), memoryCounter_(0)
00133     {
00134         branching_ = get_param(params,"branching",32);
00135         iterations_ = get_param(params,"iterations",11);
00136         if (iterations_<0) {
00137             iterations_ = (std::numeric_limits<int>::max)();
00138         }
00139         centers_init_  = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
00140         cb_index_  = get_param(params,"cb_index",0.4f);
00141 
00142         initCenterChooser();
00143     }
00144 
00145 
00146     KMeansIndex(const KMeansIndex& other) : BaseClass(other),
00147                 branching_(other.branching_),
00148                 iterations_(other.iterations_),
00149                 centers_init_(other.centers_init_),
00150                 cb_index_(other.cb_index_),
00151                 memoryCounter_(other.memoryCounter_)
00152     {
00153         initCenterChooser();
00154 
00155         copyTree(root_, other.root_);
00156     }
00157 
00158     KMeansIndex& operator=(KMeansIndex other)
00159     {
00160         this->swap(other);
00161         return *this;
00162     }
00163 
00164 
00165     void initCenterChooser()
00166     {
00167         switch(centers_init_) {
00168         case FLANN_CENTERS_RANDOM:
00169                 chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
00170                 break;
00171         case FLANN_CENTERS_GONZALES:
00172                 chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
00173                 break;
00174         case FLANN_CENTERS_KMEANSPP:
00175             chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
00176                 break;
00177         default:
00178             throw FLANNException("Unknown algorithm for choosing initial centers.");
00179         }
00180     }
00181 
00187     virtual ~KMeansIndex()
00188     {
00189         delete chooseCenters_;
00190         freeIndex();
00191     }
00192 
00193     BaseClass* clone() const
00194     {
00195         return new KMeansIndex(*this);
00196     }
00197 
00198 
00199     void set_cb_index( float index)
00200     {
00201         cb_index_ = index;
00202     }
00203 
00208     int usedMemory() const
00209     {
00210         return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
00211     }
00212 
00213     using BaseClass::buildIndex;
00214 
00215     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00216     {
00217         assert(points.cols==veclen_);
00218         size_t old_size = size_;
00219 
00220         extendDataset(points);
00221         
00222         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00223             buildIndex();
00224         }
00225         else {
00226             for (size_t i=0;i<points.rows;++i) {
00227                 DistanceType dist = distance_(root_->pivot, points[i], veclen_);
00228                 addPointToTree(root_, old_size + i, dist);
00229             }            
00230         }
00231     }
00232 
00233     template<typename Archive>
00234     void serialize(Archive& ar)
00235     {
00236         ar.setObject(this);
00237 
00238         ar & *static_cast<NNIndex<Distance>*>(this);
00239 
00240         ar & branching_;
00241         ar & iterations_;
00242         ar & memoryCounter_;
00243         ar & cb_index_;
00244         ar & centers_init_;
00245 
00246         if (Archive::is_loading::value) {
00247                 root_ = new(pool_) Node();
00248         }
00249         ar & *root_;
00250 
00251         if (Archive::is_loading::value) {
00252             index_params_["algorithm"] = getType();
00253             index_params_["branching"] = branching_;
00254             index_params_["iterations"] = iterations_;
00255             index_params_["centers_init"] = centers_init_;
00256             index_params_["cb_index"] = cb_index_;
00257         }
00258     }
00259 
00260     void saveIndex(FILE* stream)
00261     {
00262         serialization::SaveArchive sa(stream);
00263         sa & *this;
00264     }
00265 
00266     void loadIndex(FILE* stream)
00267     {
00268         freeIndex();
00269         serialization::LoadArchive la(stream);
00270         la & *this;
00271     }
00272 
00283     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00284     {
00285         if (removed_) {
00286                 findNeighborsWithRemoved<true>(result, vec, searchParams);
00287         }
00288         else {
00289                 findNeighborsWithRemoved<false>(result, vec, searchParams);
00290         }
00291 
00292     }
00293 
00301     int getClusterCenters(Matrix<DistanceType>& centers)
00302     {
00303         int numClusters = centers.rows;
00304         if (numClusters<1) {
00305             throw FLANNException("Number of clusters must be at least 1");
00306         }
00307 
00308         DistanceType variance;
00309         std::vector<NodePtr> clusters(numClusters);
00310 
00311         int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
00312 
00313         Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
00314 
00315         for (int i=0; i<clusterCount; ++i) {
00316             DistanceType* center = clusters[i]->pivot;
00317             for (size_t j=0; j<veclen_; ++j) {
00318                 centers[i][j] = center[j];
00319             }
00320         }
00321 
00322         return clusterCount;
00323     }
00324 
00325 protected:
00329     void buildIndexImpl()
00330     {
00331         chooseCenters_->setDataSize(veclen_);
00332 
00333         if (branching_<2) {
00334             throw FLANNException("Branching factor must be at least 2");
00335         }
00336 
00337         std::vector<int> indices(size_);
00338         for (size_t i=0; i<size_; ++i) {
00339                 indices[i] = int(i);
00340         }
00341 
00342         root_ = new(pool_) Node();
00343         computeNodeStatistics(root_, indices);
00344         computeClustering(root_, &indices[0], (int)size_, branching_);
00345     }
00346 
00347 private:
00348 
00349     struct PointInfo
00350     {
00351         size_t index;
00352         ElementType* point;
00353     private:
00354         template<typename Archive>
00355         void serialize(Archive& ar)
00356         {
00357                 typedef KMeansIndex<Distance> Index;
00358                 Index* obj = static_cast<Index*>(ar.getObject());
00359 
00360                 ar & index;
00361 //              ar & point;
00362 
00363                         if (Archive::is_loading::value) point = obj->points_[index];
00364         }
00365         friend struct serialization::access;
00366     };
00367 
00371     struct Node
00372     {
00376         DistanceType* pivot;
00380         DistanceType radius;
00384         DistanceType variance;
00388         int size;
00392         std::vector<Node*> childs;
00396         std::vector<PointInfo> points;
00400 //        int level;
00401 
00402         ~Node()
00403         {
00404             delete[] pivot;
00405             if (!childs.empty()) {
00406                 for (size_t i=0; i<childs.size(); ++i) {
00407                     childs[i]->~Node();
00408                 }
00409             }
00410         }
00411 
00412         template<typename Archive>
00413         void serialize(Archive& ar)
00414         {
00415                 typedef KMeansIndex<Distance> Index;
00416                 Index* obj = static_cast<Index*>(ar.getObject());
00417 
00418                 if (Archive::is_loading::value) {
00419                         pivot = new DistanceType[obj->veclen_];
00420                 }
00421                 ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType));
00422                 ar & radius;
00423                 ar & variance;
00424                 ar & size;
00425 
00426                 size_t childs_size;
00427                 if (Archive::is_saving::value) {
00428                         childs_size = childs.size();
00429                 }
00430                 ar & childs_size;
00431 
00432                 if (childs_size==0) {
00433                         ar & points;
00434                 }
00435                 else {
00436                         if (Archive::is_loading::value) {
00437                                 childs.resize(childs_size);
00438                         }
00439                         for (size_t i=0;i<childs_size;++i) {
00440                                 if (Archive::is_loading::value) {
00441                                         childs[i] = new(obj->pool_) Node();
00442                                 }
00443                                 ar & *childs[i];
00444                         }
00445                 }
00446         }
00447         friend struct serialization::access;
00448     };
00449     typedef Node* NodePtr;
00450 
00454     typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00455 
00456 
00460     void freeIndex()
00461     {
00462         if (root_) root_->~Node();
00463         root_ = NULL;
00464         pool_.free();
00465     }
00466 
00467     void copyTree(NodePtr& dst, const NodePtr& src)
00468     {
00469         dst = new(pool_) Node();
00470         dst->pivot = new DistanceType[veclen_];
00471         std::copy(src->pivot, src->pivot+veclen_, dst->pivot);
00472         dst->radius = src->radius;
00473         dst->variance = src->variance;
00474         dst->size = src->size;
00475 
00476         if (src->childs.size()==0) {
00477                 dst->points = src->points;
00478         }
00479         else {
00480                 dst->childs.resize(src->childs.size());
00481                 for (size_t i=0;i<src->childs.size();++i) {
00482                         copyTree(dst->childs[i], src->childs[i]);
00483                 }
00484         }
00485     }
00486 
00487 
00495     void computeNodeStatistics(NodePtr node, const std::vector<int>& indices)
00496     {
00497         size_t size = indices.size();
00498 
00499         DistanceType* mean = new DistanceType[veclen_];
00500         memoryCounter_ += int(veclen_*sizeof(DistanceType));
00501         memset(mean,0,veclen_*sizeof(DistanceType));
00502 
00503         for (size_t i=0; i<size; ++i) {
00504             ElementType* vec = points_[indices[i]];
00505             for (size_t j=0; j<veclen_; ++j) {
00506                 mean[j] += vec[j];
00507             }
00508         }
00509         DistanceType div_factor = DistanceType(1)/size;
00510         for (size_t j=0; j<veclen_; ++j) {
00511             mean[j] *= div_factor;
00512         }
00513         
00514         DistanceType radius = 0;
00515         DistanceType variance = 0;
00516         for (size_t i=0; i<size; ++i) {
00517             DistanceType dist = distance_(mean, points_[indices[i]], veclen_);
00518             if (dist>radius) {
00519                 radius = dist;
00520             }
00521             variance += dist;
00522         }        
00523         variance /= size;
00524 
00525         node->variance = variance;
00526         node->radius = radius;
00527         node->pivot = mean;
00528     }
00529 
00530 
00542     void computeClustering(NodePtr node, int* indices, int indices_length, int branching)
00543     {
00544         node->size = indices_length;
00545 
00546         if (indices_length < branching) {
00547             node->points.resize(indices_length);
00548             for (int i=0;i<indices_length;++i) {
00549                 node->points[i].index = indices[i];
00550                 node->points[i].point = points_[indices[i]];
00551             }
00552             node->childs.clear();
00553             return;
00554         }
00555 
00556         std::vector<int> centers_idx(branching);
00557         int centers_length;
00558         (*chooseCenters_)(branching, indices, indices_length, &centers_idx[0], centers_length);
00559 
00560         if (centers_length<branching) {
00561             node->points.resize(indices_length);
00562             for (int i=0;i<indices_length;++i) {
00563                 node->points[i].index = indices[i];
00564                 node->points[i].point = points_[indices[i]];
00565             }
00566             node->childs.clear();
00567             return;
00568         }
00569 
00570 
00571         Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
00572         for (int i=0; i<centers_length; ++i) {
00573             ElementType* vec = points_[centers_idx[i]];
00574             for (size_t k=0; k<veclen_; ++k) {
00575                 dcenters[i][k] = double(vec[k]);
00576             }
00577         }
00578 
00579         std::vector<DistanceType> radiuses(branching,0);
00580         std::vector<int> count(branching,0);
00581 
00582         //      assign points to clusters
00583         std::vector<int> belongs_to(indices_length);
00584         for (int i=0; i<indices_length; ++i) {
00585 
00586             DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
00587             belongs_to[i] = 0;
00588             for (int j=1; j<branching; ++j) {
00589                 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
00590                 if (sq_dist>new_sq_dist) {
00591                     belongs_to[i] = j;
00592                     sq_dist = new_sq_dist;
00593                 }
00594             }
00595             if (sq_dist>radiuses[belongs_to[i]]) {
00596                 radiuses[belongs_to[i]] = sq_dist;
00597             }
00598             count[belongs_to[i]]++;
00599         }
00600 
00601         bool converged = false;
00602         int iteration = 0;
00603         while (!converged && iteration<iterations_) {
00604             converged = true;
00605             iteration++;
00606 
00607             // compute the new cluster centers
00608             for (int i=0; i<branching; ++i) {
00609                 memset(dcenters[i],0,sizeof(double)*veclen_);
00610                 radiuses[i] = 0;
00611             }
00612             for (int i=0; i<indices_length; ++i) {
00613                 ElementType* vec = points_[indices[i]];
00614                 double* center = dcenters[belongs_to[i]];
00615                 for (size_t k=0; k<veclen_; ++k) {
00616                     center[k] += vec[k];
00617                 }
00618             }
00619             for (int i=0; i<branching; ++i) {
00620                 int cnt = count[i];
00621                 double div_factor = 1.0/cnt;
00622                 for (size_t k=0; k<veclen_; ++k) {
00623                     dcenters[i][k] *= div_factor;
00624                 }
00625             }
00626 
00627             // reassign points to clusters
00628             for (int i=0; i<indices_length; ++i) {
00629                 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
00630                 int new_centroid = 0;
00631                 for (int j=1; j<branching; ++j) {
00632                     DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
00633                     if (sq_dist>new_sq_dist) {
00634                         new_centroid = j;
00635                         sq_dist = new_sq_dist;
00636                     }
00637                 }
00638                 if (sq_dist>radiuses[new_centroid]) {
00639                     radiuses[new_centroid] = sq_dist;
00640                 }
00641                 if (new_centroid != belongs_to[i]) {
00642                     count[belongs_to[i]]--;
00643                     count[new_centroid]++;
00644                     belongs_to[i] = new_centroid;
00645 
00646                     converged = false;
00647                 }
00648             }
00649 
00650             for (int i=0; i<branching; ++i) {
00651                 // if one cluster converges to an empty cluster,
00652                 // move an element into that cluster
00653                 if (count[i]==0) {
00654                     int j = (i+1)%branching;
00655                     while (count[j]<=1) {
00656                         j = (j+1)%branching;
00657                     }
00658 
00659                     for (int k=0; k<indices_length; ++k) {
00660                         if (belongs_to[k]==j) {
00661                             belongs_to[k] = i;
00662                             count[j]--;
00663                             count[i]++;
00664                             break;
00665                         }
00666                     }
00667                     converged = false;
00668                 }
00669             }
00670 
00671         }
00672 
00673         std::vector<DistanceType*> centers(branching);
00674 
00675         for (int i=0; i<branching; ++i) {
00676             centers[i] = new DistanceType[veclen_];
00677             memoryCounter_ += veclen_*sizeof(DistanceType);
00678             for (size_t k=0; k<veclen_; ++k) {
00679                 centers[i][k] = (DistanceType)dcenters[i][k];
00680             }
00681         }
00682 
00683 
00684         // compute kmeans clustering for each of the resulting clusters
00685         node->childs.resize(branching);
00686         int start = 0;
00687         int end = start;
00688         for (int c=0; c<branching; ++c) {
00689             int s = count[c];
00690 
00691             DistanceType variance = 0;
00692             for (int i=0; i<indices_length; ++i) {
00693                 if (belongs_to[i]==c) {
00694                     variance += distance_(centers[c], points_[indices[i]], veclen_);
00695                     std::swap(indices[i],indices[end]);
00696                     std::swap(belongs_to[i],belongs_to[end]);
00697                     end++;
00698                 }
00699             }
00700             variance /= s;
00701 
00702             node->childs[c] = new(pool_) Node();
00703             node->childs[c]->radius = radiuses[c];
00704             node->childs[c]->pivot = centers[c];
00705             node->childs[c]->variance = variance;
00706             computeClustering(node->childs[c],indices+start, end-start, branching);
00707             start=end;
00708         }
00709 
00710         delete[] dcenters.ptr();
00711     }
00712 
00713 
00714     template<bool with_removed>
00715     void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00716     {
00717 
00718         int maxChecks = searchParams.checks;
00719 
00720         if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00721             findExactNN<with_removed>(root_, result, vec);
00722         }
00723         else {
00724             // Priority queue storing intermediate branches in the best-bin-first search
00725             Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00726 
00727             int checks = 0;
00728             findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
00729 
00730             BranchSt branch;
00731             while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
00732                 NodePtr node = branch.node;
00733                 findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
00734             }
00735 
00736             delete heap;
00737         }
00738 
00739     }
00740 
00741 
00754     template<bool with_removed>
00755     void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
00756                 Heap<BranchSt>* heap) const
00757     {
00758         // Ignore those clusters that are too far away
00759         {
00760             DistanceType bsq = distance_(vec, node->pivot, veclen_);
00761             DistanceType rsq = node->radius;
00762             DistanceType wsq = result.worstDist();
00763 
00764             DistanceType val = bsq-rsq-wsq;
00765             DistanceType val2 = val*val-4*rsq*wsq;
00766 
00767             //if (val>0) {
00768             if ((val>0)&&(val2>0)) {
00769                 return;
00770             }
00771         }
00772 
00773         if (node->childs.empty()) {
00774             if (checks>=maxChecks) {
00775                 if (result.full()) return;
00776             }
00777             for (int i=0; i<node->size; ++i) {
00778                 PointInfo& point_info = node->points[i];
00779                 int index = point_info.index;
00780                 if (with_removed) {
00781                         if (removed_points_.test(index)) continue;
00782                 }
00783                 DistanceType dist = distance_(point_info.point, vec, veclen_);
00784                 result.addPoint(dist, index);
00785                 ++checks;
00786             }
00787         }
00788         else {
00789             int closest_center = exploreNodeBranches(node, vec, heap);
00790             findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
00791         }
00792     }
00793 
00802     int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const
00803     {
00804         std::vector<DistanceType> domain_distances(branching_);
00805         int best_index = 0;
00806         domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_);
00807         for (int i=1; i<branching_; ++i) {
00808             domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_);
00809             if (domain_distances[i]<domain_distances[best_index]) {
00810                 best_index = i;
00811             }
00812         }
00813 
00814         //              float* best_center = node->childs[best_index]->pivot;
00815         for (int i=0; i<branching_; ++i) {
00816             if (i != best_index) {
00817                 domain_distances[i] -= cb_index_*node->childs[i]->variance;
00818 
00819                 //                              float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q);
00820                 //                              if (domain_distances[i]<dist_to_border) {
00821                 //                                      domain_distances[i] = dist_to_border;
00822                 //                              }
00823                 heap->insert(BranchSt(node->childs[i],domain_distances[i]));
00824             }
00825         }
00826 
00827         return best_index;
00828     }
00829 
00830 
00834     template<bool with_removed>
00835     void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const
00836     {
00837         // Ignore those clusters that are too far away
00838         {
00839             DistanceType bsq = distance_(vec, node->pivot, veclen_);
00840             DistanceType rsq = node->radius;
00841             DistanceType wsq = result.worstDist();
00842 
00843             DistanceType val = bsq-rsq-wsq;
00844             DistanceType val2 = val*val-4*rsq*wsq;
00845 
00846             //                  if (val>0) {
00847             if ((val>0)&&(val2>0)) {
00848                 return;
00849             }
00850         }
00851 
00852         if (node->childs.empty()) {
00853             for (int i=0; i<node->size; ++i) {
00854                 PointInfo& point_info = node->points[i];
00855                 int index = point_info.index;
00856                 if (with_removed) {
00857                         if (removed_points_.test(index)) continue;
00858                 }
00859                 DistanceType dist = distance_(point_info.point, vec, veclen_);
00860                 result.addPoint(dist, index);
00861             }
00862         }
00863         else {
00864             std::vector<int> sort_indices(branching_);
00865             getCenterOrdering(node, vec, sort_indices);
00866 
00867             for (int i=0; i<branching_; ++i) {
00868                 findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec);
00869             }
00870 
00871         }
00872     }
00873 
00874 
00880     void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const
00881     {
00882         std::vector<DistanceType> domain_distances(branching_);
00883         for (int i=0; i<branching_; ++i) {
00884             DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_);
00885 
00886             int j=0;
00887             while (domain_distances[j]<dist && j<i) j++;
00888             for (int k=i; k>j; --k) {
00889                 domain_distances[k] = domain_distances[k-1];
00890                 sort_indices[k] = sort_indices[k-1];
00891             }
00892             domain_distances[j] = dist;
00893             sort_indices[j] = i;
00894         }
00895     }
00896 
00902     DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) const
00903     {
00904         DistanceType sum = 0;
00905         DistanceType sum2 = 0;
00906 
00907         for (int i=0; i<veclen_; ++i) {
00908             DistanceType t = c[i]-p[i];
00909             sum += t*(q[i]-(c[i]+p[i])/2);
00910             sum2 += t*t;
00911         }
00912 
00913         return sum*sum/sum2;
00914     }
00915 
00916 
00926     int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const
00927     {
00928         int clusterCount = 1;
00929         clusters[0] = root;
00930 
00931         DistanceType meanVariance = root->variance*root->size;
00932 
00933         while (clusterCount<clusters_length) {
00934             DistanceType minVariance = (std::numeric_limits<DistanceType>::max)();
00935             int splitIndex = -1;
00936 
00937             for (int i=0; i<clusterCount; ++i) {
00938                 if (!clusters[i]->childs.empty()) {
00939 
00940                     DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
00941 
00942                     for (int j=0; j<branching_; ++j) {
00943                         variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
00944                     }
00945                     if (variance<minVariance) {
00946                         minVariance = variance;
00947                         splitIndex = i;
00948                     }
00949                 }
00950             }
00951 
00952             if (splitIndex==-1) break;
00953             if ( (branching_+clusterCount-1) > clusters_length) break;
00954 
00955             meanVariance = minVariance;
00956 
00957             // split node
00958             NodePtr toSplit = clusters[splitIndex];
00959             clusters[splitIndex] = toSplit->childs[0];
00960             for (int i=1; i<branching_; ++i) {
00961                 clusters[clusterCount++] = toSplit->childs[i];
00962             }
00963         }
00964 
00965         varianceValue = meanVariance/root->size;
00966         return clusterCount;
00967     }
00968     
00969     void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
00970     {
00971         ElementType* point = points_[index];
00972         if (dist_to_pivot>node->radius) {
00973             node->radius = dist_to_pivot;
00974         }
00975         // if radius changed above, the variance will be an approximation
00976         node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
00977         node->size++;
00978         
00979         if (node->childs.empty()) { // leaf node
00980                 PointInfo point_info;
00981                 point_info.index = index;
00982                 point_info.point = point;
00983                 node->points.push_back(point_info);
00984 
00985             std::vector<int> indices(node->points.size());
00986             for (size_t i=0;i<node->points.size();++i) {
00987                 indices[i] = node->points[i].index;
00988             }
00989             computeNodeStatistics(node, indices);
00990             if (indices.size()>=size_t(branching_)) {
00991                 computeClustering(node, &indices[0], indices.size(), branching_);
00992             }
00993         }
00994         else {            
00995             // find the closest child
00996             int closest = 0;
00997             DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_);
00998             for (size_t i=1;i<size_t(branching_);++i) {
00999                 DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_);
01000                 if (crt_dist<dist) {
01001                     dist = crt_dist;
01002                     closest = i;
01003                 }
01004             }
01005             addPointToTree(node->childs[closest], index, dist);
01006         }                
01007     }
01008 
01009 
01010     void swap(KMeansIndex& other)
01011     {
01012         std::swap(branching_, other.branching_);
01013         std::swap(iterations_, other.iterations_);
01014         std::swap(centers_init_, other.centers_init_);
01015         std::swap(cb_index_, other.cb_index_);
01016         std::swap(root_, other.root_);
01017         std::swap(pool_, other.pool_);
01018         std::swap(memoryCounter_, other.memoryCounter_);
01019         std::swap(chooseCenters_, other.chooseCenters_);
01020     }
01021 
01022 
01023 private:
01025     int branching_;
01026 
01028     int iterations_;
01029 
01031     flann_centers_init_t centers_init_;
01032 
01039     float cb_index_;
01040     
01044     NodePtr root_;
01045 
01049     PooledAllocator pool_;
01050 
01054     int memoryCounter_;
01055 
01059     CenterChooser<Distance>* chooseCenters_;
01060 
01061     USING_BASECLASS_SYMBOLS
01062 };
01063 
01064 }
01065 
01066 #endif //FLANN_KMEANS_INDEX_H_


rtabmap
Author(s): Mathieu Labbe
autogenerated on Thu Jun 6 2019 21:59:20