131 template <
typename ElementType,
typename ElementType2,
typename Adapter>
133 const Adapter &adapter,
int ndims,
double worstDist)
const
137 for (
int i = 0; i < ndims; i++)
139 double d = adapter(elema, i) - adapter(elemb, i);
148 template <
int DIMS,
typename Adapter,
typename DistanceType = L2>
155 template <
typename Container>
156 inline void build(
const Container &container)
159 _index.reserve(container.size() * 2);
164 for (
size_t i = 0; i < container.size(); i++)
166 if (container.size() == 0)
181 inline void toStream(std::ostream &str)
const;
185 template <
typename Type,
typename Container>
186 inline std::vector<std::pair<uint32_t, double> >
searchKnn(
const Container &container,
187 const Type &val,
int nn,
190 std::vector<std::pair<uint32_t, double> > res;
191 generalSearch<Type, Container>(res, container, val, -1, sorted, nn);
196 template <
typename Type,
typename Container>
197 inline std::vector<std::pair<uint32_t, double> >
radiusSearch(
const Container &container,
198 const Type &val,
double dist,
200 int maxNN = -1)
const
202 std::vector<std::pair<uint32_t, double> > res;
203 generalSearch<Type, Container>(res, container, val, dist, sorted, maxNN);
208 template <
typename Type,
typename Container>
209 inline void radiusSearch(std::vector<std::pair<uint32_t, double> > &res,
210 const Container &container,
const Type &val,
double dist,
211 bool sorted =
true,
int maxNN = -1)
213 generalSearch<Type, Container>(res, container, val, dist, sorted, maxNN);
235 void toStream(std::ostream &str)
const;
242 struct Index :
public std::vector<Node>
247 inline void toStream(std::ostream &str)
const;
260 template <
typename Container>
261 void divideTree(Index &index, uint64_t nodeIdx,
int startIndex,
int endIndex,
265 Node &currNode = index[nodeIdx];
266 int count = endIndex - startIndex;
267 assert(startIndex < endIndex);
271 currNode.idx.resize(count);
272 for (
int i = 0; i < count; i++)
274 computeBoundingBox<Container>(bbox, startIndex, endIndex, container);
280 currNode.setNodesInfo(index.size(), index.size() + 1);
281 index.push_back(Node());
282 int leftNode = index.size() - 1;
283 index.push_back(Node());
284 int rightNode = index.size() - 1;
291 computeBoundingBox<Container>(_bbox, startIndex, endIndex, container);
293 double max_spread = -1;
294 currNode.col_index = 0;
295 for (
int i = 0; i < DIMS; i++)
297 double spread = _bbox[i].second - _bbox[i].first;
298 if (spread > max_spread)
301 currNode.col_index = i;
305 double split_val = (bbox[currNode.col_index].first + bbox[currNode.col_index].second) / 2;
306 if (split_val < _bbox[currNode.col_index].first)
307 currNode.div_val = _bbox[currNode.col_index].first;
308 else if (split_val > _bbox[currNode.col_index].second)
309 currNode.div_val = _bbox[currNode.col_index].second;
311 currNode.div_val = split_val;
316 double var[DIMS], mean[DIMS];
318 mean_var_calculate<Container>(startIndex, endIndex, var, mean, container);
319 currNode.col_index = 0;
321 for (
int i = 1; i < DIMS; i++)
322 if (var[i] > var[currNode.col_index])
323 currNode.col_index = i;
327 currNode.div_val = mean[currNode.col_index];
337 planeSplit<Container>(&
all_indices[startIndex], count, currNode.col_index,
338 currNode.div_val, lim1, lim2, container);
342 if (lim1 > count / 2)
344 else if (lim2 < count / 2)
347 split_index = count / 2;
352 if ((lim1 == count) || (lim2 == 0))
353 split_index = count / 2;
359 [&](
const uint32_t &a,
const uint32_t &b)
361 return adapter(container.at(a), currNode.col_index) <
362 adapter(container.at(b), currNode.col_index);
364 split_index = count / 2;
374 left_bbox[currNode.col_index].second = currNode.div_val;
375 divideTree<Container>(index, leftNode, startIndex, startIndex + split_index,
376 left_bbox, container);
377 left_bbox[currNode.col_index].second = currNode.div_val;
378 assert(left_bbox[currNode.col_index].second <= currNode.div_val);
380 right_bbox[currNode.col_index].first = currNode.div_val;
381 divideTree<Container>(index, rightNode, startIndex + split_index, endIndex,
382 right_bbox, container);
384 currNode.divlow = left_bbox[currNode.col_index].second;
385 currNode.divhigh = right_bbox[currNode.col_index].first;
386 assert(currNode.divlow <= currNode.divhigh);
388 for (
int i = 0; i < DIMS; ++i)
390 bbox[i].first = std::min(left_bbox[i].first, right_bbox[i].first);
391 bbox[i].second = std::max(left_bbox[i].second, right_bbox[i].second);
397 template <
typename Container>
401 for (
int i = 0; i < DIMS; ++i)
404 for (
int k = start + 1; k < end; ++k)
406 for (
int i = 0; i < DIMS; ++i)
409 if (v < bbox[i].first)
411 if (v > bbox[i].second)
417 template <
typename Container>
419 const Container &container)
421 const int MAX_ELEM_MEAN = 100;
424 memset(mean, 0,
sizeof(
double) * DIMS);
426 memset(sum2, 0,
sizeof(
double) * DIMS);
431 if (endIndex - startindex >= 2 * MAX_ELEM_MEAN)
432 increment = (endIndex - startindex) / MAX_ELEM_MEAN;
433 for (
int i = startindex; i < endIndex; i += increment)
435 for (
int c = 0; c < DIMS; c++)
439 sum2[c] += val * val;
444 double invcnt = 1. / double(cnt);
445 for (
int c = 0; c < DIMS; c++)
448 var[c] = sum2[c] * invcnt - mean[c] * mean[c];
462 template <
typename Container>
463 void planeSplit(uint32_t *ind,
int count,
int cutfeat,
float cutval,
int &lim1,
464 int &lim2,
const Container &container)
468 int right = count - 1;
471 while (left <= right &&
adapter(container.at(ind[left]), cutfeat) < cutval)
473 while (left <= right &&
adapter(container.at(ind[right]), cutfeat) >= cutval)
477 std::swap(ind[left], ind[right]);
485 while (left <= right &&
adapter(container.at(ind[left]), cutfeat) <= cutval)
487 while (left <= right &&
adapter(container.at(ind[right]), cutfeat) > cutval)
491 std::swap(ind[left], ind[right]);
499 template <
typename Type>
505 for (
int i = 0; i < DIMS; ++i)
507 double elem_i =
adapter(elem, i);
508 if (elem_i < bbox[i].first)
510 auto d = elem_i - bbox[i].first;
514 if (elem_i > bbox[i].second)
516 auto d = elem_i - bbox[i].second;
524 template <
typename Type,
typename Container>
526 const Container &container,
const Type &val,
double dist,
528 uint32_t maxNn = std::numeric_limits<int>::max())
const
531 memset(dists, 0,
sizeof(
double) * DIMS);
533 ResultSet hres(res, maxNn, dist > 0 ? dist * dist : -1.
f);
534 float distsq = computeInitialDistances<Type>(val, dists,
_index.
rootBBox);
535 searchExactLevel<Type, Container>(
_index, 0, val, hres, distsq, dists, 1, container);
536 if (sorted && res.size() > 1)
537 std::sort(res.begin(), res.end(),
538 [](
const std::pair<uint32_t, double> &a,
const std::pair<uint32_t, double> &b)
539 { return a.second < b.second; });
547 std::vector<std::pair<uint32_t, double> > &
array;
549 double maxValue = std::numeric_limits<double>::max();
553 ResultSet(std::vector<std::pair<uint32_t, double> > &data_ref,
554 uint32_t MaxSize = std::numeric_limits<uint32_t>::max(),
double MaxV = -1)
567 inline void push(
const std::pair<uint32_t, double> &val)
571 array.push_back(val);
578 if (val.second <
array[0].second)
582 if (
array.size() > 1)
588 array.push_back(val);
589 if (
array.size() > 1)
600 return std::numeric_limits<double>::max();
601 return array[0].second;
605 assert(!
array.empty());
606 return array[0].second;
614 size_t parentIndex = (index - 1) / 2;
615 if (
array[parentIndex].second <
array[index].second)
621 inline void up(
size_t index)
623 size_t leftIndex = 2 * index + 1;
624 size_t rightIndex = 2 * index + 2;
627 if (leftIndex >=
array.size())
631 if (rightIndex >=
array.size())
633 if (
array[index].second <
array[leftIndex].second)
640 if (
array[rightIndex].second <
array[leftIndex].second)
643 if (
array[index].second <
array[leftIndex].second)
652 if (
array[index].second <
array[rightIndex].second)
662 template <
typename Type,
typename Container>
664 ResultSet &res,
double mindistsq,
double dists[],
665 double epsError,
const Container &container)
const
667 const Node &currNode = index[nodeIdx];
668 if (currNode.isLeaf())
670 double worstDist = res.worstDist();
671 for (
size_t i = 0; i < currNode.idx.size(); i++)
673 double sqd =
_distance.compute_distance(elem, container.at(currNode.idx[i]),
677 res.push({ currNode.idx[i], sqd });
678 worstDist = res.worstDist();
684 double val =
adapter(elem, currNode.col_index);
685 double diff1 = val - currNode.divlow;
686 double diff2 = val - currNode.divhigh;
691 if ((diff1 + diff2) < 0)
693 bestChild = currNode._ileft;
694 otherChild = currNode._iright;
695 cut_dist = diff2 * diff2;
699 bestChild = currNode._iright;
700 otherChild = currNode._ileft;
701 cut_dist = diff1 * diff1;
704 searchExactLevel<Type, Container>(index, bestChild, elem, res, mindistsq, dists,
705 epsError, container);
707 float dst = dists[currNode.col_index];
708 mindistsq = mindistsq + cut_dist - dst;
709 dists[currNode.col_index] = cut_dist;
710 if (mindistsq * epsError <= res.worstDist())
711 searchExactLevel<Type, Container>(index, otherChild, elem, res, mindistsq, dists,
712 epsError, container);
713 dists[currNode.col_index] = dst;
717 template <
int DIMS,
typename AAdapter,
typename DistanceType>
726 uint64_t s =
idx.size();
727 str.write((
char *)&s,
sizeof(s));
728 str.write((
char *)&
idx[0],
sizeof(
idx[0]) *
idx.size());
731 template <
int DIMS,
typename AAdapter,
typename DistanceType>
734 str.read((
char *)&div_val,
sizeof(div_val));
735 str.read((
char *)&col_index,
sizeof(col_index));
736 str.read((
char *)&divhigh,
sizeof(divhigh));
737 str.read((
char *)&divlow,
sizeof(divlow));
738 str.read((
char *)&_ileft,
sizeof(_ileft));
739 str.read((
char *)&_iright,
sizeof(_iright));
741 str.read((
char *)&s,
sizeof(s));
743 str.read((
char *)&idx[0],
sizeof(idx[0]) * idx.size());
746 template <
int DIMS,
typename AAdapter,
typename DistanceType>
749 str.write((
char *)&dims,
sizeof(dims));
750 str.write((
char *)&rootBBox[0],
sizeof(rootBBox[0]) * dims);
751 str.write((
char *)&nValues,
sizeof(nValues));
753 uint64_t s = std::vector<Node>::size();
754 str.write((
char *)&s,
sizeof(s));
755 for (
size_t i = 0; i < std::vector<Node>::size(); i++)
756 std::vector<Node>::at(i).toStream(str);
759 template <
int DIMS,
typename AAdapter,
typename DistanceType>
762 str.read((
char *)&dims,
sizeof(dims));
763 rootBBox.resize(dims);
764 str.read((
char *)&rootBBox[0],
sizeof(rootBBox[0]) * dims);
765 str.read((
char *)&nValues,
sizeof(nValues));
770 str.read((
char *)&s,
sizeof(s));
771 std::vector<Node>::resize(s);
772 for (
size_t i = 0; i < std::vector<Node>::size(); i++)
773 std::vector<Node>::at(i).fromStream(str);
774 if (dims != DIMS && this->size() != 0 && nValues != 0)
775 throw std::runtime_error(
776 "Number of dimensions of the index in the stream is different from the number of dimensions of this");
779 template <
int DIMS,
typename AAdapter,
typename DistanceType>
785 template <
int DIMS,
typename AAdapter,
typename DistanceType>