00001
00002
00003
00004
00005
00006
00007
00008 #ifndef RTABMAP_CENTER_CHOOSER_H_
00009 #define RTABMAP_CENTER_CHOOSER_H_
00010
00011 #include "rtflann/util/matrix.h"
00012
00013 namespace rtflann
00014 {
00015
00016 template <typename Distance, typename ElementType>
00017 struct squareDistance
00018 {
00019 typedef typename Distance::ResultType ResultType;
00020 ResultType operator()( ResultType dist ) { return dist*dist; }
00021 };
00022
00023
00024 template <typename ElementType>
00025 struct squareDistance<L2_Simple<ElementType>, ElementType>
00026 {
00027 typedef typename L2_Simple<ElementType>::ResultType ResultType;
00028 ResultType operator()( ResultType dist ) { return dist; }
00029 };
00030
00031 template <typename ElementType>
00032 struct squareDistance<L2_3D<ElementType>, ElementType>
00033 {
00034 typedef typename L2_3D<ElementType>::ResultType ResultType;
00035 ResultType operator()( ResultType dist ) { return dist; }
00036 };
00037
00038 template <typename ElementType>
00039 struct squareDistance<L2<ElementType>, ElementType>
00040 {
00041 typedef typename L2<ElementType>::ResultType ResultType;
00042 ResultType operator()( ResultType dist ) { return dist; }
00043 };
00044
00045
00046 template <typename ElementType>
00047 struct squareDistance<HellingerDistance<ElementType>, ElementType>
00048 {
00049 typedef typename HellingerDistance<ElementType>::ResultType ResultType;
00050 ResultType operator()( ResultType dist ) { return dist; }
00051 };
00052
00053
00054 template <typename ElementType>
00055 struct squareDistance<ChiSquareDistance<ElementType>, ElementType>
00056 {
00057 typedef typename ChiSquareDistance<ElementType>::ResultType ResultType;
00058 ResultType operator()( ResultType dist ) { return dist; }
00059 };
00060
00061
00062 template <typename Distance>
00063 typename Distance::ResultType ensureSquareDistance( typename Distance::ResultType dist )
00064 {
00065 typedef typename Distance::ElementType ElementType;
00066
00067 squareDistance<Distance, ElementType> dummy;
00068 return dummy( dist );
00069 }
00070
00071
00072
00073 template <typename Distance>
00074 class CenterChooser
00075 {
00076 public:
00077 typedef typename Distance::ElementType ElementType;
00078 typedef typename Distance::ResultType DistanceType;
00079
00080 CenterChooser(const Distance& distance, const std::vector<ElementType*>& points) : distance_(distance), points_(points) {};
00081
00082 virtual ~CenterChooser() {};
00083
00084 void setDataSize(size_t cols) { cols_ = cols; }
00085
00095 virtual void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length) = 0;
00096
00097 protected:
00098 const Distance distance_;
00099 const std::vector<ElementType*>& points_;
00100 size_t cols_;
00101 };
00102
00103
00104 template <typename Distance>
00105 class RandomCenterChooser : public CenterChooser<Distance>
00106 {
00107 public:
00108 typedef typename Distance::ElementType ElementType;
00109 typedef typename Distance::ResultType DistanceType;
00110 using CenterChooser<Distance>::points_;
00111 using CenterChooser<Distance>::distance_;
00112 using CenterChooser<Distance>::cols_;
00113
00114 RandomCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
00115 CenterChooser<Distance>(distance, points) {}
00116
00117 void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
00118 {
00119 UniqueRandom r(indices_length);
00120
00121 int index;
00122 for (index=0; index<k; ++index) {
00123 bool duplicate = true;
00124 int rnd;
00125 while (duplicate) {
00126 duplicate = false;
00127 rnd = r.next();
00128 if (rnd<0) {
00129 centers_length = index;
00130 return;
00131 }
00132
00133 centers[index] = indices[rnd];
00134
00135 for (int j=0; j<index; ++j) {
00136 DistanceType sq = distance_(points_[centers[index]], points_[centers[j]], cols_);
00137 if (sq<1e-16) {
00138 duplicate = true;
00139 }
00140 }
00141 }
00142 }
00143
00144 centers_length = index;
00145 }
00146 };
00147
00148
00149
00153 template <typename Distance>
00154 class GonzalesCenterChooser : public CenterChooser<Distance>
00155 {
00156 public:
00157 typedef typename Distance::ElementType ElementType;
00158 typedef typename Distance::ResultType DistanceType;
00159
00160 using CenterChooser<Distance>::points_;
00161 using CenterChooser<Distance>::distance_;
00162 using CenterChooser<Distance>::cols_;
00163
00164 GonzalesCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
00165 CenterChooser<Distance>(distance, points) {}
00166
00167 void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
00168 {
00169 int n = indices_length;
00170
00171 int rnd = rand_int(n);
00172 assert(rnd >=0 && rnd < n);
00173
00174 centers[0] = indices[rnd];
00175
00176 int index;
00177 for (index=1; index<k; ++index) {
00178
00179 int best_index = -1;
00180 DistanceType best_val = 0;
00181 for (int j=0; j<n; ++j) {
00182 DistanceType dist = distance_(points_[centers[0]],points_[indices[j]],cols_);
00183 for (int i=1; i<index; ++i) {
00184 DistanceType tmp_dist = distance_(points_[centers[i]],points_[indices[j]],cols_);
00185 if (tmp_dist<dist) {
00186 dist = tmp_dist;
00187 }
00188 }
00189 if (dist>best_val) {
00190 best_val = dist;
00191 best_index = j;
00192 }
00193 }
00194 if (best_index!=-1) {
00195 centers[index] = indices[best_index];
00196 }
00197 else {
00198 break;
00199 }
00200 }
00201 centers_length = index;
00202 }
00203 };
00204
00205
00210 template <typename Distance>
00211 class KMeansppCenterChooser : public CenterChooser<Distance>
00212 {
00213 public:
00214 typedef typename Distance::ElementType ElementType;
00215 typedef typename Distance::ResultType DistanceType;
00216
00217 using CenterChooser<Distance>::points_;
00218 using CenterChooser<Distance>::distance_;
00219 using CenterChooser<Distance>::cols_;
00220
00221 KMeansppCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
00222 CenterChooser<Distance>(distance, points) {}
00223
00224 void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
00225 {
00226 int n = indices_length;
00227
00228 double currentPot = 0;
00229 DistanceType* closestDistSq = new DistanceType[n];
00230
00231
00232 int index = rand_int(n);
00233 assert(index >=0 && index < n);
00234 centers[0] = indices[index];
00235
00236
00237
00238 for (int i = 0; i < n; i++) {
00239 closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
00240 closestDistSq[i] = ensureSquareDistance<Distance>( closestDistSq[i] );
00241 currentPot += closestDistSq[i];
00242 }
00243
00244
00245 const int numLocalTries = 1;
00246
00247
00248 int centerCount;
00249 for (centerCount = 1; centerCount < k; centerCount++) {
00250
00251
00252 double bestNewPot = -1;
00253 int bestNewIndex = 0;
00254 for (int localTrial = 0; localTrial < numLocalTries; localTrial++) {
00255
00256
00257
00258 double randVal = rand_double(currentPot);
00259 for (index = 0; index < n-1; index++) {
00260 if (randVal <= closestDistSq[index]) break;
00261 else randVal -= closestDistSq[index];
00262 }
00263
00264
00265 double newPot = 0;
00266 for (int i = 0; i < n; i++) {
00267 DistanceType dist = distance_(points_[indices[i]], points_[indices[index]], cols_);
00268 newPot += std::min( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
00269 }
00270
00271
00272 if ((bestNewPot < 0)||(newPot < bestNewPot)) {
00273 bestNewPot = newPot;
00274 bestNewIndex = index;
00275 }
00276 }
00277
00278
00279 centers[centerCount] = indices[bestNewIndex];
00280 currentPot = bestNewPot;
00281 for (int i = 0; i < n; i++) {
00282 DistanceType dist = distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_);
00283 closestDistSq[i] = std::min( ensureSquareDistance<Distance>(dist), closestDistSq[i] );
00284 }
00285 }
00286
00287 centers_length = centerCount;
00288
00289 delete[] closestDistSq;
00290 }
00291 };
00292
00293
00294
00306 template <typename Distance>
00307 class GroupWiseCenterChooser : public CenterChooser<Distance>
00308 {
00309 public:
00310 typedef typename Distance::ElementType ElementType;
00311 typedef typename Distance::ResultType DistanceType;
00312
00313 using CenterChooser<Distance>::points_;
00314 using CenterChooser<Distance>::distance_;
00315 using CenterChooser<Distance>::cols_;
00316
00317 GroupWiseCenterChooser(const Distance& distance, const std::vector<ElementType*>& points) :
00318 CenterChooser<Distance>(distance, points) {}
00319
00320 void operator()(int k, int* indices, int indices_length, int* centers, int& centers_length)
00321 {
00322 const float kSpeedUpFactor = 1.3f;
00323
00324 int n = indices_length;
00325
00326 DistanceType* closestDistSq = new DistanceType[n];
00327
00328
00329 int index = rand_int(n);
00330 assert(index >=0 && index < n);
00331 centers[0] = indices[index];
00332
00333 for (int i = 0; i < n; i++) {
00334 closestDistSq[i] = distance_(points_[indices[i]], points_[indices[index]], cols_);
00335 }
00336
00337
00338
00339 int centerCount;
00340 for (centerCount = 1; centerCount < k; centerCount++) {
00341
00342
00343 double bestNewPot = -1;
00344 int bestNewIndex = 0;
00345 DistanceType furthest = 0;
00346 for (index = 0; index < n; index++) {
00347
00348
00349 if( closestDistSq[index] > kSpeedUpFactor * (float)furthest ) {
00350
00351
00352 double newPot = 0;
00353 for (int i = 0; i < n; i++) {
00354 newPot += std::min( distance_(points_[indices[i]], points_[indices[index]], cols_)
00355 , closestDistSq[i] );
00356 }
00357
00358
00359 if ((bestNewPot < 0)||(newPot <= bestNewPot)) {
00360 bestNewPot = newPot;
00361 bestNewIndex = index;
00362 furthest = closestDistSq[index];
00363 }
00364 }
00365 }
00366
00367
00368 centers[centerCount] = indices[bestNewIndex];
00369 for (int i = 0; i < n; i++) {
00370 closestDistSq[i] = std::min( distance_(points_[indices[i]], points_[indices[bestNewIndex]], cols_)
00371 , closestDistSq[i] );
00372 }
00373 }
00374
00375 centers_length = centerCount;
00376
00377 delete[] closestDistSq;
00378 }
00379 };
00380
00381
00382 }
00383
00384
00385 #endif