Program Listing for File KDTree.h
↰ Return to documentation for file (include/dynotree/KDTree.h)
#pragma once
#include <algorithm>
#include <cmath>
#include <cwchar>
#include <limits>
#include <queue>
#include <set>
#include <vector>
#include <eigen3/Eigen/Core>
#include <eigen3/Eigen/Dense>
#include "StateSpace.h"
#include "dynotree/dynotree_macros.h"
namespace dynotree {
template <class Id, int Dimensions, std::size_t BucketSize = 32,
typename Scalar = double,
typename StateSpace = Rn<Scalar, Dimensions>>
class KDTree {
private:
struct Node;
std::vector<Node> m_nodes;
std::set<std::size_t> waitingForSplit;
StateSpace state_space;
public:
using scalar_t = Scalar;
using id_t = Id;
using point_t = Eigen::Matrix<Scalar, Dimensions, 1>;
using cref_t = const Eigen::Ref<const Eigen::Matrix<Scalar, Dimensions, 1>> &;
using ref_t = Eigen::Ref<Eigen::Matrix<Scalar, Dimensions, 1>>;
using state_space_t = StateSpace;
int m_dimensions = Dimensions;
static const std::size_t bucketSize = BucketSize;
using tree_t = KDTree<Id, Dimensions, BucketSize, Scalar, StateSpace>;
StateSpace &getStateSpace() { return state_space; }
KDTree() = default;
void init_tree(int runtime_dimension = -1,
const StateSpace &t_state_space = StateSpace()) {
state_space = t_state_space;
if constexpr (Dimensions == Eigen::Dynamic) {
assert(runtime_dimension > 0);
m_dimensions = runtime_dimension;
m_nodes.emplace_back(BucketSize, m_dimensions);
} else {
m_nodes.emplace_back(BucketSize, -1);
}
}
size_t size() const { return m_nodes[0].m_entries; }
void addPoint(const point_t &x, const Id &id, bool autosplit = true) {
std::size_t addNode = 0;
assert(m_dimensions > 0);
while (m_nodes[addNode].m_splitDimension != m_dimensions) {
m_nodes[addNode].expandBounds(x);
if (x[m_nodes[addNode].m_splitDimension] <
m_nodes[addNode].m_splitValue) {
addNode = m_nodes[addNode].m_children.first;
} else {
addNode = m_nodes[addNode].m_children.second;
}
}
m_nodes[addNode].add(PointId{x, id});
if (m_nodes[addNode].shouldSplit() &&
m_nodes[addNode].m_entries % BucketSize == 0) {
if (autosplit) {
split(addNode);
} else {
waitingForSplit.insert(addNode);
}
}
}
void splitOutstanding() {
std::vector<std::size_t> searchStack(waitingForSplit.begin(),
waitingForSplit.end());
waitingForSplit.clear();
while (searchStack.size() > 0) {
std::size_t addNode = searchStack.back();
searchStack.pop_back();
if (m_nodes[addNode].m_splitDimension == m_dimensions &&
m_nodes[addNode].shouldSplit() && split(addNode)) {
searchStack.push_back(m_nodes[addNode].m_children.first);
searchStack.push_back(m_nodes[addNode].m_children.second);
}
}
}
struct DistanceId {
Scalar distance;
Id id;
inline bool operator<(const DistanceId &dp) const {
return distance < dp.distance;
}
};
std::vector<DistanceId> searchKnn(const point_t &x,
std::size_t maxPoints) const {
return searcher().search(x, std::numeric_limits<Scalar>::max(), maxPoints,
state_space);
}
std::vector<DistanceId> searchBall(const point_t &x, Scalar maxRadius) const {
return searcher().search(
x, maxRadius, std::numeric_limits<std::size_t>::max(), state_space);
}
std::vector<DistanceId>
searchCapacityLimitedBall(const point_t &x, Scalar maxRadius,
std::size_t maxPoints) const {
return searcher().search(x, maxRadius, maxPoints, state_space);
}
DistanceId search(const point_t &x) const {
DistanceId result;
result.distance = std::numeric_limits<Scalar>::infinity();
if (m_nodes[0].m_entries > 0) {
std::vector<std::size_t> searchStack;
searchStack.reserve(
1 +
std::size_t(1.5 * std::log2(1 + m_nodes[0].m_entries / BucketSize)));
searchStack.push_back(0);
while (searchStack.size() > 0) {
std::size_t nodeIndex = searchStack.back();
searchStack.pop_back();
const Node &node = m_nodes[nodeIndex];
if (result.distance > node.distance_to_rectangle(x, state_space)) {
if (node.m_splitDimension == m_dimensions) {
for (const auto &lp : node.m_locationId) {
// Allow to have inactive nodes in the tree
if (!lp.active)
continue;
Scalar nodeDist = state_space.distance(x, lp.x);
if (nodeDist < result.distance) {
result = DistanceId{nodeDist, lp.id};
}
}
} else {
node.queueChildren(x, searchStack);
}
}
}
}
return result;
}
void set_inactive(const point_t &x) {
DistanceId result;
result.distance = std::numeric_limits<Scalar>::infinity();
bool found = false;
if (m_nodes[0].m_entries > 0) {
std::vector<std::size_t> searchStack;
searchStack.reserve(
1 +
std::size_t(1.5 * std::log2(1 + m_nodes[0].m_entries / BucketSize)));
searchStack.push_back(0);
while (!found && searchStack.size() > 0) {
std::size_t nodeIndex = searchStack.back();
searchStack.pop_back();
Node &node = m_nodes[nodeIndex];
if (result.distance > node.distance_to_rectangle(x, state_space)) {
if (node.m_splitDimension == m_dimensions) {
for (auto &lp : node.m_locationId) {
// Allow to have inactive nodes in the tree
if (!lp.active)
continue;
Scalar nodeDist = state_space.distance(x, lp.x);
if (nodeDist < result.distance) {
result = DistanceId{nodeDist, lp.id};
if (result.distance < 1e-8) {
found = true;
lp.active = false;
break;
}
}
}
} else {
node.queueChildren(x, searchStack);
}
}
}
}
CHECK_PRETTY_DYNOTREE__(found);
// return result;
}
class Searcher {
public:
Searcher(const tree_t &tree) : m_tree(tree) {}
Searcher(const Searcher &searcher) : m_tree(searcher.m_tree) {}
// NB! this method is not const. Do not call this on same instance from
// different threads simultaneously.
const std::vector<DistanceId> &search(const point_t &x, Scalar maxRadius,
std::size_t maxPoints,
const StateSpace &state_space) {
// clear results from last time
m_results.clear();
// reserve capacities
m_searchStack.reserve(
1 + std::size_t(1.5 * std::log2(1 + m_tree.m_nodes[0].m_entries /
BucketSize)));
if (m_prioqueueCapacity < maxPoints &&
maxPoints < m_tree.m_nodes[0].m_entries) {
std::vector<DistanceId> container;
container.reserve(maxPoints);
m_prioqueue = std::priority_queue<DistanceId, std::vector<DistanceId>>(
std::less<DistanceId>(), std::move(container));
m_prioqueueCapacity = maxPoints;
}
m_tree.searchCapacityLimitedBall(x, maxRadius, maxPoints, m_searchStack,
m_prioqueue, m_results, state_space);
m_prioqueueCapacity = std::max(m_prioqueueCapacity, m_results.size());
return m_results;
}
private:
const tree_t &m_tree;
std::vector<std::size_t> m_searchStack;
std::priority_queue<DistanceId, std::vector<DistanceId>> m_prioqueue;
std::size_t m_prioqueueCapacity = 0;
std::vector<DistanceId> m_results;
};
// NB! returned class has no const methods. Get one instance per thread!
Searcher searcher() const { return Searcher(*this); }
private:
struct PointId {
point_t x;
Id id;
bool active = true;
};
std::vector<PointId> m_bucketRecycle;
void searchCapacityLimitedBall(
const point_t &x, Scalar maxRadius, std::size_t maxPoints,
std::vector<std::size_t> &searchStack,
std::priority_queue<DistanceId, std::vector<DistanceId>> &prioqueue,
std::vector<DistanceId> &results, const StateSpace &state_space) const {
std::size_t numSearchPoints = std::min(maxPoints, m_nodes[0].m_entries);
if (numSearchPoints > 0) {
searchStack.push_back(0);
while (searchStack.size() > 0) {
std::size_t nodeIndex = searchStack.back();
searchStack.pop_back();
const Node &node = m_nodes[nodeIndex];
Scalar minDist = node.distance_to_rectangle(x, state_space);
if (maxRadius > minDist && (prioqueue.size() < numSearchPoints ||
prioqueue.top().distance > minDist)) {
if (node.m_splitDimension == m_dimensions) {
node.searchCapacityLimitedBall(x, maxRadius, numSearchPoints,
prioqueue, state_space);
} else {
node.queueChildren(x, searchStack);
}
}
}
results.reserve(prioqueue.size());
while (prioqueue.size() > 0) {
results.push_back(prioqueue.top());
prioqueue.pop();
}
std::reverse(results.begin(), results.end());
}
}
bool split(std::size_t index) {
if (m_nodes.capacity() < m_nodes.size() + 2) {
m_nodes.reserve((m_nodes.capacity() + 1) * 2);
}
Node &splitNode = m_nodes[index];
splitNode.m_splitDimension = m_dimensions;
Scalar width(0);
state_space.choose_split_dimension(splitNode.m_lb, splitNode.m_ub,
splitNode.m_splitDimension, width);
if (splitNode.m_splitDimension == m_dimensions) {
return false;
}
std::vector<Scalar> splitDimVals;
splitDimVals.reserve(splitNode.m_entries);
for (const auto &lp : splitNode.m_locationId) {
splitDimVals.push_back(lp.x[splitNode.m_splitDimension]);
}
std::nth_element(splitDimVals.begin(),
splitDimVals.begin() + splitDimVals.size() / 2 + 1,
splitDimVals.end());
std::nth_element(splitDimVals.begin(),
splitDimVals.begin() + splitDimVals.size() / 2,
splitDimVals.begin() + splitDimVals.size() / 2 + 1);
splitNode.m_splitValue = (splitDimVals[splitDimVals.size() / 2] +
splitDimVals[splitDimVals.size() / 2 + 1]) /
Scalar(2);
splitNode.m_children = std::make_pair(m_nodes.size(), m_nodes.size() + 1);
std::size_t entries = splitNode.m_entries;
m_nodes.emplace_back(m_bucketRecycle, entries, m_dimensions);
Node &leftNode = m_nodes.back();
m_nodes.emplace_back(entries, m_dimensions);
Node &rightNode = m_nodes.back();
for (const auto &lp : splitNode.m_locationId) {
if (lp.x[splitNode.m_splitDimension] < splitNode.m_splitValue) {
leftNode.add(lp);
} else {
rightNode.add(lp);
}
}
if (leftNode.m_entries ==
0) // points with equality to splitValue go in rightNode
{
splitNode.m_splitValue = 0;
splitNode.m_splitDimension = m_dimensions;
splitNode.m_children = std::pair<std::size_t, std::size_t>(0, 0);
std::swap(rightNode.m_locationId, m_bucketRecycle);
m_nodes.pop_back();
m_nodes.pop_back();
return false;
} else {
splitNode.m_locationId.clear();
// if it was a standard sized bucket, recycle the memory to reduce
// allocator pressure otherwise clear the memory used by the bucket
// since it is a branch not a leaf anymore
if (splitNode.m_locationId.capacity() == BucketSize) {
std::swap(splitNode.m_locationId, m_bucketRecycle);
} else {
std::vector<PointId> empty;
std::swap(splitNode.m_locationId, empty);
}
return true;
}
}
struct Node {
Node(std::size_t capacity, int runtime_dimension = -1) {
init(capacity, runtime_dimension);
}
Node(std::vector<PointId> &recycle, std::size_t capacity,
int runtime_dimension) {
std::swap(m_locationId, recycle);
init(capacity, runtime_dimension);
}
void init(std::size_t capacity, int runtime_dimension) {
if constexpr (Dimensions == Eigen::Dynamic) {
assert(runtime_dimension > 0);
m_lb.resize(runtime_dimension);
m_ub.resize(runtime_dimension);
m_splitDimension = runtime_dimension;
}
m_lb.setConstant(std::numeric_limits<Scalar>::max());
m_ub.setConstant(std::numeric_limits<Scalar>::lowest());
m_locationId.reserve(std::max(BucketSize, capacity));
}
void expandBounds(const point_t &x) {
m_lb = m_lb.cwiseMin(x);
m_ub = m_ub.cwiseMax(x);
m_entries++;
}
void add(const PointId &lp) {
expandBounds(lp.x);
m_locationId.push_back(lp);
}
bool shouldSplit() const { return m_entries >= BucketSize; }
void searchCapacityLimitedBall(const point_t &x, Scalar maxRadius,
std::size_t K,
std::priority_queue<DistanceId> &results,
const StateSpace &state_space) const {
std::size_t i = 0;
// this fills up the queue if it isn't full yet
for (; results.size() < K && i < m_entries; i++) {
const auto &lp = m_locationId[i];
Scalar distance = state_space.distance(x, lp.x);
if (distance < maxRadius) {
results.emplace(DistanceId{distance, lp.id});
}
}
// this adds new things to the queue once it is full
for (; i < m_entries; i++) {
const auto &lp = m_locationId[i];
Scalar distance = state_space.distance(x, lp.x);
if (distance < maxRadius && distance < results.top().distance) {
results.pop();
results.emplace(DistanceId{distance, lp.id});
}
}
}
void queueChildren(const point_t &x,
std::vector<std::size_t> &searchStack) const {
if (x[m_splitDimension] < m_splitValue) {
searchStack.push_back(m_children.second);
searchStack.push_back(m_children.first); // left is popped first
} else {
searchStack.push_back(m_children.first);
searchStack.push_back(m_children.second); // right is popped first
}
}
Scalar distance_to_rectangle(const point_t &x,
const StateSpace &distance) const {
return distance.distance_to_rectangle(x, m_lb, m_ub);
}
std::size_t m_entries = 0;
int m_splitDimension = Dimensions;
Scalar m_splitValue = 0;
// struct Range {
// Scalar min, max;
// };
// std::array<Range, Dimensions> m_bounds; /// bounding box of this node
Eigen::Matrix<Scalar, Dimensions, 1> m_lb;
Eigen::Matrix<Scalar, Dimensions, 1> m_ub;
std::pair<std::size_t, std::size_t>
m_children;
std::vector<PointId> m_locationId;
};
};
} // namespace dynotree