hierarchical_clustering_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2011  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2011  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 #ifndef RTABMAP_FLANN_HIERARCHICAL_CLUSTERING_INDEX_H_
00032 #define RTABMAP_FLANN_HIERARCHICAL_CLUSTERING_INDEX_H_
00033 
00034 #include <algorithm>
00035 #include <string>
00036 #include <map>
00037 #include <cassert>
00038 #include <limits>
00039 #include <cmath>
00040 
00041 #ifndef SIZE_MAX
00042 #define SIZE_MAX ((size_t) -1)
00043 #endif
00044 
00045 #include "rtflann/general.h"
00046 #include "rtflann/algorithms/nn_index.h"
00047 #include "rtflann/algorithms/dist.h"
00048 #include "rtflann/util/matrix.h"
00049 #include "rtflann/util/result_set.h"
00050 #include "rtflann/util/heap.h"
00051 #include "rtflann/util/allocator.h"
00052 #include "rtflann/util/random.h"
00053 #include "rtflann/util/saving.h"
00054 #include "rtflann/util/serialization.h"
00055 
00056 namespace rtflann
00057 {
00058 
00059 struct HierarchicalClusteringIndexParams : public IndexParams
00060 {
00061     HierarchicalClusteringIndexParams(int branching = 32,
00062                                       flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM,
00063                                       int trees = 4, int leaf_max_size = 100)
00064     {
00065         (*this)["algorithm"] = FLANN_INDEX_HIERARCHICAL;
00066         // The branching factor used in the hierarchical clustering
00067         (*this)["branching"] = branching;
00068         // Algorithm used for picking the initial cluster centers
00069         (*this)["centers_init"] = centers_init;
00070         // number of parallel trees to build
00071         (*this)["trees"] = trees;
00072         // maximum leaf size
00073         (*this)["leaf_max_size"] = leaf_max_size;
00074     }
00075 };
00076 
00077 
00078 
00085 template <typename Distance>
00086 class HierarchicalClusteringIndex : public NNIndex<Distance>
00087 {
00088 public:
00089     typedef typename Distance::ElementType ElementType;
00090     typedef typename Distance::ResultType DistanceType;
00091 
00092     typedef NNIndex<Distance> BaseClass;
00093 
00100     HierarchicalClusteringIndex(const IndexParams& index_params = HierarchicalClusteringIndexParams(), Distance d = Distance())
00101         : BaseClass(index_params, d)
00102     {
00103         memoryCounter_ = 0;
00104 
00105         branching_ = get_param(index_params_,"branching",32);
00106         centers_init_ = get_param(index_params_,"centers_init", FLANN_CENTERS_RANDOM);
00107         trees_ = get_param(index_params_,"trees",4);
00108         leaf_max_size_ = get_param(index_params_,"leaf_max_size",100);
00109 
00110         initCenterChooser();
00111     }
00112 
00113 
00121     HierarchicalClusteringIndex(const Matrix<ElementType>& inputData, const IndexParams& index_params = HierarchicalClusteringIndexParams(),
00122                                 Distance d = Distance())
00123         : BaseClass(index_params, d)
00124     {
00125         memoryCounter_ = 0;
00126 
00127         branching_ = get_param(index_params_,"branching",32);
00128         centers_init_ = get_param(index_params_,"centers_init", FLANN_CENTERS_RANDOM);
00129         trees_ = get_param(index_params_,"trees",4);
00130         leaf_max_size_ = get_param(index_params_,"leaf_max_size",100);
00131 
00132         initCenterChooser();
00133         
00134         setDataset(inputData);
00135 
00136         chooseCenters_->setDataSize(veclen_);
00137     }
00138 
00139 
00140     HierarchicalClusteringIndex(const HierarchicalClusteringIndex& other) : BaseClass(other),
00141                 memoryCounter_(other.memoryCounter_),
00142                 branching_(other.branching_),
00143                 trees_(other.trees_),
00144                 centers_init_(other.centers_init_),
00145                 leaf_max_size_(other.leaf_max_size_)
00146 
00147     {
00148         initCenterChooser();
00149         tree_roots_.resize(other.tree_roots_.size());
00150         for (size_t i=0;i<tree_roots_.size();++i) {
00151                 copyTree(tree_roots_[i], other.tree_roots_[i]);
00152         }
00153     }
00154 
00155     HierarchicalClusteringIndex& operator=(HierarchicalClusteringIndex other)
00156     {
00157         this->swap(other);
00158         return *this;
00159     }
00160 
00161 
00162     void initCenterChooser()
00163     {
00164         switch(centers_init_) {
00165         case FLANN_CENTERS_RANDOM:
00166                 chooseCenters_ = new RandomCenterChooser<Distance>(distance_, points_);
00167                 break;
00168         case FLANN_CENTERS_GONZALES:
00169                 chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_, points_);
00170                 break;
00171         case FLANN_CENTERS_KMEANSPP:
00172             chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_, points_);
00173                 break;
00174         case FLANN_CENTERS_GROUPWISE:
00175             chooseCenters_ = new GroupWiseCenterChooser<Distance>(distance_, points_);
00176             break;
00177         default:
00178             throw FLANNException("Unknown algorithm for choosing initial centers.");
00179         }
00180     }
00181 
00187     virtual ~HierarchicalClusteringIndex()
00188     {
00189         delete chooseCenters_;
00190         freeIndex();
00191     }
00192 
00193     BaseClass* clone() const
00194     {
00195         return new HierarchicalClusteringIndex(*this);
00196     }
00197 
00202     int usedMemory() const
00203     {
00204         return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
00205     }
00206     
00207     using BaseClass::buildIndex;
00208 
00209     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
00210     {
00211         assert(points.cols==veclen_);
00212         size_t old_size = size_;
00213 
00214         extendDataset(points);
00215         
00216         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
00217             buildIndex();
00218         }
00219         else {
00220             for (size_t i=0;i<points.rows;++i) {
00221                 for (int j = 0; j < trees_; j++) {
00222                     addPointToTree(tree_roots_[j], old_size + i);
00223                 }
00224             }            
00225         }
00226     }
00227 
00228 
00229     flann_algorithm_t getType() const
00230     {
00231         return FLANN_INDEX_HIERARCHICAL;
00232     }
00233 
00234 
00235     template<typename Archive>
00236     void serialize(Archive& ar)
00237     {
00238         ar.setObject(this);
00239 
00240         ar & *static_cast<NNIndex<Distance>*>(this);
00241 
00242         ar & branching_;
00243         ar & trees_;
00244         ar & centers_init_;
00245         ar & leaf_max_size_;
00246 
00247         if (Archive::is_loading::value) {
00248                 tree_roots_.resize(trees_);
00249         }
00250         for (size_t i=0;i<tree_roots_.size();++i) {
00251                 if (Archive::is_loading::value) {
00252                         tree_roots_[i] = new(pool_) Node();
00253                 }
00254                 ar & *tree_roots_[i];
00255         }
00256 
00257         if (Archive::is_loading::value) {
00258             index_params_["algorithm"] = getType();
00259             index_params_["branching"] = branching_;
00260             index_params_["trees"] = trees_;
00261             index_params_["centers_init"] = centers_init_;
00262             index_params_["leaf_size"] = leaf_max_size_;
00263         }
00264     }
00265 
00266     void saveIndex(FILE* stream)
00267     {
00268         serialization::SaveArchive sa(stream);
00269         sa & *this;
00270     }
00271 
00272 
00273     void loadIndex(FILE* stream)
00274     {
00275         serialization::LoadArchive la(stream);
00276         la & *this;
00277     }
00278 
00279 
00290     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00291     {
00292         if (removed_) {
00293                 findNeighborsWithRemoved<true>(result, vec, searchParams);
00294         }
00295         else {
00296                 findNeighborsWithRemoved<false>(result, vec, searchParams);
00297         }
00298     }
00299 
00300 protected:
00301 
00305     void buildIndexImpl()
00306     {
00307         chooseCenters_->setDataSize(veclen_);
00308 
00309         if (branching_<2) {
00310             throw FLANNException("Branching factor must be at least 2");
00311         }
00312         tree_roots_.resize(trees_);
00313         std::vector<int> indices(size_);
00314         for (int i=0; i<trees_; ++i) {
00315             for (size_t j=0; j<size_; ++j) {
00316                 indices[j] = j;
00317             }
00318             tree_roots_[i] = new(pool_) Node();
00319             computeClustering(tree_roots_[i], &indices[0], size_);
00320         }
00321     }
00322 
00323 private:
00324 
00325     struct PointInfo
00326     {
00328         size_t index;
00330         ElementType* point;
00331 
00332     private:
00333         template<typename Archive>
00334         void serialize(Archive& ar)
00335         {
00336                 typedef HierarchicalClusteringIndex<Distance> Index;
00337                 Index* obj = static_cast<Index*>(ar.getObject());
00338 
00339                 ar & index;
00340 //              ar & point;
00341 
00342                         if (Archive::is_loading::value) {
00343                                 point = obj->points_[index];
00344                         }
00345         }
00346         friend struct serialization::access;
00347     };
00348 
00352     struct Node
00353     {
00357         ElementType* pivot;
00358         size_t pivot_index;
00362         std::vector<Node*> childs;
00366         std::vector<PointInfo> points;
00367 
00368                 Node(){
00369                         pivot = NULL;
00370                         pivot_index = SIZE_MAX;
00371                 }
00376         ~Node()
00377         {
00378                 for(size_t i=0; i<childs.size(); i++){
00379                         childs[i]->~Node();
00380                                 pivot = NULL;
00381                                 pivot_index = -1;
00382                 }
00383         };
00384 
00385     private:
00386         template<typename Archive>
00387         void serialize(Archive& ar)
00388         {
00389                 typedef HierarchicalClusteringIndex<Distance> Index;
00390                 Index* obj = static_cast<Index*>(ar.getObject());
00391                 ar & pivot_index;
00392                 if (Archive::is_loading::value) {
00393                                 if (pivot_index != SIZE_MAX)
00394                                         pivot = obj->points_[pivot_index];
00395                                 else
00396                                         pivot = NULL;
00397                 }
00398                 size_t childs_size;
00399                 if (Archive::is_saving::value) {
00400                         childs_size = childs.size();
00401                 }
00402                 ar & childs_size;
00403 
00404                 if (childs_size==0) {
00405                         ar & points;
00406                 }
00407                 else {
00408                         if (Archive::is_loading::value) {
00409                                 childs.resize(childs_size);
00410                         }
00411                         for (size_t i=0;i<childs_size;++i) {
00412                                 if (Archive::is_loading::value) {
00413                                         childs[i] = new(obj->pool_) Node();
00414                                 }
00415                                 ar & *childs[i];
00416                         }
00417                 }
00418 
00419         }
00420         friend struct serialization::access;
00421     };
00422     typedef Node* NodePtr;
00423 
00424 
00425 
00429     typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00430 
00431 
00436     void freeIndex(){
00437         for (size_t i=0; i<tree_roots_.size(); ++i) {
00438                 tree_roots_[i]->~Node();
00439         }
00440         pool_.free();
00441     }
00442 
00443     void copyTree(NodePtr& dst, const NodePtr& src)
00444     {
00445         dst = new(pool_) Node();
00446         dst->pivot_index = src->pivot_index;
00447         dst->pivot = points_[dst->pivot_index];
00448 
00449         if (src->childs.size()==0) {
00450                 dst->points = src->points;
00451         }
00452         else {
00453                 dst->childs.resize(src->childs.size());
00454                 for (size_t i=0;i<src->childs.size();++i) {
00455                         copyTree(dst->childs[i], src->childs[i]);
00456                 }
00457         }
00458     }
00459 
00460 
00461 
00462     void computeLabels(int* indices, int indices_length,  int* centers, int centers_length, int* labels, DistanceType& cost)
00463     {
00464         cost = 0;
00465         for (int i=0; i<indices_length; ++i) {
00466             ElementType* point = points_[indices[i]];
00467             DistanceType dist = distance_(point, points_[centers[0]], veclen_);
00468             labels[i] = 0;
00469             for (int j=1; j<centers_length; ++j) {
00470                 DistanceType new_dist = distance_(point, points_[centers[j]], veclen_);
00471                 if (dist>new_dist) {
00472                     labels[i] = j;
00473                     dist = new_dist;
00474                 }
00475             }
00476             cost += dist;
00477         }
00478     }
00479 
00490     void computeClustering(NodePtr node, int* indices, int indices_length)
00491     {
00492         if (indices_length < leaf_max_size_) { // leaf node
00493             node->points.resize(indices_length);
00494             for (int i=0;i<indices_length;++i) {
00495                 node->points[i].index = indices[i];
00496                 node->points[i].point = points_[indices[i]];
00497             }
00498             node->childs.clear();
00499             return;
00500         }
00501 
00502         std::vector<int> centers(branching_);
00503         std::vector<int> labels(indices_length);
00504 
00505         int centers_length;
00506         (*chooseCenters_)(branching_, indices, indices_length, &centers[0], centers_length);
00507 
00508         if (centers_length<branching_) {
00509             node->points.resize(indices_length);
00510             for (int i=0;i<indices_length;++i) {
00511                 node->points[i].index = indices[i];
00512                 node->points[i].point = points_[indices[i]];
00513             }
00514             node->childs.clear();
00515             return;
00516         }
00517 
00518 
00519         //  assign points to clusters
00520         DistanceType cost;
00521         computeLabels(indices, indices_length, &centers[0], centers_length, &labels[0], cost);
00522 
00523         node->childs.resize(branching_);
00524         int start = 0;
00525         int end = start;
00526         for (int i=0; i<branching_; ++i) {
00527             for (int j=0; j<indices_length; ++j) {
00528                 if (labels[j]==i) {
00529                     std::swap(indices[j],indices[end]);
00530                     std::swap(labels[j],labels[end]);
00531                     end++;
00532                 }
00533             }
00534 
00535             node->childs[i] = new(pool_) Node();
00536             node->childs[i]->pivot_index = centers[i];
00537             node->childs[i]->pivot = points_[centers[i]];
00538             node->childs[i]->points.clear();
00539             computeClustering(node->childs[i],indices+start, end-start);
00540             start=end;
00541         }
00542     }
00543 
00544 
00545     template<bool with_removed>
00546     void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
00547     {
00548         int maxChecks = searchParams.checks;
00549 
00550         // Priority queue storing intermediate branches in the best-bin-first search
00551         Heap<BranchSt>* heap = new Heap<BranchSt>(size_);
00552 
00553         DynamicBitset checked(size_);
00554         int checks = 0;
00555         for (int i=0; i<trees_; ++i) {
00556             findNN<with_removed>(tree_roots_[i], result, vec, checks, maxChecks, heap, checked);
00557         }
00558 
00559         BranchSt branch;
00560         while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
00561             NodePtr node = branch.node;
00562             findNN<with_removed>(node, result, vec, checks, maxChecks, heap, checked);
00563         }
00564 
00565         delete heap;
00566     }
00567 
00568 
00581     template<bool with_removed>
00582     void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
00583                 Heap<BranchSt>* heap,  DynamicBitset& checked) const
00584     {
00585         if (node->childs.empty()) {
00586             if (checks>=maxChecks) {
00587                 if (result.full()) return;
00588             }
00589 
00590             for (size_t i=0; i<node->points.size(); ++i) {
00591                 PointInfo& pointInfo = node->points[i];
00592                 if (with_removed) {
00593                         if (removed_points_.test(pointInfo.index)) continue;
00594                 }
00595                 if (checked.test(pointInfo.index)) continue;
00596                 DistanceType dist = distance_(pointInfo.point, vec, veclen_);
00597                 result.addPoint(dist, pointInfo.index);
00598                 checked.set(pointInfo.index);
00599                 ++checks;
00600             }
00601         }
00602         else {
00603             DistanceType* domain_distances = new DistanceType[branching_];
00604             int best_index = 0;
00605             domain_distances[best_index] = distance_(vec, node->childs[best_index]->pivot, veclen_);
00606             for (int i=1; i<branching_; ++i) {
00607                 domain_distances[i] = distance_(vec, node->childs[i]->pivot, veclen_);
00608                 if (domain_distances[i]<domain_distances[best_index]) {
00609                     best_index = i;
00610                 }
00611             }
00612             for (int i=0; i<branching_; ++i) {
00613                 if (i!=best_index) {
00614                     heap->insert(BranchSt(node->childs[i],domain_distances[i]));
00615                 }
00616             }
00617             delete[] domain_distances;
00618             findNN<with_removed>(node->childs[best_index],result,vec, checks, maxChecks, heap, checked);
00619         }
00620     }
00621     
00622     void addPointToTree(NodePtr node, size_t index)
00623     {
00624         ElementType* point = points_[index];
00625         
00626         if (node->childs.empty()) { // leaf node
00627                 PointInfo pointInfo;
00628                 pointInfo.point = point;
00629                 pointInfo.index = index;
00630             node->points.push_back(pointInfo);
00631 
00632             if (node->points.size()>=size_t(branching_)) {
00633                 std::vector<int> indices(node->points.size());
00634 
00635                 for (size_t i=0;i<node->points.size();++i) {
00636                         indices[i] = node->points[i].index;
00637                 }
00638                 computeClustering(node, &indices[0], indices.size());
00639             }
00640         }
00641         else {            
00642             // find the closest child
00643             int closest = 0;
00644             ElementType* center = node->childs[closest]->pivot;
00645             DistanceType dist = distance_(center, point, veclen_);
00646             for (size_t i=1;i<size_t(branching_);++i) {
00647                 center = node->childs[i]->pivot;
00648                 DistanceType crt_dist = distance_(center, point, veclen_);
00649                 if (crt_dist<dist) {
00650                     dist = crt_dist;
00651                     closest = i;
00652                 }
00653             }
00654             addPointToTree(node->childs[closest], index);
00655         }                
00656     }
00657 
00658     void swap(HierarchicalClusteringIndex& other)
00659     {
00660         BaseClass::swap(other);
00661 
00662         std::swap(tree_roots_, other.tree_roots_);
00663         std::swap(pool_, other.pool_);
00664         std::swap(memoryCounter_, other.memoryCounter_);
00665         std::swap(branching_, other.branching_);
00666         std::swap(trees_, other.trees_);
00667         std::swap(centers_init_, other.centers_init_);
00668         std::swap(leaf_max_size_, other.leaf_max_size_);
00669         std::swap(chooseCenters_, other.chooseCenters_);
00670     }
00671 
00672 private:
00673 
00677     std::vector<Node*> tree_roots_;
00678 
00686     PooledAllocator pool_;
00687 
00691     int memoryCounter_;
00692 
00697     int branching_;
00698     
00702     int trees_;
00703     
00707     flann_centers_init_t centers_init_;
00708     
00712     int leaf_max_size_;
00713     
00717     CenterChooser<Distance>* chooseCenters_;
00718 
00719     USING_BASECLASS_SYMBOLS
00720 };
00721 
00722 }
00723 
00724 #endif /* FLANN_HIERARCHICAL_CLUSTERING_INDEX_H_ */


rtabmap
Author(s): Mathieu Labbe
autogenerated on Thu Jun 6 2019 21:59:20