kdtree_opencl.cpp
Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2010--2011, Stephane Magnenat, ASL, ETHZ, Switzerland
00004 You can contact the author at <stephane at magnenat dot net>
00005 
00006 All rights reserved.
00007 
00008 Redistribution and use in source and binary forms, with or without
00009 modification, are permitted provided that the following conditions are met:
00010     * Redistributions of source code must retain the above copyright
00011       notice, this list of conditions and the following disclaimer.
00012     * Redistributions in binary form must reproduce the above copyright
00013       notice, this list of conditions and the following disclaimer in the
00014       documentation and/or other materials provided with the distribution.
00015     * Neither the name of the <organization> nor the
00016       names of its contributors may be used to endorse or promote products
00017       derived from this software without specific prior written permission.
00018 
00019 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
00020 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00021 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00022 DISCLAIMED. IN NO EVENT SHALL ETH-ASL BE LIABLE FOR ANY
00023 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00024 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00025 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00026 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00028 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 #ifdef HAVE_OPENCL
00033 
00034 #include "nabo_private.h"
00035 #include "index_heap.h"
00036 #include <iostream>
00037 #include <sstream>
00038 #include <fstream>
00039 #include <stdexcept>
00040 #include <limits>
00041 #include <queue>
00042 #include <algorithm>
00043 // #include <map>
00044 
00045 
00051 namespace cl
00052 {
00054         typedef std::vector<Device> Devices;
00055 }
00056 
00057 namespace Nabo
00058 {
00060 
00063         template<typename T, typename CloudType>
00064         size_t argMax(const typename NearestNeighbourSearch<T, CloudType>::Vector& v)
00065         {
00066                 T maxVal(0);
00067                 size_t maxIdx(0);
00068                 for (int i = 0; i < v.size(); ++i)
00069                 {
00070                         if (v[i] > maxVal)
00071                         {
00072                                 maxVal = v[i];
00073                                 maxIdx = i;
00074                         }
00075                 }
00076                 return maxIdx;
00077         }
00078         
00080 
00081         
00083         #define MAX_K 32
00084         
00085         using namespace std;
00086         
00088         template<typename T, typename CloudType>
00089         struct EnableCLTypeSupport {};
00090         
00092         template<typename CloudType>
00093         struct EnableCLTypeSupport<float, CloudType>
00094         {
00096                 static string code(const cl::Device& device)
00097                 {
00098                         return "typedef float T;\n";
00099                 }
00100         };
00101         
00103         template<typename CloudType>
00104         struct EnableCLTypeSupport<double, CloudType>
00105         {
00107 
00108                 static string code(const cl::Device& device)
00109                 {
00110                         string s;
00111                         const string& exts(device.getInfo<CL_DEVICE_EXTENSIONS>());
00112                         //cerr << "extensions: " << exts << endl;
00113                         // first try generic 64-bits fp, otherwise try to fall back on vendor-specific extensions
00114                         if (exts.find("cl_khr_fp64") != string::npos)
00115                                 s += "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
00116                         else if (exts.find("cl_amd_fp64") != string::npos)
00117                                 s += "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n";
00118                         else
00119                                 throw runtime_error("The OpenCL platform does not support 64 bits double-precision floating-points scalars.");
00120                         s += "typedef double T;\n";
00121                         return s;
00122                 }
00123         };
00124         
00126         struct SourceCacher
00127         {
00129                 typedef std::vector<cl::Device> Devices;
00131                 typedef std::map<std::string, cl::Program> ProgramCache;
00132                 
00133                 cl::Context context; 
00134                 Devices devices; 
00135                 ProgramCache cachedPrograms; 
00136                 
00138                 SourceCacher(const cl_device_type deviceType)
00139                 {
00140                         // looking for platforms, AMD drivers do not like the default for creating context
00141                         vector<cl::Platform> platforms;
00142                         cl::Platform::get(&platforms);
00143                         if (platforms.empty())
00144                                 throw runtime_error("No OpenCL platform found");
00145                         //for(vector<cl::Platform>::iterator i = platforms.begin(); i != platforms.end(); ++i)
00146                         //      cerr << "platform " << i - platforms.begin() << " is " << (*i).getInfo<CL_PLATFORM_VENDOR>() << endl;
00147                         cl::Platform platform = platforms[0];
00148                         const char *userDefinedPlatform(getenv("NABO_OPENCL_USE_PLATFORM"));
00149                         if (userDefinedPlatform)
00150                         {
00151                                 size_t userDefinedPlatformId = atoi(userDefinedPlatform);
00152                                 if (userDefinedPlatformId < platforms.size())
00153                                         platform = platforms[userDefinedPlatformId];
00154                         }
00155                         
00156                         // create OpenCL contexts
00157                         cl_context_properties properties[] = { CL_CONTEXT_PLATFORM, (cl_context_properties)platform(), 0 };
00158                         bool deviceFound = false;
00159                         try {
00160                                 context = cl::Context(deviceType, properties);
00161                                 deviceFound = true;
00162                         } catch (const cl::Error& e) {
00163                                 cerr << "Cannot find device type " << deviceType << " for OpenCL, falling back to any device" << endl;
00164                         }
00165                         if (!deviceFound)
00166                                 context = cl::Context(CL_DEVICE_TYPE_ALL, properties);
00167                         devices = context.getInfo<CL_CONTEXT_DEVICES>();
00168                         if (devices.empty())
00169                                 throw runtime_error("No devices on OpenCL platform");
00170                 }
00171                 
00173                 ~SourceCacher()
00174                 {
00175                         cerr << "Destroying source cacher containing " << cachedPrograms.size() << " cached programs" << endl;
00176                 }
00177                 
00179                 bool contains(const std::string& source)
00180                 {
00181                         return cachedPrograms.find(source) != cachedPrograms.end();
00182                 }
00183         };
00184         
00186         class ContextManager
00187         {
00188         public:
00190                 typedef std::map<cl_device_type, SourceCacher*> Devices;
00191                 
00193                 ~ContextManager()
00194                 {
00195                         cerr << "Destroying CL context manager, used " << devices.size() << " contexts" << endl;
00196                         for (Devices::iterator it(devices.begin()); it != devices.end(); ++it)
00197                                 delete it->second;
00198                 }
00200                 cl::Context& createContext(const cl_device_type deviceType)
00201                 {
00202                         std::lock_guard lock(mutex);
00203                         Devices::iterator it(devices.find(deviceType));
00204                         if (it == devices.end())
00205                         {
00206                                 it = devices.insert(
00207                                         pair<cl_device_type, SourceCacher*>(deviceType, new SourceCacher(deviceType))
00208                                         ).first;
00209                         }
00210                         return it->second->context;
00211                 }
00213                 SourceCacher* getSourceCacher(const cl_device_type deviceType)
00214                 {
00215                         std::lock_guard lock(mutex);
00216                         Devices::iterator it(devices.find(deviceType));
00217                         if (it == devices.end())
00218                                 throw runtime_error("Attempt to get source cacher before creating a context");
00219                         return it->second;
00220                 }
00221                 
00222         protected:
00223                 Devices devices; 
00224                 std::mutex mutex; 
00225         };
00226         
00228         static ContextManager contextManager;
00229         
00230         template<typename T, typename CloudType>
00231         OpenCLSearch<T, CloudType>::OpenCLSearch(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00232                 NearestNeighbourSearch<T, CloudType>::NearestNeighbourSearch(cloud, dim, creationOptionFlags),
00233                 deviceType(deviceType),
00234                 context(contextManager.createContext(deviceType))
00235         {
00236         }
00237         
00238         template<typename T, typename CloudType>
00239         void OpenCLSearch<T, CloudType>::initOpenCL(const char* clFileName, const char* kernelName, const std::string& additionalDefines)
00240         {
00241                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T, CloudType>::TOUCH_STATISTICS);
00242                 
00243                 SourceCacher* sourceCacher(contextManager.getSourceCacher(deviceType));
00244                 SourceCacher::Devices& devices(sourceCacher->devices);
00245                 
00246                 // build and load source files
00247                 cl::Program::Sources sources;
00248                 // build defines
00249                 ostringstream oss;
00250                 oss << EnableCLTypeSupport<T, CloudType>::code(devices.back());
00251                 oss << "#define EPSILON " << numeric_limits<T>::epsilon() << "\n";
00252                 oss << "#define DIM_COUNT " << dim << "\n";
00253                 //oss << "#define CLOUD_POINT_COUNT " << cloud.cols() << "\n";
00254                 oss << "#define POINT_STRIDE " << cloud.stride() << "\n";
00255                 oss << "#define MAX_K " << MAX_K << "\n";
00256                 if (collectStatistics)
00257                         oss << "#define TOUCH_STATISTICS\n";
00258                 oss << additionalDefines;
00259                 //cerr << "params:\n" << oss.str() << endl;
00260                 
00261                 const std::string& source(oss.str());
00262                 if (!sourceCacher->contains(source))
00263                 {
00264                         const size_t defLen(source.length());
00265                         char *defContent(new char[defLen+1]);
00266                         strcpy(defContent, source.c_str());
00267                         sources.push_back(std::make_pair(defContent, defLen));
00268                         string sourceFileName(OPENCL_SOURCE_DIR);
00269                         sourceFileName += clFileName;
00270                         // load files
00271                         const char* files[] = {
00272                                 OPENCL_SOURCE_DIR "structure.cl",
00273                                 OPENCL_SOURCE_DIR "heap.cl",
00274                                 sourceFileName.c_str(),
00275                                 NULL 
00276                         };
00277                         for (const char** file = files; *file != NULL; ++file)
00278                         {
00279                                 std::ifstream stream(*file);
00280                                 if (!stream.good())
00281                                         throw runtime_error((string("cannot open file: ") + *file));
00282                                 
00283                                 stream.seekg(0, std::ios_base::end);
00284                                 size_t size(stream.tellg());
00285                                 stream.seekg(0, std::ios_base::beg);
00286                                 
00287                                 char* content(new char[size + 1]);
00288                                 std::copy(std::istreambuf_iterator<char>(stream),
00289                                                         std::istreambuf_iterator<char>(), content);
00290                                 content[size] = '\0';
00291                                 
00292                                 sources.push_back(std::make_pair(content, size));
00293                         }
00294                         sourceCacher->cachedPrograms[source] = cl::Program(context, sources);
00295                         cl::Program& program = sourceCacher->cachedPrograms[source];
00296                         
00297                         // build
00298                         cl::Error error(CL_SUCCESS);
00299                         try {
00300                                 program.build(devices);
00301                         } catch (cl::Error e) {
00302                                 error = e;
00303                         }
00304                         
00305                         // dump
00306                         for (cl::Devices::const_iterator it = devices.begin(); it != devices.end(); ++it)
00307                         {
00308                                 cerr << "device : " << it->getInfo<CL_DEVICE_NAME>() << "\n";
00309                                 cerr << "compilation log:\n" << program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(*it) << endl;
00310                         }
00311                         // cleanup sources
00312                         for (cl::Program::Sources::iterator it = sources.begin(); it != sources.end(); ++it)
00313                         {
00314                                 delete[] it->first;
00315                         }
00316                         sources.clear();
00317                         
00318                         // make sure to stop if compilation failed
00319                         if (error.err() != CL_SUCCESS)
00320                                 throw error;
00321                 }
00322                 cl::Program& program = sourceCacher->cachedPrograms[source];
00323                 
00324                 // build kernel and command queue
00325                 knnKernel = cl::Kernel(program, kernelName); 
00326                 queue = cl::CommandQueue(context, devices.back());
00327                 
00328                 // map cloud
00329                 if (!(cloud.Flags & Eigen::DirectAccessBit) || (cloud.Flags & Eigen::RowMajorBit))
00330                         throw runtime_error("wrong memory mapping in point cloud");
00331                 const size_t cloudCLSize(cloud.cols() * cloud.stride() * sizeof(T));
00332                 cloudCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, cloudCLSize, const_cast<T*>(&cloud.coeff(0,0)));
00333                 knnKernel.setArg(0, sizeof(cl_mem), &cloudCL);
00334         }
00335         
00336         template<typename T, typename CloudType>
00337         unsigned long OpenCLSearch<T, CloudType>::knn(const Matrix& query, IndexMatrix& indices, Matrix& dists2, const Index k, const T epsilon, const unsigned optionFlags, const T maxRadius) const
00338         {
00339                 checkSizesKnn(query, indices, dists2, k, optionFlags);
00340                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T, CloudType>::TOUCH_STATISTICS);
00341                 
00342                 // check K
00343                 if (k > MAX_K)
00344                         throw runtime_error("number of neighbors too large for OpenCL");
00345                 
00346                 // check consistency of query wrt cloud
00347                 if (query.stride() != cloud.stride() ||
00348                         query.rows() != cloud.rows())
00349                         throw runtime_error("query is not of the same dimensionality as the point cloud");
00350                 
00351                 // map query
00352                 if (!(query.Flags & Eigen::DirectAccessBit) || (query.Flags & Eigen::RowMajorBit))
00353                         throw runtime_error("wrong memory mapping in query data");
00354                 const size_t queryCLSize(query.cols() * query.stride() * sizeof(T));
00355                 cl::Buffer queryCL(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, queryCLSize, const_cast<T*>(&query.coeff(0,0)));
00356                 knnKernel.setArg(1, sizeof(cl_mem), &queryCL);
00357                 // map indices
00358                 assert((indices.Flags & Eigen::DirectAccessBit) && (!(indices.Flags & Eigen::RowMajorBit)));
00359                 const int indexStride(indices.stride());
00360                 const size_t indicesCLSize(indices.cols() * indexStride * sizeof(int));
00361                 cl::Buffer indicesCL(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR, indicesCLSize, &indices.coeffRef(0,0));
00362                 knnKernel.setArg(2, sizeof(cl_mem), &indicesCL);
00363                 // map dists2
00364                 assert((dists2.Flags & Eigen::DirectAccessBit) && (!(dists2.Flags & Eigen::RowMajorBit)));
00365                 const int dists2Stride(dists2.stride());
00366                 const size_t dists2CLSize(dists2.cols() * dists2Stride * sizeof(T));
00367                 cl::Buffer dists2CL(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR, dists2CLSize, &dists2.coeffRef(0,0));
00368                 knnKernel.setArg(3, sizeof(cl_mem), &dists2CL);
00369                 
00370                 // set resulting parameters
00371                 knnKernel.setArg(4, k);
00372                 knnKernel.setArg(5, (1 + epsilon)*(1 + epsilon));
00373                 knnKernel.setArg(6, maxRadius*maxRadius);
00374                 knnKernel.setArg(7, optionFlags);
00375                 knnKernel.setArg(8, indexStride);
00376                 knnKernel.setArg(9, dists2Stride);
00377                 knnKernel.setArg(10, cl_uint(cloud.cols()));
00378                 
00379                 // if required, map visit count
00380                 vector<cl_uint> visitCounts;
00381                 const size_t visitCountCLSize(query.cols() * sizeof(cl_uint));
00382                 cl::Buffer visitCountCL;
00383                 if (collectStatistics)
00384                 {
00385                         visitCounts.resize(query.cols());
00386                         visitCountCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, visitCountCLSize, &visitCounts[0]);
00387                         knnKernel.setArg(11, sizeof(cl_mem), &visitCountCL);
00388                 }
00389                 
00390                 // execute query
00391                 queue.enqueueNDRangeKernel(knnKernel, cl::NullRange, cl::NDRange(query.cols()), cl::NullRange);
00392                 queue.enqueueMapBuffer(indicesCL, true, CL_MAP_READ, 0, indicesCLSize, 0, 0);
00393                 queue.enqueueMapBuffer(dists2CL, true, CL_MAP_READ, 0, dists2CLSize, 0, 0);
00394                 if (collectStatistics)
00395                         queue.enqueueMapBuffer(visitCountCL, true, CL_MAP_READ, 0, visitCountCLSize, 0, 0);
00396                 queue.finish();
00397                 
00398                 // if required, collect statistics
00399                 if (collectStatistics)
00400                 {
00401                         unsigned long totalVisitCounts(0);
00402                         for (size_t i = 0; i < visitCounts.size(); ++i)
00403                                 totalVisitCounts += (unsigned long)visitCounts[i];
00404                         return totalVisitCounts;
00405                 }
00406                 else
00407                         return 0;
00408         }
00409         
00410         template<typename T, typename CloudType>
00411         BruteForceSearchOpenCL<T, CloudType>::BruteForceSearchOpenCL(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00412         OpenCLSearch<T, CloudType>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00413         {
00414 #ifdef EIGEN3_API
00415                 const_cast<Vector&>(this->minBound) = cloud.topRows(this->dim).rowwise().minCoeff();
00416                 const_cast<Vector&>(this->maxBound) = cloud.topRows(this->dim).rowwise().maxCoeff();
00417 #else // EIGEN3_API
00418                 // compute bounds
00419                 for (int i = 0; i < cloud.cols(); ++i)
00420                 {
00421                         const Vector& v(cloud.block(0,i,this->dim,1));
00422                         const_cast<Vector&>(this->minBound) = this->minBound.cwise().min(v);
00423                         const_cast<Vector&>(this->maxBound) = this->maxBound.cwise().max(v);
00424                 }
00425 #endif // EIGEN3_API
00426                 // init openCL
00427                 initOpenCL("knn_bf.cl", "knnBruteForce");
00428         }
00429 
00430         template struct BruteForceSearchOpenCL<float>;
00431         template struct BruteForceSearchOpenCL<double>;
00432         template struct BruteForceSearchOpenCL<float, Eigen::Matrix3Xf>;
00433         template struct BruteForceSearchOpenCL<double, Eigen::Matrix3Xd>;
00434         template struct BruteForceSearchOpenCL<float, Eigen::Map<const Eigen::Matrix3Xf, Eigen::Aligned> >;
00435         template struct BruteForceSearchOpenCL<double, Eigen::Map<const Eigen::Matrix3Xd, Eigen::Aligned> >;
00436         
00437         
00438 
00439         template<typename T, typename CloudType>
00440         size_t KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::getTreeSize(size_t elCount) const
00441         {
00442                 // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00443                 assert(elCount > 0);
00444                 elCount --;
00445                 size_t count = 0;
00446                 int i = 31;
00447                 for (; i >= 0; --i)
00448                 {
00449                         if (elCount & (1 << i))
00450                                 break;
00451                 }
00452                 for (int j = 0; j <= i; ++j)
00453                         count |= (1 << j);
00454                 count <<= 1;
00455                 count |= 1;
00456                 return count;
00457         }
00458         
00459         template<typename T, typename CloudType>
00460         size_t KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::getTreeDepth(size_t elCount) const
00461         {
00462                 if (elCount <= 1)
00463                         return 0;
00464                 elCount --;
00465                 size_t i = 31;
00466                 for (; i >= 0; --i)
00467                 {
00468                         if (elCount & (1 << i))
00469                                 break;
00470                 }
00471                 return i+1;
00472         }
00473 
00474         template<typename T, typename CloudType>
00475         void KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::buildNodes(const BuildPointsIt first, const BuildPointsIt last, const size_t pos, const Vector minValues, const Vector maxValues)
00476         {
00477                 const size_t count(last - first);
00478                 //cerr << count << endl;
00479                 if (count == 1)
00480                 {
00481                         const int d = -2-(first->index);
00482                         assert(pos < nodes.size());
00483                         nodes[pos] = Node(d);
00484                         return;
00485                 }
00486                 
00487                 // find the largest dimension of the box
00488                 size_t cutDim = argMax<T, CloudType>(maxValues - minValues);
00489                 
00490                 // compute number of elements
00491                 const size_t rightCount(count/2);
00492                 const size_t leftCount(count - rightCount);
00493                 assert(last - rightCount == first + leftCount);
00494                 
00495                 // sort
00496                 nth_element(first, first + leftCount, last, CompareDim(cutDim));
00497                 
00498                 // set node
00499                 const T cutVal((first+leftCount)->pos.coeff(cutDim));
00500                 nodes[pos] = Node(cutDim, cutVal);
00501                 
00502                 //cerr << pos << " cutting on " << cutDim << " at " << (first+leftCount)->pos[cutDim] << endl;
00503                 
00504                 // update bounds for left
00505                 Vector leftMaxValues(maxValues);
00506                 leftMaxValues[cutDim] = cutVal;
00507                 // update bounds for right
00508                 Vector rightMinValues(minValues);
00509                 rightMinValues[cutDim] = cutVal;
00510                 
00511                 // recurse
00512                 buildNodes(first, first + leftCount, childLeft(pos), minValues, leftMaxValues);
00513                 buildNodes(first + leftCount, last, childRight(pos), rightMinValues, maxValues);
00514         }
00515         
00516         template<typename T, typename CloudType>
00517         KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::KDTreeBalancedPtInLeavesStackOpenCL(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00518                 OpenCLSearch<T, CloudType>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00519         {
00520                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T>::TOUCH_STATISTICS);
00521                 
00522                 // build point vector and compute bounds
00523                 BuildPoints buildPoints;
00524                 buildPoints.reserve(cloud.cols());
00525                 for (int i = 0; i < cloud.cols(); ++i)
00526                 {
00527                         const Vector& v(cloud.block(0,i,this->dim,1));
00528                         buildPoints.push_back(BuildPoint(v, i));
00529 #ifdef EIGEN3_API
00530                         const_cast<Vector&>(minBound) = minBound.array().min(v.array());
00531                         const_cast<Vector&>(maxBound) = maxBound.array().max(v.array());
00532 #else // EIGEN3_API
00533                         const_cast<Vector&>(minBound) = minBound.cwise().min(v);
00534                         const_cast<Vector&>(maxBound) = maxBound.cwise().max(v);
00535 #endif // EIGEN3_API
00536                 }
00537                 
00538                 // create nodes
00539                 nodes.resize(getTreeSize(cloud.cols()));
00540                 buildNodes(buildPoints.begin(), buildPoints.end(), 0, minBound, maxBound);
00541                 const unsigned maxStackDepth(getTreeDepth(nodes.size()) + 1);
00542                 
00543                 // init openCL
00544                 initOpenCL("knn_kdtree_pt_in_leaves.cl", "knnKDTree", "#define MAX_STACK_DEPTH " + std::to_string(maxStackDepth) + "\n");
00545                 
00546                 // map nodes, for info about alignment, see sect 6.1.5 
00547                 const size_t nodesCLSize(nodes.size() * sizeof(Node));
00548                 nodesCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, nodesCLSize, &nodes[0]);
00549                 if (collectStatistics)
00550                         knnKernel.setArg(12, sizeof(cl_mem), &nodesCL);
00551                 else
00552                         knnKernel.setArg(11, sizeof(cl_mem), &nodesCL);
00553         }
00554 
00555         template struct KDTreeBalancedPtInLeavesStackOpenCL<float>;
00556         template struct KDTreeBalancedPtInLeavesStackOpenCL<double>;
00557         template struct KDTreeBalancedPtInLeavesStackOpenCL<float, Eigen::Matrix3Xf>;
00558         template struct KDTreeBalancedPtInLeavesStackOpenCL<double, Eigen::Matrix3Xd>;
00559         template struct KDTreeBalancedPtInLeavesStackOpenCL<float, Eigen::Map<const Eigen::Matrix3Xf, Eigen::Aligned> >;
00560         template struct KDTreeBalancedPtInLeavesStackOpenCL<double, Eigen::Map<const Eigen::Matrix3Xd, Eigen::Aligned> >;
00561         
00562         
00563         template<typename T, typename CloudType>
00564         size_t KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::getTreeSize(size_t elCount) const
00565         {
00566                 // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00567                 size_t count = 0;
00568                 int i = 31;
00569                 for (; i >= 0; --i)
00570                 {
00571                         if (elCount & (1 << i))
00572                                 break;
00573                 }
00574                 for (int j = 0; j <= i; ++j)
00575                         count |= (1 << j);
00576                 //cerr << "tree size " << count << " (" << elCount << " elements)\n";
00577                 return count;
00578         }
00579         
00580         template<typename T, typename CloudType>
00581         size_t KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::getTreeDepth(size_t elCount) const
00582         {
00583                 // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00584                 int i = 31;
00585                 for (; i >= 0; --i)
00586                 {
00587                         if (elCount & (1 << i))
00588                                 break;
00589                 }
00590                 return i + 1;
00591         }
00592         
00593         template<typename T, typename CloudType>
00594         void KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::buildNodes(const BuildPointsIt first, const BuildPointsIt last, const size_t pos, const Vector minValues, const Vector maxValues)
00595         {
00596                 const size_t count(last - first);
00597                 //cerr << count << endl;
00598                 if (count == 1)
00599                 {
00600                         nodes[pos] = Node(-1, *first);
00601                         return;
00602                 }
00603                 
00604                 // find the largest dimension of the box
00605                 const size_t cutDim = argMax<T, CloudType>(maxValues - minValues);
00606                 
00607                 // compute number of elements
00608                 const size_t recurseCount(count-1);
00609                 const size_t rightCount(recurseCount/2);
00610                 const size_t leftCount(recurseCount-rightCount);
00611                 assert(last - rightCount == first + leftCount + 1);
00612                 
00613                 // sort
00614                 nth_element(first, first + leftCount, last, CompareDim(cloud, cutDim));
00615                 
00616                 // set node
00617                 const Index index(*(first+leftCount));
00618                 const T cutVal(cloud.coeff(cutDim, index));
00619                 nodes[pos] = Node(cutDim, index);
00620                 
00621                 //cerr << pos << " cutting on " << cutDim << " at " << (first+leftCount)->pos[cutDim] << endl;
00622                 
00623                 // update bounds for left
00624                 Vector leftMaxValues(maxValues);
00625                 leftMaxValues[cutDim] = cutVal;
00626                 // update bounds for right
00627                 Vector rightMinValues(minValues);
00628                 rightMinValues[cutDim] = cutVal;
00629                 
00630                 // recurse
00631                 if (count > 2)
00632                 {
00633                         buildNodes(first, first + leftCount, childLeft(pos), minValues, leftMaxValues);
00634                         buildNodes(first + leftCount + 1, last, childRight(pos), rightMinValues, maxValues);
00635                 }
00636                 else
00637                 {
00638                         nodes[childLeft(pos)] = Node(-1, *first);
00639                         nodes[childRight(pos)] = Node(-2, 0);
00640                 }
00641         }
00642         
00643         template<typename T, typename CloudType>
00644         KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::KDTreeBalancedPtInNodesStackOpenCL(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00645         OpenCLSearch<T, CloudType>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00646         {
00647                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T, CloudType>::TOUCH_STATISTICS);
00648                 
00649                 // build point vector and compute bounds
00650                 BuildPoints buildPoints;
00651                 buildPoints.reserve(cloud.cols());
00652                 for (int i = 0; i < cloud.cols(); ++i)
00653                 {
00654                         buildPoints.push_back(i);
00655                         const Vector& v(cloud.block(0,i,this->dim,1));
00656 #ifdef EIGEN3_API
00657                         const_cast<Vector&>(minBound) = minBound.array().min(v.array());
00658                         const_cast<Vector&>(maxBound) = maxBound.array().max(v.array());
00659 #else // EIGEN3_API
00660                         const_cast<Vector&>(minBound) = minBound.cwise().min(v);
00661                         const_cast<Vector&>(maxBound) = maxBound.cwise().max(v);
00662 #endif // EIGEN3_API
00663                 }
00664                 
00665                 // create nodes
00666                 nodes.resize(getTreeSize(cloud.cols()));
00667                 buildNodes(buildPoints.begin(), buildPoints.end(), 0, minBound, maxBound);
00668                 const unsigned maxStackDepth(getTreeDepth(nodes.size()) + 1);
00669                 
00670                 // init openCL
00671                 initOpenCL("knn_kdtree_pt_in_nodes.cl", "knnKDTree", "#define MAX_STACK_DEPTH " + std::to_string(maxStackDepth) + "\n");
00672                 
00673                 // map nodes, for info about alignment, see sect 6.1.5 
00674                 const size_t nodesCLSize(nodes.size() * sizeof(Node));
00675                 nodesCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, nodesCLSize, &nodes[0]);
00676                 if (collectStatistics)
00677                         knnKernel.setArg(12, sizeof(cl_mem), &nodesCL);
00678                 else
00679                         knnKernel.setArg(11, sizeof(cl_mem), &nodesCL);
00680         }
00681         
00682         template struct KDTreeBalancedPtInNodesStackOpenCL<float>;
00683         template struct KDTreeBalancedPtInNodesStackOpenCL<double>;
00684         template struct KDTreeBalancedPtInNodesStackOpenCL<float, Eigen::Matrix3Xf>;
00685         template struct KDTreeBalancedPtInNodesStackOpenCL<double, Eigen::Matrix3Xd>;
00686         template struct KDTreeBalancedPtInNodesStackOpenCL<float, Eigen::Map<const Eigen::Matrix3Xf, Eigen::Aligned> >;
00687         template struct KDTreeBalancedPtInNodesStackOpenCL<double, Eigen::Map<const Eigen::Matrix3Xd, Eigen::Aligned> >;
00688         
00690 }
00691 
00692 #endif // HAVE_OPENCL
00693 /* vim: set ts=8 sw=8 tw=0 noexpandtab cindent softtabstop=8 :*/


libnabo
Author(s): Stéphane Magnenat
autogenerated on Sun Feb 10 2019 03:52:20