131 template <
typename ElementType,
typename ElementType2,
typename Adapter>
133 const Adapter &adapter,
int ndims,
double worstDist)
const 137 for (
int i = 0; i < ndims; i++)
139 double d = adapter(elema, i) - adapter(elemb, i);
148 template <
int DIMS,
typename Adapter,
typename DistanceType = L2>
155 template <
typename Container>
156 inline void build(
const Container &container)
159 _index.reserve(container.size() * 2);
161 _index.nValues = container.size();
163 all_indices.resize(container.size());
164 for (
size_t i = 0; i < container.size(); i++)
166 if (container.size() == 0)
168 computeBoundingBox<Container>(_index.rootBBox, 0, all_indices.size(), container);
169 _index.push_back(Node());
170 divideTree<Container>(_index, 0, 0, all_indices.size(), _index.rootBBox, container);
181 inline void toStream(std::ostream &str)
const;
183 inline void fromStream(std::istream &str);
185 template <
typename Type,
typename Container>
186 inline std::vector<std::pair<uint32_t, double> >
searchKnn(
const Container &container,
187 const Type &val,
int nn,
190 std::vector<std::pair<uint32_t, double> > res;
191 generalSearch<Type, Container>(res, container, val, -1, sorted, nn);
196 template <
typename Type,
typename Container>
197 inline std::vector<std::pair<uint32_t, double> >
radiusSearch(
const Container &container,
198 const Type &val,
double dist,
200 int maxNN = -1)
const 202 std::vector<std::pair<uint32_t, double> > res;
203 generalSearch<Type, Container>(res, container, val, dist, sorted, maxNN);
208 template <
typename Type,
typename Container>
209 inline void radiusSearch(std::vector<std::pair<uint32_t, double> > &res,
210 const Container &container,
const Type &val,
double dist,
211 bool sorted =
true,
int maxNN = -1)
213 generalSearch<Type, Container>(res, container, val, dist, sorted, maxNN);
223 return _ileft == -1 && _iright == -1;
234 int64_t _ileft = -1, _iright = -1;
235 void toStream(std::ostream &str)
const;
236 void fromStream(std::istream &str);
242 struct Index :
public std::vector<Node>
247 inline void toStream(std::ostream &str)
const;
248 inline void fromStream(std::istream &str);
255 int _maxLeafSize = 10;
260 template <
typename Container>
261 void divideTree(Index &index, uint64_t nodeIdx,
int startIndex,
int endIndex,
262 BoundingBox &bbox,
const Container &container)
265 Node &currNode = index[nodeIdx];
266 int count = endIndex - startIndex;
267 assert(startIndex < endIndex);
269 if (count <= _maxLeafSize)
271 currNode.idx.resize(count);
272 for (
int i = 0; i < count; i++)
273 currNode.idx[i] = all_indices[startIndex + i];
274 computeBoundingBox<Container>(bbox, startIndex, endIndex, container);
280 currNode.setNodesInfo(index.size(), index.size() + 1);
281 index.push_back(Node());
282 int leftNode = index.size() - 1;
283 index.push_back(Node());
284 int rightNode = index.size() - 1;
291 computeBoundingBox<Container>(_bbox, startIndex, endIndex, container);
293 double max_spread = -1;
294 currNode.col_index = 0;
295 for (
int i = 0; i < DIMS; i++)
297 double spread = _bbox[i].second - _bbox[i].first;
298 if (spread > max_spread)
301 currNode.col_index = i;
305 double split_val = (bbox[currNode.col_index].first + bbox[currNode.col_index].second) / 2;
306 if (split_val < _bbox[currNode.col_index].first)
307 currNode.div_val = _bbox[currNode.col_index].first;
308 else if (split_val > _bbox[currNode.col_index].second)
309 currNode.div_val = _bbox[currNode.col_index].second;
311 currNode.div_val = split_val;
316 double var[DIMS], mean[DIMS];
318 mean_var_calculate<Container>(startIndex, endIndex, var, mean, container);
319 currNode.col_index = 0;
321 for (
int i = 1; i < DIMS; i++)
322 if (var[i] > var[currNode.col_index])
323 currNode.col_index = i;
327 currNode.div_val = mean[currNode.col_index];
337 planeSplit<Container>(&all_indices[startIndex], count, currNode.col_index,
338 currNode.div_val, lim1, lim2, container);
342 if (lim1 > count / 2)
344 else if (lim2 < count / 2)
347 split_index = count / 2;
352 if ((lim1 == count) || (lim2 == 0))
353 split_index = count / 2;
355 if (_maxLeafSize != 1)
356 if (split_index < _maxLeafSize || count - split_index < _maxLeafSize)
358 std::sort(all_indices.begin() + startIndex, all_indices.begin() + endIndex,
359 [&](
const uint32_t &a,
const uint32_t &b)
361 return adapter(container.at(a), currNode.col_index) <
362 adapter(container.at(b), currNode.col_index);
364 split_index = count / 2;
366 adapter(container.at(all_indices[startIndex + split_index]), currNode.col_index);
373 BoundingBox left_bbox(bbox);
374 left_bbox[currNode.col_index].second = currNode.div_val;
375 divideTree<Container>(index, leftNode, startIndex, startIndex + split_index,
376 left_bbox, container);
377 left_bbox[currNode.col_index].second = currNode.div_val;
378 assert(left_bbox[currNode.col_index].second <= currNode.div_val);
379 BoundingBox right_bbox(bbox);
380 right_bbox[currNode.col_index].first = currNode.div_val;
381 divideTree<Container>(index, rightNode, startIndex + split_index, endIndex,
382 right_bbox, container);
384 currNode.divlow = left_bbox[currNode.col_index].second;
385 currNode.divhigh = right_bbox[currNode.col_index].first;
386 assert(currNode.divlow <= currNode.divhigh);
388 for (
int i = 0; i < DIMS; ++i)
390 bbox[i].first = std::min(left_bbox[i].first, right_bbox[i].first);
391 bbox[i].second = std::max(left_bbox[i].second, right_bbox[i].second);
397 template <
typename Container>
401 for (
int i = 0; i < DIMS; ++i)
402 bbox[i].second = bbox[i].first = adapter(container.at(all_indices[start]), i);
404 for (
int k = start + 1; k < end; ++k)
406 for (
int i = 0; i < DIMS; ++i)
408 float v = adapter(container.at(all_indices[k]), i);
409 if (v < bbox[i].first)
411 if (v > bbox[i].second)
417 template <
typename Container>
419 const Container &container)
421 const int MAX_ELEM_MEAN = 100;
424 memset(mean, 0,
sizeof(
double) * DIMS);
426 memset(sum2, 0,
sizeof(
double) * DIMS);
431 if (endIndex - startindex >= 2 * MAX_ELEM_MEAN)
432 increment = (endIndex - startindex) / MAX_ELEM_MEAN;
433 for (
int i = startindex; i < endIndex; i += increment)
435 for (
int c = 0; c < DIMS; c++)
437 auto val = adapter(container.at(all_indices[i]), c);
439 sum2[c] += val * val;
444 double invcnt = 1. / double(cnt);
445 for (
int c = 0; c < DIMS; c++)
448 var[c] = sum2[c] * invcnt - mean[c] * mean[c];
462 template <
typename Container>
463 void planeSplit(uint32_t *ind,
int count,
int cutfeat,
float cutval,
int &lim1,
464 int &lim2,
const Container &container)
468 int right = count - 1;
471 while (left <= right && adapter(container.at(ind[left]), cutfeat) < cutval)
473 while (left <= right && adapter(container.at(ind[right]), cutfeat) >= cutval)
477 std::swap(ind[left], ind[right]);
485 while (left <= right && adapter(container.at(ind[left]), cutfeat) <= cutval)
487 while (left <= right && adapter(container.at(ind[right]), cutfeat) > cutval)
491 std::swap(ind[left], ind[right]);
499 template <
typename Type>
501 const BoundingBox &bbox)
const 505 for (
int i = 0; i < DIMS; ++i)
507 double elem_i = adapter(elem, i);
508 if (elem_i < bbox[i].first)
510 auto d = elem_i - bbox[i].first;
514 if (elem_i > bbox[i].second)
516 auto d = elem_i - bbox[i].second;
524 template <
typename Type,
typename Container>
526 const Container &container,
const Type &val,
double dist,
528 uint32_t maxNn = std::numeric_limits<int>::max())
const 531 memset(dists, 0,
sizeof(
double) * DIMS);
533 ResultSet hres(res, maxNn, dist > 0 ? dist * dist : -1.
f);
534 float distsq = computeInitialDistances<Type>(val, dists, _index.rootBBox);
535 searchExactLevel<Type, Container>(_index, 0, val, hres, distsq, dists, 1, container);
536 if (sorted && res.size() > 1)
537 std::sort(res.begin(), res.end(),
538 [](
const std::pair<uint32_t, double> &a,
const std::pair<uint32_t, double> &b)
539 {
return a.second < b.second; });
547 std::vector<std::pair<uint32_t, double> > &
array;
549 double maxValue = std::numeric_limits<double>::max();
550 bool radius_search =
false;
553 ResultSet(std::vector<std::pair<uint32_t, double> > &data_ref,
554 uint32_t MaxSize = std::numeric_limits<uint32_t>::max(),
double MaxV = -1)
562 radius_search =
true;
567 inline void push(
const std::pair<uint32_t, double> &val)
569 if (radius_search && val.second < maxValue)
571 array.push_back(val);
575 if (array.size() >= size_t(maxSize))
578 if (val.second < array[0].second)
580 swap(array.front(), array.back());
582 if (array.size() > 1)
588 array.push_back(val);
589 if (array.size() > 1)
590 down(array.size() - 1);
599 else if (array.size() < size_t(maxSize))
600 return std::numeric_limits<double>::max();
601 return array[0].second;
605 assert(!array.empty());
606 return array[0].second;
614 size_t parentIndex = (index - 1) / 2;
615 if (array[parentIndex].second < array[index].second)
617 swap(array[index], array[parentIndex]);
621 inline void up(
size_t index)
623 size_t leftIndex = 2 * index + 1;
624 size_t rightIndex = 2 * index + 2;
627 if (leftIndex >= array.size())
631 if (rightIndex >= array.size())
633 if (array[index].second < array[leftIndex].second)
634 swap(array[index], array[leftIndex]);
640 if (array[rightIndex].second < array[leftIndex].second)
643 if (array[index].second < array[leftIndex].second)
645 swap(array[index], array[leftIndex]);
652 if (array[index].second < array[rightIndex].second)
654 swap(array[index], array[rightIndex]);
662 template <
typename Type,
typename Container>
664 ResultSet &res,
double mindistsq,
double dists[],
665 double epsError,
const Container &container)
const 667 const Node &currNode = index[nodeIdx];
668 if (currNode.isLeaf())
670 double worstDist = res.worstDist();
671 for (
size_t i = 0; i < currNode.idx.size(); i++)
673 double sqd = _distance.compute_distance(elem, container.at(currNode.idx[i]),
674 adapter, DIMS, worstDist);
677 res.push({ currNode.idx[i], sqd });
678 worstDist = res.worstDist();
684 double val = adapter(elem, currNode.col_index);
685 double diff1 = val - currNode.divlow;
686 double diff2 = val - currNode.divhigh;
691 if ((diff1 + diff2) < 0)
693 bestChild = currNode._ileft;
694 otherChild = currNode._iright;
695 cut_dist = diff2 * diff2;
699 bestChild = currNode._iright;
700 otherChild = currNode._ileft;
701 cut_dist = diff1 * diff1;
704 searchExactLevel<Type, Container>(index, bestChild, elem, res, mindistsq, dists,
705 epsError, container);
707 float dst = dists[currNode.col_index];
708 mindistsq = mindistsq + cut_dist - dst;
709 dists[currNode.col_index] = cut_dist;
710 if (mindistsq * epsError <= res.worstDist())
711 searchExactLevel<Type, Container>(index, otherChild, elem, res, mindistsq, dists,
712 epsError, container);
713 dists[currNode.col_index] = dst;
717 template <
int DIMS,
typename AAdapter,
typename DistanceType>
720 str.write((
char *)&div_val,
sizeof(div_val));
721 str.write((
char *)&col_index,
sizeof(col_index));
722 str.write((
char *)&divhigh,
sizeof(divhigh));
723 str.write((
char *)&divlow,
sizeof(divlow));
724 str.write((
char *)&_ileft,
sizeof(_ileft));
725 str.write((
char *)&_iright,
sizeof(_iright));
726 uint64_t s = idx.size();
727 str.write((
char *)&s,
sizeof(s));
728 str.write((
char *)&idx[0],
sizeof(idx[0]) * idx.size());
731 template <
int DIMS,
typename AAdapter,
typename DistanceType>
734 str.read((
char *)&div_val,
sizeof(div_val));
735 str.read((
char *)&col_index,
sizeof(col_index));
736 str.read((
char *)&divhigh,
sizeof(divhigh));
737 str.read((
char *)&divlow,
sizeof(divlow));
738 str.read((
char *)&_ileft,
sizeof(_ileft));
739 str.read((
char *)&_iright,
sizeof(_iright));
741 str.read((
char *)&s,
sizeof(s));
743 str.read((
char *)&idx[0],
sizeof(idx[0]) * idx.size());
746 template <
int DIMS,
typename AAdapter,
typename DistanceType>
749 str.write((
char *)&dims,
sizeof(dims));
750 str.write((
char *)&rootBBox[0],
sizeof(rootBBox[0]) * dims);
751 str.write((
char *)&nValues,
sizeof(nValues));
753 uint64_t s = std::vector<Node>::size();
754 str.write((
char *)&s,
sizeof(s));
755 for (
size_t i = 0; i < std::vector<Node>::size(); i++)
756 std::vector<Node>::at(i).toStream(str);
759 template <
int DIMS,
typename AAdapter,
typename DistanceType>
762 str.read((
char *)&dims,
sizeof(dims));
763 rootBBox.resize(dims);
764 str.read((
char *)&rootBBox[0],
sizeof(rootBBox[0]) * dims);
765 str.read((
char *)&nValues,
sizeof(nValues));
770 str.read((
char *)&s,
sizeof(s));
771 std::vector<Node>::resize(s);
772 for (
size_t i = 0; i < std::vector<Node>::size(); i++)
773 std::vector<Node>::at(i).fromStream(str);
774 if (dims != DIMS && this->size() != 0 && nValues != 0)
775 throw std::runtime_error(
776 "Number of dimensions of the index in the stream is different from the number of dimensions of this");
779 template <
int DIMS,
typename AAdapter,
typename DistanceType>
782 _index.toStream(str);
785 template <
int DIMS,
typename AAdapter,
typename DistanceType>
788 _index.fromStream(str);
The KdTreeIndex class is the simplest an minimal method to use kdtrees. You only must define an adapt...
void toStream(std::ostream &str) const
double compute_distance(const ElementType &elema, const ElementType2 &elemb, const Adapter &adapter, int ndims, double worstDist) const
ResultSet(std::vector< std::pair< uint32_t, double > > &data_ref, uint32_t MaxSize=std::numeric_limits< uint32_t >::max(), double MaxV=-1)
void mean_var_calculate(int startindex, int endIndex, double var[], double mean[], const Container &container)
void radiusSearch(std::vector< std::pair< uint32_t, double > > &res, const Container &container, const Type &val, double dist, bool sorted=true, int maxNN=-1)
void setNodesInfo(uint32_t l, uint32_t r)
void computeBoundingBox(BoundingBox &bbox, int start, int end, const Container &container)
void fromStream(std::istream &str)
void generalSearch(std::vector< std::pair< uint32_t, double > > &res, const Container &container, const Type &val, double dist, bool sorted=true, uint32_t maxNn=std::numeric_limits< int >::max()) const
void toStream(std::ostream &str) const
void toStream(std::ostream &str) const
std::vector< std::pair< double, double > > BoundingBox
void divideTree(Index &index, uint64_t nodeIdx, int startIndex, int endIndex, BoundingBox &bbox, const Container &container)
std::vector< std::pair< uint32_t, double > > & array
std::vector< std::pair< uint32_t, double > > radiusSearch(const Container &container, const Type &val, double dist, bool sorted=true, int maxNN=-1) const
void build(const Container &container)
void planeSplit(uint32_t *ind, int count, int cutfeat, float cutval, int &lim1, int &lim2, const Container &container)
double computeInitialDistances(const Type &elem, double dists[], const BoundingBox &bbox) const
void searchExactLevel(const Index &index, int64_t nodeIdx, const Type &elem, ResultSet &res, double mindistsq, double dists[], double epsError, const Container &container) const
void fromStream(std::istream &str)
std::vector< std::pair< uint32_t, double > > searchKnn(const Container &container, const Type &val, int nn, bool sorted=true)
void fromStream(std::istream &str)
void push(const std::pair< uint32_t, double > &val)
std::vector< uint32_t > all_indices