00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #ifndef KDTREE_VCG_H
00025 #define KDTREE_VCG_H
00026
00027 #include <vcg/space/point3.h>
00028 #include <vcg/space/box3.h>
00029 #include <vcg/space/index/kdtree/priorityqueue.h>
00030
00031 #include <vector>
00032 #include <limits>
00033 #include <iostream>
00034 #include <cstdint>
00035
00036 namespace vcg {
00037
00038 template<typename _DataType>
00039 class ConstDataWrapper
00040 {
00041 public:
00042 typedef _DataType DataType;
00043 inline ConstDataWrapper()
00044 : mpData(0), mStride(0), mSize(0)
00045 {}
00046 inline ConstDataWrapper(const DataType* pData, int size, int64_t stride = sizeof(DataType))
00047 : mpData(reinterpret_cast<const unsigned char*>(pData)), mStride(stride), mSize(size)
00048 {}
00049 inline const DataType& operator[] (int i) const
00050 {
00051 return *reinterpret_cast<const DataType*>(mpData + i*mStride);
00052 }
00053 inline size_t size() const { return mSize; }
00054 protected:
00055 const unsigned char* mpData;
00056 int64_t mStride;
00057 size_t mSize;
00058 };
00059
00060 template<class StdVectorType>
00061 class VectorConstDataWrapper :public ConstDataWrapper<typename StdVectorType::value_type>
00062 {
00063 public:
00064 inline VectorConstDataWrapper(StdVectorType &vec):
00065 ConstDataWrapper<typename StdVectorType::value_type> ( &(vec[0]), vec.size(), sizeof(typename StdVectorType::value_type))
00066 {}
00067 };
00068
00069 template<class MeshType>
00070 class VertexConstDataWrapper :public ConstDataWrapper<typename MeshType::CoordType>
00071 {
00072 public:
00073 inline VertexConstDataWrapper(MeshType &m):
00074 ConstDataWrapper<typename MeshType::CoordType> ( &(m.vert[0].P()), m.vert.size(), sizeof(typename MeshType::VertexType))
00075 {}
00076 };
00077
00082 template<typename _Scalar>
00083 class KdTree
00084 {
00085 public:
00086
00087 typedef _Scalar Scalar;
00088 typedef vcg::Point3<Scalar> VectorType;
00089 typedef vcg::Box3<Scalar> AxisAlignedBoxType;
00090
00091 typedef HeapMaxPriorityQueue<int, Scalar> PriorityQueue;
00092
00093 struct Node
00094 {
00095 union {
00096
00097 struct {
00098 Scalar splitValue;
00099 unsigned int firstChildId:24;
00100 unsigned int dim:2;
00101 unsigned int leaf:1;
00102 };
00103
00104 struct {
00105 unsigned int start;
00106 unsigned short size;
00107 };
00108 };
00109 };
00110 typedef std::vector<Node> NodeList;
00111
00112
00113 inline const NodeList& _getNodes(void) { return mNodes; }
00114 inline const std::vector<VectorType>& _getPoints(void) { return mPoints; }
00115
00116 public:
00117
00118 KdTree(const ConstDataWrapper<VectorType>& points, unsigned int nofPointsPerCell = 16, unsigned int maxDepth = 64);
00119
00120 ~KdTree();
00121
00122 void doQueryK(const VectorType& queryPoint, int k, PriorityQueue& mNeighborQueue);
00123
00124 void doQueryDist(const VectorType& queryPoint, float dist, std::vector<unsigned int>& points, std::vector<Scalar>& sqrareDists);
00125
00126 void doQueryClosest(const VectorType& queryPoint, unsigned int& index, Scalar& dist);
00127
00128 protected:
00129
00130
00131 struct QueryNode
00132 {
00133 QueryNode() {}
00134 QueryNode(unsigned int id) : nodeId(id) {}
00135 unsigned int nodeId;
00136 Scalar sq;
00137 };
00138
00139
00140
00141 unsigned int split(int start, int end, unsigned int dim, float splitValue);
00142
00143 int createTree(unsigned int nodeId, unsigned int start, unsigned int end, unsigned int level, unsigned int targetCellsize, unsigned int targetMaxDepth);
00144
00145 protected:
00146
00147 AxisAlignedBoxType mAABB;
00148 NodeList mNodes;
00149 std::vector<VectorType> mPoints;
00150 std::vector<unsigned int> mIndices;
00151 };
00152
00153 template<typename Scalar>
00154 KdTree<Scalar>::KdTree(const ConstDataWrapper<VectorType>& points, unsigned int nofPointsPerCell, unsigned int maxDepth)
00155 : mPoints(points.size()), mIndices(points.size())
00156 {
00157
00158 mPoints[0] = points[0];
00159 mAABB.Set(mPoints[0]);
00160 for (unsigned int i=1 ; i<mPoints.size() ; ++i)
00161 {
00162 mPoints[i] = points[i];
00163 mIndices[i] = i;
00164 mAABB.Add(mPoints[i]);
00165 }
00166
00167 mNodes.reserve(4*mPoints.size()/nofPointsPerCell);
00168
00169 mNodes.resize(1);
00170 mNodes.back().leaf = 0;
00171
00172 createTree(0, 0, mPoints.size(), 1, nofPointsPerCell, maxDepth);
00173 }
00174
00175 template<typename Scalar>
00176 KdTree<Scalar>::~KdTree()
00177 {
00178 }
00179
00180
00197 template<typename Scalar>
00198 void KdTree<Scalar>::doQueryK(const VectorType& queryPoint, int k, PriorityQueue& mNeighborQueue)
00199 {
00200 mNeighborQueue.setMaxSize(k);
00201 mNeighborQueue.init();
00202
00203 QueryNode mNodeStack[64];
00204 mNodeStack[0].nodeId = 0;
00205 mNodeStack[0].sq = 0.f;
00206 unsigned int count = 1;
00207
00208 while (count)
00209 {
00210
00211 QueryNode& qnode = mNodeStack[count-1];
00212
00213
00214
00215
00216 Node& node = mNodes[qnode.nodeId];
00217
00218
00219 if (mNeighborQueue.getNofElements() < k || qnode.sq < mNeighborQueue.getTopWeight())
00220 {
00221
00222 if (node.leaf)
00223 {
00224 --count;
00225
00226
00227 unsigned int end = node.start+node.size;
00228
00229 for (unsigned int i=node.start ; i<end ; ++i)
00230 mNeighborQueue.insert(mIndices[i], vcg::SquaredNorm(queryPoint - mPoints[i]));
00231 }
00232
00233 else
00234 {
00235
00236 float new_off = queryPoint[node.dim] - node.splitValue;
00237
00238
00239 if (new_off < 0.)
00240 {
00241 mNodeStack[count].nodeId = node.firstChildId;
00242
00243 qnode.nodeId = node.firstChildId+1;
00244 }
00245
00246 else
00247 {
00248 mNodeStack[count].nodeId = node.firstChildId+1;
00249 qnode.nodeId = node.firstChildId;
00250 }
00251
00252 mNodeStack[count].sq = qnode.sq;
00253
00254 qnode.sq = new_off*new_off;
00255 ++count;
00256 }
00257 }
00258 else
00259 {
00260
00261 --count;
00262 }
00263 }
00264 }
00265
00266
00272 template<typename Scalar>
00273 void KdTree<Scalar>::doQueryDist(const VectorType& queryPoint, float dist, std::vector<unsigned int>& points, std::vector<Scalar>& sqrareDists)
00274 {
00275 QueryNode mNodeStack[64];
00276 mNodeStack[0].nodeId = 0;
00277 mNodeStack[0].sq = 0.f;
00278 unsigned int count = 1;
00279
00280 float sqrareDist = dist*dist;
00281 while (count)
00282 {
00283 QueryNode& qnode = mNodeStack[count-1];
00284 Node & node = mNodes[qnode.nodeId];
00285
00286 if (qnode.sq < sqrareDist)
00287 {
00288 if (node.leaf)
00289 {
00290 --count;
00291 unsigned int end = node.start+node.size;
00292 for (unsigned int i=node.start ; i<end ; ++i)
00293 {
00294 float pointSquareDist = vcg::SquaredNorm(queryPoint - mPoints[i]);
00295 if (pointSquareDist < sqrareDist)
00296 {
00297 points.push_back(mIndices[i]);
00298 sqrareDists.push_back(pointSquareDist);
00299 }
00300 }
00301 }
00302 else
00303 {
00304
00305 float new_off = queryPoint[node.dim] - node.splitValue;
00306 if (new_off < 0.)
00307 {
00308 mNodeStack[count].nodeId = node.firstChildId;
00309 qnode.nodeId = node.firstChildId+1;
00310 }
00311 else
00312 {
00313 mNodeStack[count].nodeId = node.firstChildId+1;
00314 qnode.nodeId = node.firstChildId;
00315 }
00316 mNodeStack[count].sq = qnode.sq;
00317 qnode.sq = new_off*new_off;
00318 ++count;
00319 }
00320 }
00321 else
00322 {
00323
00324 --count;
00325 }
00326 }
00327 }
00328
00329
00335 template<typename Scalar>
00336 void KdTree<Scalar>::doQueryClosest(const VectorType& queryPoint, unsigned int& index, Scalar& dist)
00337 {
00338 QueryNode mNodeStack[64];
00339 mNodeStack[0].nodeId = 0;
00340 mNodeStack[0].sq = 0.f;
00341 unsigned int count = 1;
00342
00343 int minIndex = mIndices.size() / 2;
00344 Scalar minDist = vcg::SquaredNorm(queryPoint - mPoints[minIndex]);
00345 minIndex = mIndices[minIndex];
00346
00347 while (count)
00348 {
00349 QueryNode& qnode = mNodeStack[count-1];
00350 Node & node = mNodes[qnode.nodeId];
00351
00352 if (qnode.sq < minDist)
00353 {
00354 if (node.leaf)
00355 {
00356 --count;
00357 unsigned int end = node.start+node.size;
00358 for (unsigned int i=node.start ; i<end ; ++i)
00359 {
00360 float pointSquareDist = vcg::SquaredNorm(queryPoint - mPoints[i]);
00361 if (pointSquareDist < minDist)
00362 {
00363 minDist = pointSquareDist;
00364 minIndex = mIndices[i];
00365 }
00366 }
00367 }
00368 else
00369 {
00370
00371 float new_off = queryPoint[node.dim] - node.splitValue;
00372 if (new_off < 0.)
00373 {
00374 mNodeStack[count].nodeId = node.firstChildId;
00375 qnode.nodeId = node.firstChildId+1;
00376 }
00377 else
00378 {
00379 mNodeStack[count].nodeId = node.firstChildId+1;
00380 qnode.nodeId = node.firstChildId;
00381 }
00382 mNodeStack[count].sq = qnode.sq;
00383 qnode.sq = new_off*new_off;
00384 ++count;
00385 }
00386 }
00387 else
00388 {
00389
00390 --count;
00391 }
00392 }
00393 index = minIndex;
00394 dist = minDist;
00395 }
00396
00397
00398
00404 template<typename Scalar>
00405 unsigned int KdTree<Scalar>::split(int start, int end, unsigned int dim, float splitValue)
00406 {
00407 int l(start), r(end-1);
00408 for ( ; l<r ; ++l, --r)
00409 {
00410 while (l < end && mPoints[l][dim] < splitValue)
00411 l++;
00412 while (r >= start && mPoints[r][dim] >= splitValue)
00413 r--;
00414 if (l > r)
00415 break;
00416 std::swap(mPoints[l],mPoints[r]);
00417 std::swap(mIndices[l],mIndices[r]);
00418 }
00419
00420 return (mPoints[l][dim] < splitValue ? l+1 : l);
00421 }
00422
00440 template<typename Scalar>
00441 int KdTree<Scalar>::createTree(unsigned int nodeId, unsigned int start, unsigned int end, unsigned int level, unsigned int targetCellSize, unsigned int targetMaxDepth)
00442 {
00443
00444 Node& node = mNodes[nodeId];
00445 AxisAlignedBoxType aabb;
00446
00447
00448 aabb.Set(mPoints[start]);
00449 for (unsigned int i=start+1 ; i<end ; ++i)
00450 aabb.Add(mPoints[i]);
00451
00452
00453 VectorType diag = aabb.max - aabb.min;
00454
00455
00456 unsigned int dim;
00457 if (diag.X() > diag.Y())
00458 dim = diag.X() > diag.Z() ? 0 : 2;
00459 else
00460 dim = diag.Y() > diag.Z() ? 1 : 2;
00461
00462 node.dim = dim;
00463
00464 node.splitValue = Scalar(0.5*(aabb.max[dim] + aabb.min[dim]));
00465
00466
00467 unsigned int midId = split(start, end, dim, node.splitValue);
00468
00469
00470 node.firstChildId = mNodes.size();
00471 mNodes.resize(mNodes.size()+2);
00472 int leftLevel, rightLevel;
00473
00474 {
00475
00476 unsigned int childId = mNodes[nodeId].firstChildId;
00477 Node& child = mNodes[childId];
00478 if (midId - start <= targetCellSize || level>=targetMaxDepth)
00479 {
00480 child.leaf = 1;
00481 child.start = start;
00482 child.size = midId - start;
00483 leftLevel = level;
00484 }
00485 else
00486 {
00487 child.leaf = 0;
00488 leftLevel = createTree(childId, start, midId, level+1, targetCellSize, targetMaxDepth);
00489 }
00490 }
00491
00492 {
00493
00494 unsigned int childId = mNodes[nodeId].firstChildId+1;
00495 Node& child = mNodes[childId];
00496 if (end - midId <= targetCellSize || level>=targetMaxDepth)
00497 {
00498 child.leaf = 1;
00499 child.start = midId;
00500 child.size = end - midId;
00501 rightLevel = level;
00502 }
00503 else
00504 {
00505 child.leaf = 0;
00506 rightLevel = createTree(childId, midId, end, level+1, targetCellSize, targetMaxDepth);
00507 }
00508 }
00509 if (leftLevel > rightLevel)
00510 return leftLevel;
00511 return rightLevel;
00512 }
00513 }
00514
00515 #endif