00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef _OPENCV_KDTREE_H_
00032 #define _OPENCV_KDTREE_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038
00039 #include "opencv2/flann/general.h"
00040 #include "opencv2/flann/nn_index.h"
00041 #include "opencv2/flann/matrix.h"
00042 #include "opencv2/flann/result_set.h"
00043 #include "opencv2/flann/heap.h"
00044 #include "opencv2/flann/allocator.h"
00045 #include "opencv2/flann/random.h"
00046 #include "opencv2/flann/saving.h"
00047
00048 using namespace std;
00049
00050
00051 namespace cvflann
00052 {
00053
00054 struct CV_EXPORTS KDTreeIndexParams : public IndexParams {
00055 KDTreeIndexParams(int trees_ = 4) : IndexParams(KDTREE), trees(trees_) {};
00056
00057 int trees;
00058
00059 flann_algorithm_t getIndexType() const { return algorithm; }
00060
00061 void print() const
00062 {
00063 logger().info("Index type: %d\n",(int)algorithm);
00064 logger().info("Trees: %d\n", trees);
00065 }
00066
00067 };
00068
00069
00076 template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
00077 class KDTreeIndex : public NNIndex<ELEM_TYPE>
00078 {
00079
00080 enum {
00086 SAMPLE_MEAN = 100,
00094 RAND_DIM=5
00095 };
00096
00097
00101 int numTrees;
00102
00106 int* vind;
00107
00108
00112 const Matrix<ELEM_TYPE> dataset;
00113
00114 const IndexParams& index_params;
00115
00116 size_t size_;
00117 size_t veclen_;
00118
00119
00120 DIST_TYPE* mean;
00121 DIST_TYPE* var;
00122
00123
00124
00125
00132 struct TreeSt {
00138 int divfeat;
00142 DIST_TYPE divval;
00146 TreeSt *child1, *child2;
00147 };
00148 typedef TreeSt* Tree;
00149
00153 Tree* trees;
00154 typedef BranchStruct<Tree> BranchSt;
00155 typedef BranchSt* Branch;
00156
00164 PooledAllocator pool;
00165
00166
00167
00168 public:
00169
00170 flann_algorithm_t getType() const
00171 {
00172 return KDTREE;
00173 }
00174
00182 KDTreeIndex(const Matrix<ELEM_TYPE>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams() ) :
00183 dataset(inputData), index_params(params)
00184 {
00185 size_ = dataset.rows;
00186 veclen_ = dataset.cols;
00187
00188 numTrees = params.trees;
00189 trees = new Tree[numTrees];
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202 vind = new int[size_];
00203 for (size_t i = 0; i < size_; i++) {
00204 vind[i] = (int)i;
00205 }
00206
00207 mean = new DIST_TYPE[veclen_];
00208 var = new DIST_TYPE[veclen_];
00209 }
00210
00214 ~KDTreeIndex()
00215 {
00216 delete[] vind;
00217 if (trees!=NULL) {
00218 delete[] trees;
00219 }
00220 delete[] mean;
00221 delete[] var;
00222 }
00223
00224
00228 void buildIndex()
00229 {
00230
00231 for (int i = 0; i < numTrees; i++) {
00232
00233 for (int j = (int)size_; j > 0; --j) {
00234 int rnd = rand_int(j);
00235 swap(vind[j-1], vind[rnd]);
00236 }
00237 trees[i] = divideTree(0, (int)size_ - 1);
00238 }
00239 }
00240
00241
00242
00243 void saveIndex(FILE* stream)
00244 {
00245 save_value(stream, numTrees);
00246 for (int i=0;i<numTrees;++i) {
00247 save_tree(stream, trees[i]);
00248 }
00249 }
00250
00251
00252
00253 void loadIndex(FILE* stream)
00254 {
00255 load_value(stream, numTrees);
00256
00257 if (trees!=NULL) {
00258 delete[] trees;
00259 }
00260 trees = new Tree[numTrees];
00261 for (int i=0;i<numTrees;++i) {
00262 load_tree(stream,trees[i]);
00263 }
00264 }
00265
00266
00270 size_t size() const
00271 {
00272 return size_;
00273 }
00274
00278 size_t veclen() const
00279 {
00280 return veclen_;
00281 }
00282
00283
00288 int usedMemory() const
00289 {
00290 return (int)(pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int));
00291 }
00292
00293
00303 void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams)
00304 {
00305 int maxChecks = searchParams.checks;
00306
00307 if (maxChecks<0) {
00308 getExactNeighbors(result, vec);
00309 } else {
00310 getNeighbors(result, vec, maxChecks);
00311 }
00312 }
00313
00314 const IndexParams* getParameters() const
00315 {
00316 return &index_params;
00317 }
00318
00319 private:
00320
00321
00322 void save_tree(FILE* stream, Tree tree)
00323 {
00324 save_value(stream, *tree);
00325 if (tree->child1!=NULL) {
00326 save_tree(stream, tree->child1);
00327 }
00328 if (tree->child2!=NULL) {
00329 save_tree(stream, tree->child2);
00330 }
00331 }
00332
00333
00334 void load_tree(FILE* stream, Tree& tree)
00335 {
00336 tree = pool.allocate<TreeSt>();
00337 load_value(stream, *tree);
00338 if (tree->child1!=NULL) {
00339 load_tree(stream, tree->child1);
00340 }
00341 if (tree->child2!=NULL) {
00342 load_tree(stream, tree->child2);
00343 }
00344 }
00345
00346
00356 Tree divideTree(int first, int last)
00357 {
00358 Tree node = pool.allocate<TreeSt>();
00359
00360
00361 if (first == last) {
00362 node->child1 = node->child2 = NULL;
00363 node->divfeat = vind[first];
00364 }
00365 else {
00366 chooseDivision(node, first, last);
00367 subdivide(node, first, last);
00368 }
00369
00370 return node;
00371 }
00372
00373
00379 void chooseDivision(Tree node, int first, int last)
00380 {
00381 memset(mean,0,veclen_*sizeof(DIST_TYPE));
00382 memset(var,0,veclen_*sizeof(DIST_TYPE));
00383
00384
00385
00386
00387 int end = min(first + SAMPLE_MEAN, last);
00388 for (int j = first; j <= end; ++j) {
00389 ELEM_TYPE* v = dataset[vind[j]];
00390 for (size_t k=0; k<veclen_; ++k) {
00391 mean[k] += v[k];
00392 }
00393 }
00394 for (size_t k=0; k<veclen_; ++k) {
00395 mean[k] /= (end - first + 1);
00396 }
00397
00398
00399 for (int j = first; j <= end; ++j) {
00400 ELEM_TYPE* v = dataset[vind[j]];
00401 for (size_t k=0; k<veclen_; ++k) {
00402 DIST_TYPE dist = v[k] - mean[k];
00403 var[k] += dist * dist;
00404 }
00405 }
00406
00407 node->divfeat = selectDivision(var);
00408 node->divval = mean[node->divfeat];
00409
00410 }
00411
00412
00417 int selectDivision(DIST_TYPE* v)
00418 {
00419 int num = 0;
00420 int topind[RAND_DIM];
00421
00422
00423 for (size_t i = 0; i < veclen_; ++i) {
00424 if (num < RAND_DIM || v[i] > v[topind[num-1]]) {
00425
00426 if (num < RAND_DIM) {
00427 topind[num++] = (int)i;
00428 }
00429 else {
00430 topind[num-1] = (int)i;
00431 }
00432
00433 int j = num - 1;
00434 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
00435 swap(topind[j], topind[j-1]);
00436 --j;
00437 }
00438 }
00439 }
00440
00441 int rnd = rand_int(num);
00442 return topind[rnd];
00443 }
00444
00445
00450 void subdivide(Tree node, int first, int last)
00451 {
00452
00453 int i = first;
00454 int j = last;
00455 while (i <= j) {
00456 int ind = vind[i];
00457 ELEM_TYPE val = dataset[ind][node->divfeat];
00458 if (val < node->divval) {
00459 ++i;
00460 } else {
00461
00462 swap(vind[i], vind[j]);
00463 --j;
00464 }
00465 }
00466
00467
00468
00469
00470 if ( (i == first) || (i == last+1)) {
00471 i = (first+last+1)/2;
00472 }
00473
00474 node->child1 = divideTree(first, i - 1);
00475 node->child2 = divideTree(i, last);
00476 }
00477
00478
00479
00484 void getExactNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec)
00485 {
00486
00487
00488 if (numTrees > 1) {
00489 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00490 }
00491 if (numTrees>0) {
00492 searchLevelExact(result, vec, trees[0], 0.0);
00493 }
00494 assert(result.full());
00495 }
00496
00502 void getNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, int maxCheck)
00503 {
00504 int i;
00505 BranchSt branch;
00506
00507 int checkCount = 0;
00508 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00509 vector<bool> checked(size_,false);
00510
00511
00512 for (i = 0; i < numTrees; ++i) {
00513 searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, heap, checked);
00514 }
00515
00516
00517 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00518 searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, heap, checked);
00519 }
00520
00521 delete heap;
00522
00523 assert(result.full());
00524 }
00525
00526
00532 void searchLevel(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq, int& checkCount, int maxCheck,
00533 Heap<BranchSt>* heap, vector<bool>& checked)
00534 {
00535 if (result.worstDist()<mindistsq) {
00536
00537 return;
00538 }
00539
00540
00541 if (node->child1 == NULL && node->child2 == NULL) {
00542
00543
00544
00545
00546
00547 if (checked[node->divfeat] == true || checkCount>=maxCheck) {
00548 if (result.full()) return;
00549 }
00550 checkCount++;
00551 checked[node->divfeat] = true;
00552
00553 result.addPoint(dataset[node->divfeat],node->divfeat);
00554 return;
00555 }
00556
00557
00558 ELEM_TYPE val = vec[node->divfeat];
00559 DIST_TYPE diff = val - node->divval;
00560 Tree bestChild = (diff < 0) ? node->child1 : node->child2;
00561 Tree otherChild = (diff < 0) ? node->child2 : node->child1;
00562
00563
00564
00565
00566
00567
00568
00569
00570
00571 DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq);
00572
00573 if (new_distsq < result.worstDist() || !result.full()) {
00574 heap->insert( BranchSt::make_branch(otherChild, new_distsq) );
00575 }
00576
00577
00578 searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck, heap, checked);
00579 }
00580
00584 void searchLevelExact(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq)
00585 {
00586 if (mindistsq>result.worstDist()) {
00587 return;
00588 }
00589
00590
00591 if (node->child1 == NULL && node->child2 == NULL) {
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601 result.addPoint(dataset[node->divfeat],node->divfeat);
00602 return;
00603 }
00604
00605
00606 ELEM_TYPE val = vec[node->divfeat];
00607 DIST_TYPE diff = val - node->divval;
00608 Tree bestChild = (diff < 0) ? node->child1 : node->child2;
00609 Tree otherChild = (diff < 0) ? node->child2 : node->child1;
00610
00611
00612
00613 searchLevelExact(result, vec, bestChild, mindistsq);
00614 DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq);
00615 searchLevelExact(result, vec, otherChild, new_distsq);
00616 }
00617
00618 };
00619
00620 }
00621
00622 #endif //_OPENCV_KDTREE_H_