44 #include <unordered_set>
67 template <
typename _T>
68 class NearestNeighborsGNAT :
public NearestNeighbors<_T>
74 using DataDist = std::pair<const _T*, double>;
75 struct DataDistCompare
77 bool operator()(
const DataDist& d0,
const DataDist& d1)
79 return d0.second < d1.second;
82 using NearQueue = std::priority_queue<DataDist, std::vector<DataDist>, DataDistCompare>;
87 using NodeDist = std::pair<Node*, double>;
88 struct NodeDistCompare
90 bool operator()(
const NodeDist& n0,
const NodeDist& n1)
const
92 return (n0.second - n0.first->maxRadius_) > (n1.second - n1.first->maxRadius_);
95 using NodeQueue = std::priority_queue<NodeDist, std::vector<NodeDist>, NodeDistCompare>;
99 NearestNeighborsGNAT(
unsigned int degree = 8,
unsigned int minDegree = 4,
unsigned int maxDegree = 12,
100 unsigned int maxNumPtsPerLeaf = 50,
unsigned int removedCacheSize = 500,
101 bool rebalancing =
false)
107 ,
rebuildSize_(rebalancing ? maxNumPtsPerLeaf * degree :
std::numeric_limits<
std::size_t>::max())
126 void clear()
override
135 if (
rebuildSize_ != std::numeric_limits<std::size_t>::max())
144 void add(
const _T& data)
override
158 void add(
const std::vector<_T>& data)
override
162 else if (!data.empty())
165 for (
unsigned int i = 1; i < data.size(); ++i)
167 size_ += data.size();
185 bool remove(
const _T& data)
override
192 const _T*
d = nbh_queue.top().first;
204 _T
nearest(
const _T& data)
const override
210 if (!nbh_queue.empty())
211 return *nbh_queue.top().first;
213 throw moveit::Exception(
"No elements found in nearest neighbors data structure");
217 void nearestK(
const _T& data, std::size_t k, std::vector<_T>& nbh)
const override
231 void nearestR(
const _T& data,
double radius, std::vector<_T>& nbh)
const override
242 std::size_t
size()
const override
247 void list(std::vector<_T>& data)
const override
250 data.reserve(
size());
261 if (!gnat.removed_.empty())
263 out <<
"Elements marked for removal:\n";
264 for (
typename std::unordered_set<const _T*>::const_iterator it = gnat.removed_.begin();
265 it != gnat.removed_.end(); it++)
277 std::unordered_set<const _T*> tmp;
282 for (
typename std::unordered_set<const _T*>::iterator it = tmp.begin(); it != tmp.end(); it++)
285 for (i = 0; i < lst.size(); ++i)
291 std::cout <<
"***** FAIL!! ******\n" << *
this <<
'\n';
292 for (
unsigned int j = 0; j < lst.size(); ++j)
293 std::cout << lst[j] <<
'\t';
294 std::cout << std::endl;
296 assert(i != lst.size());
302 if (lst.size() !=
size_)
303 std::cout <<
"#########################################\n" << *
this << std::endl;
304 assert(lst.size() ==
size_);
325 NodeQueue node_queue;
329 tree_->
nearestK(*
this, data, k, nbhQueue, node_queue, is_pivot);
330 while (!node_queue.empty())
332 dist = nbhQueue.top().second;
333 node_dist = node_queue.top();
335 if (nbhQueue.size() == k && (node_dist.second > node_dist.first->maxRadius_ + dist ||
336 node_dist.second < node_dist.first->minRadius_ - dist))
338 node_dist.first->nearestK(*
this, data, k, nbhQueue, node_queue, is_pivot);
343 void nearestRInternal(
const _T& data,
double radius, NearQueue& nbhQueue)
const
345 double dist = radius;
346 NodeQueue node_queue;
351 while (!node_queue.empty())
353 node_dist = node_queue.top();
355 if (node_dist.second > node_dist.first->maxRadius_ + dist || node_dist.second < node_dist.first->minRadius_ - dist)
357 node_dist.first->nearestR(*
this, data, radius, nbhQueue, node_queue);
364 typename std::vector<_T>::reverse_iterator it;
365 nbh.resize(nbhQueue.size());
366 for (it = nbh.rbegin(); it != nbh.rend(); it++, nbhQueue.pop())
367 *it = *nbhQueue.top().first;
376 Node(
int degree,
int capacity, _T pivot)
385 data_.reserve(capacity + 1);
390 for (
unsigned int i = 0; i <
children_.size(); ++i)
410 activity_ = std::max(-32, activity_ - 1);
424 void add(
GNAT& gnat,
const _T& data)
428 data_.push_back(data);
432 if (!gnat.removed_.empty())
433 gnat.rebuildDataStructure();
434 else if (gnat.size_ >= gnat.rebuildSize_)
436 gnat.rebuildSize_ <<= 1;
437 gnat.rebuildDataStructure();
445 std::vector<double> dist(
children_.size());
449 for (
unsigned int i = 1; i <
children_.size(); ++i)
450 if ((dist[i] = gnat.distFun_(data,
children_[i]->pivot_)) < min_dist)
455 for (
unsigned int i = 0; i <
children_.size(); ++i)
457 children_[min_ind]->updateRadius(min_dist);
464 unsigned int sz =
data_.size();
473 std::vector<unsigned int> pivots;
476 gnat.pivotSelector_.kcenters(
data_,
degree_, pivots, dists);
477 for (
unsigned int& pivot : pivots)
480 for (
unsigned int j = 0; j <
data_.size(); ++j)
483 for (
unsigned int i = 1; i <
degree_; ++i)
484 if (dists(j, i) < dists(j, k))
489 child->data_.push_back(
data_[j]);
490 child->updateRadius(dists(j, k));
492 for (
unsigned int i = 0; i <
degree_; ++i)
496 for (
unsigned int i = 0; i <
degree_; ++i)
510 for (
unsigned int i = 0; i <
degree_; ++i)
516 bool insertNeighborK(NearQueue& nbh, std::size_t k,
const _T& data,
const _T& key,
double dist)
const
520 nbh.push(std::make_pair(&data, dist));
523 if (dist < nbh.top().second || (dist < std::numeric_limits<double>::epsilon() && data == key))
526 nbh.push(std::make_pair(&data, dist));
537 void nearestK(
const GNAT& gnat,
const _T& data, std::size_t k, NearQueue& nbh, NodeQueue& nodeQueue,
540 for (
unsigned int i = 0; i <
data_.size(); ++i)
550 std::vector<double> dist_to_pivot(
children_.size());
551 std::vector<int> permutation(
children_.size());
552 for (
unsigned int i = 0; i < permutation.size(); ++i)
554 std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine{});
556 for (
unsigned int i = 0; i <
children_.size(); ++i)
557 if (permutation[i] >= 0)
560 dist_to_pivot[permutation[i]] = gnat.
distFun_(data, child->pivot_);
561 if (
insertNeighborK(nbh, k, child->pivot_, data, dist_to_pivot[permutation[i]]))
565 dist = nbh.top().second;
566 for (
unsigned int j = 0; j <
children_.size(); ++j)
567 if (permutation[j] >= 0 && i != j &&
568 (dist_to_pivot[permutation[i]] - dist > child->maxRange_[permutation[j]] ||
569 dist_to_pivot[permutation[i]] + dist < child->
minRange_[permutation[j]]))
574 dist = nbh.top().second;
575 for (
unsigned int i = 0; i <
children_.size(); ++i)
576 if (permutation[i] >= 0)
579 if (nbh.size() < k || (dist_to_pivot[permutation[i]] - dist <= child->
maxRadius_ &&
580 dist_to_pivot[permutation[i]] + dist >= child->minRadius_))
581 nodeQueue.push(std::make_pair(child, dist_to_pivot[permutation[i]]));
586 void insertNeighborR(NearQueue& nbh,
double r,
const _T& data,
double dist)
const
589 nbh.push(std::make_pair(&data, dist));
594 void nearestR(
const GNAT& gnat,
const _T& data,
double r, NearQueue& nbh, NodeQueue& nodeQueue)
const
598 for (
unsigned int i = 0; i <
data_.size(); ++i)
599 if (!gnat.isRemoved(
data_[i]))
604 std::vector<double> dist_to_pivot(
children_.size());
605 std::vector<int> permutation(
children_.size());
606 for (
unsigned int i = 0; i < permutation.size(); ++i)
608 std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine{});
610 for (
unsigned int i = 0; i <
children_.size(); ++i)
611 if (permutation[i] >= 0)
614 dist_to_pivot[i] = gnat.distFun_(data, child->pivot_);
616 for (
unsigned int j = 0; j <
children_.size(); ++j)
617 if (permutation[j] >= 0 && i != j &&
618 (dist_to_pivot[i] - dist > child->maxRange_[permutation[j]] ||
619 dist_to_pivot[i] + dist < child->
minRange_[permutation[j]]))
623 for (
unsigned int i = 0; i <
children_.size(); ++i)
624 if (permutation[i] >= 0)
627 if (dist_to_pivot[i] - dist <= child->
maxRadius_ && dist_to_pivot[i] + dist >= child->minRadius_)
628 nodeQueue.push(std::make_pair(child, dist_to_pivot[i]));
633 void list(
const GNAT& gnat, std::vector<_T>& data)
const
635 if (!gnat.isRemoved(
pivot_))
637 for (
unsigned int i = 0; i <
data_.size(); ++i)
638 if (!gnat.isRemoved(
data_[i]))
639 data.push_back(
data_[i]);
640 for (
unsigned int i = 0; i <
children_.size(); ++i)
644 friend std::ostream&
operator<<(std::ostream& out,
const Node& node)
646 out <<
"\ndegree:\t" << node.degree_;
647 out <<
"\nminRadius:\t" << node.minRadius_;
648 out <<
"\nmaxRadius:\t" << node.maxRadius_;
649 out <<
"\nminRange:\t";
650 for (
unsigned int i = 0; i < node.minRange_.size(); ++i)
651 out << node.minRange_[i] <<
'\t';
652 out <<
"\nmaxRange: ";
653 for (
unsigned int i = 0; i < node.maxRange_.size(); ++i)
654 out << node.maxRange_[i] <<
'\t';
655 out <<
"\npivot:\t" << node.pivot_;
657 for (
unsigned int i = 0; i < node.data_.size(); ++i)
658 out << node.data_[i] <<
'\t';
659 out <<
"\nthis:\t" << &node;
660 out <<
"\nchildren:\n";
661 for (
unsigned int i = 0; i < node.children_.size(); ++i)
662 out << node.children_[i] <<
'\t';
664 for (
unsigned int i = 0; i < node.children_.size(); ++i)
665 out << *node.children_[i] <<
'\n';
685 std::vector<_T>
data_;
692 Node*
tree_{
nullptr };
709 std::size_t
size_{ 0 };
720 std::unordered_set<const _T*>
removed_;