kdtree_opencl.cpp
Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2010--2011, Stephane Magnenat, ASL, ETHZ, Switzerland
00004 You can contact the author at <stephane at magnenat dot net>
00005 
00006 All rights reserved.
00007 
00008 Redistribution and use in source and binary forms, with or without
00009 modification, are permitted provided that the following conditions are met:
00010     * Redistributions of source code must retain the above copyright
00011       notice, this list of conditions and the following disclaimer.
00012     * Redistributions in binary form must reproduce the above copyright
00013       notice, this list of conditions and the following disclaimer in the
00014       documentation and/or other materials provided with the distribution.
00015     * Neither the name of the <organization> nor the
00016       names of its contributors may be used to endorse or promote products
00017       derived from this software without specific prior written permission.
00018 
00019 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
00020 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00021 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00022 DISCLAIMED. IN NO EVENT SHALL ETH-ASL BE LIABLE FOR ANY
00023 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00024 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00025 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00026 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00028 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 #ifdef HAVE_OPENCL
00033 
00034 #include "nabo_private.h"
00035 #include "index_heap.h"
00036 #include <iostream>
00037 #include <sstream>
00038 #include <fstream>
00039 #include <stdexcept>
00040 #include <limits>
00041 #include <queue>
00042 #include <algorithm>
00043 // #include <map>
00044 #include <boost/numeric/conversion/bounds.hpp>
00045 #include <boost/limits.hpp>
00046 #include <boost/format.hpp>
00047 #include <boost/thread.hpp>
00048 
00049 
00055 namespace cl
00056 {
00058         typedef std::vector<Device> Devices;
00059 }
00060 
00061 namespace Nabo
00062 {
00064 
00067         template<typename T, typename CloudType>
00068         size_t argMax(const typename NearestNeighbourSearch<T, CloudType>::Vector& v)
00069         {
00070                 T maxVal(0);
00071                 size_t maxIdx(0);
00072                 for (int i = 0; i < v.size(); ++i)
00073                 {
00074                         if (v[i] > maxVal)
00075                         {
00076                                 maxVal = v[i];
00077                                 maxIdx = i;
00078                         }
00079                 }
00080                 return maxIdx;
00081         }
00082         
00084 
00085         
00087         #define MAX_K 32
00088         
00089         using namespace std;
00090         
00092         template<typename T, typename CloudType>
00093         struct EnableCLTypeSupport {};
00094         
00096         template<typename CloudType>
00097         struct EnableCLTypeSupport<float, CloudType>
00098         {
00100                 static string code(const cl::Device& device)
00101                 {
00102                         return "typedef float T;\n";
00103                 }
00104         };
00105         
00107         template<typename CloudType>
00108         struct EnableCLTypeSupport<double, CloudType>
00109         {
00111 
00112                 static string code(const cl::Device& device)
00113                 {
00114                         string s;
00115                         const string& exts(device.getInfo<CL_DEVICE_EXTENSIONS>());
00116                         //cerr << "extensions: " << exts << endl;
00117                         // first try generic 64-bits fp, otherwise try to fall back on vendor-specific extensions
00118                         if (exts.find("cl_khr_fp64") != string::npos)
00119                                 s += "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
00120                         else if (exts.find("cl_amd_fp64") != string::npos)
00121                                 s += "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n";
00122                         else
00123                                 throw runtime_error("The OpenCL platform does not support 64 bits double-precision floating-points scalars.");
00124                         s += "typedef double T;\n";
00125                         return s;
00126                 }
00127         };
00128         
00130         struct SourceCacher
00131         {
00133                 typedef std::vector<cl::Device> Devices;
00135                 typedef std::map<std::string, cl::Program> ProgramCache;
00136                 
00137                 cl::Context context; 
00138                 Devices devices; 
00139                 ProgramCache cachedPrograms; 
00140                 
00142                 SourceCacher(const cl_device_type deviceType)
00143                 {
00144                         // looking for platforms, AMD drivers do not like the default for creating context
00145                         vector<cl::Platform> platforms;
00146                         cl::Platform::get(&platforms);
00147                         if (platforms.empty())
00148                                 throw runtime_error("No OpenCL platform found");
00149                         //for(vector<cl::Platform>::iterator i = platforms.begin(); i != platforms.end(); ++i)
00150                         //      cerr << "platform " << i - platforms.begin() << " is " << (*i).getInfo<CL_PLATFORM_VENDOR>() << endl;
00151                         cl::Platform platform = platforms[0];
00152                         const char *userDefinedPlatform(getenv("NABO_OPENCL_USE_PLATFORM"));
00153                         if (userDefinedPlatform)
00154                         {
00155                                 size_t userDefinedPlatformId = atoi(userDefinedPlatform);
00156                                 if (userDefinedPlatformId < platforms.size())
00157                                         platform = platforms[userDefinedPlatformId];
00158                         }
00159                         
00160                         // create OpenCL contexts
00161                         cl_context_properties properties[] = { CL_CONTEXT_PLATFORM, (cl_context_properties)platform(), 0 };
00162                         bool deviceFound = false;
00163                         try {
00164                                 context = cl::Context(deviceType, properties);
00165                                 deviceFound = true;
00166                         } catch (const cl::Error& e) {
00167                                 cerr << "Cannot find device type " << deviceType << " for OpenCL, falling back to any device" << endl;
00168                         }
00169                         if (!deviceFound)
00170                                 context = cl::Context(CL_DEVICE_TYPE_ALL, properties);
00171                         devices = context.getInfo<CL_CONTEXT_DEVICES>();
00172                         if (devices.empty())
00173                                 throw runtime_error("No devices on OpenCL platform");
00174                 }
00175                 
00177                 ~SourceCacher()
00178                 {
00179                         cerr << "Destroying source cacher containing " << cachedPrograms.size() << " cached programs" << endl;
00180                 }
00181                 
00183                 bool contains(const std::string& source)
00184                 {
00185                         return cachedPrograms.find(source) != cachedPrograms.end();
00186                 }
00187         };
00188         
00190         class ContextManager
00191         {
00192         public:
00194                 typedef std::map<cl_device_type, SourceCacher*> Devices;
00195                 
00197                 ~ContextManager()
00198                 {
00199                         cerr << "Destroying CL context manager, used " << devices.size() << " contexts" << endl;
00200                         for (Devices::iterator it(devices.begin()); it != devices.end(); ++it)
00201                                 delete it->second;
00202                 }
00204                 cl::Context& createContext(const cl_device_type deviceType)
00205                 {
00206                         boost::mutex::scoped_lock lock(mutex);
00207                         Devices::iterator it(devices.find(deviceType));
00208                         if (it == devices.end())
00209                         {
00210                                 it = devices.insert(
00211                                         pair<cl_device_type, SourceCacher*>(deviceType, new SourceCacher(deviceType))
00212                                         ).first;
00213                         }
00214                         return it->second->context;
00215                 }
00217                 SourceCacher* getSourceCacher(const cl_device_type deviceType)
00218                 {
00219                         boost::mutex::scoped_lock lock(mutex);
00220                         Devices::iterator it(devices.find(deviceType));
00221                         if (it == devices.end())
00222                                 throw runtime_error("Attempt to get source cacher before creating a context");
00223                         return it->second;
00224                 }
00225                 
00226         protected:
00227                 Devices devices; 
00228                 boost::mutex mutex; 
00229         };
00230         
00232         static ContextManager contextManager;
00233         
00234         template<typename T, typename CloudType>
00235         OpenCLSearch<T, CloudType>::OpenCLSearch(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00236                 NearestNeighbourSearch<T, CloudType>::NearestNeighbourSearch(cloud, dim, creationOptionFlags),
00237                 deviceType(deviceType),
00238                 context(contextManager.createContext(deviceType))
00239         {
00240         }
00241         
00242         template<typename T, typename CloudType>
00243         void OpenCLSearch<T, CloudType>::initOpenCL(const char* clFileName, const char* kernelName, const std::string& additionalDefines)
00244         {
00245                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T, CloudType>::TOUCH_STATISTICS);
00246                 
00247                 SourceCacher* sourceCacher(contextManager.getSourceCacher(deviceType));
00248                 SourceCacher::Devices& devices(sourceCacher->devices);
00249                 
00250                 // build and load source files
00251                 cl::Program::Sources sources;
00252                 // build defines
00253                 ostringstream oss;
00254                 oss << EnableCLTypeSupport<T, CloudType>::code(devices.back());
00255                 oss << "#define EPSILON " << numeric_limits<T>::epsilon() << "\n";
00256                 oss << "#define DIM_COUNT " << dim << "\n";
00257                 //oss << "#define CLOUD_POINT_COUNT " << cloud.cols() << "\n";
00258                 oss << "#define POINT_STRIDE " << cloud.stride() << "\n";
00259                 oss << "#define MAX_K " << MAX_K << "\n";
00260                 if (collectStatistics)
00261                         oss << "#define TOUCH_STATISTICS\n";
00262                 oss << additionalDefines;
00263                 //cerr << "params:\n" << oss.str() << endl;
00264                 
00265                 const std::string& source(oss.str());
00266                 if (!sourceCacher->contains(source))
00267                 {
00268                         const size_t defLen(source.length());
00269                         char *defContent(new char[defLen+1]);
00270                         strcpy(defContent, source.c_str());
00271                         sources.push_back(std::make_pair(defContent, defLen));
00272                         string sourceFileName(OPENCL_SOURCE_DIR);
00273                         sourceFileName += clFileName;
00274                         // load files
00275                         const char* files[] = {
00276                                 OPENCL_SOURCE_DIR "structure.cl",
00277                                 OPENCL_SOURCE_DIR "heap.cl",
00278                                 sourceFileName.c_str(),
00279                                 NULL 
00280                         };
00281                         for (const char** file = files; *file != NULL; ++file)
00282                         {
00283                                 std::ifstream stream(*file);
00284                                 if (!stream.good())
00285                                         throw runtime_error((string("cannot open file: ") + *file));
00286                                 
00287                                 stream.seekg(0, std::ios_base::end);
00288                                 size_t size(stream.tellg());
00289                                 stream.seekg(0, std::ios_base::beg);
00290                                 
00291                                 char* content(new char[size + 1]);
00292                                 std::copy(std::istreambuf_iterator<char>(stream),
00293                                                         std::istreambuf_iterator<char>(), content);
00294                                 content[size] = '\0';
00295                                 
00296                                 sources.push_back(std::make_pair(content, size));
00297                         }
00298                         sourceCacher->cachedPrograms[source] = cl::Program(context, sources);
00299                         cl::Program& program = sourceCacher->cachedPrograms[source];
00300                         
00301                         // build
00302                         cl::Error error(CL_SUCCESS);
00303                         try {
00304                                 program.build(devices);
00305                         } catch (cl::Error e) {
00306                                 error = e;
00307                         }
00308                         
00309                         // dump
00310                         for (cl::Devices::const_iterator it = devices.begin(); it != devices.end(); ++it)
00311                         {
00312                                 cerr << "device : " << it->getInfo<CL_DEVICE_NAME>() << "\n";
00313                                 cerr << "compilation log:\n" << program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(*it) << endl;
00314                         }
00315                         // cleanup sources
00316                         for (cl::Program::Sources::iterator it = sources.begin(); it != sources.end(); ++it)
00317                         {
00318                                 delete[] it->first;
00319                         }
00320                         sources.clear();
00321                         
00322                         // make sure to stop if compilation failed
00323                         if (error.err() != CL_SUCCESS)
00324                                 throw error;
00325                 }
00326                 cl::Program& program = sourceCacher->cachedPrograms[source];
00327                 
00328                 // build kernel and command queue
00329                 knnKernel = cl::Kernel(program, kernelName); 
00330                 queue = cl::CommandQueue(context, devices.back());
00331                 
00332                 // map cloud
00333                 if (!(cloud.Flags & Eigen::DirectAccessBit) || (cloud.Flags & Eigen::RowMajorBit))
00334                         throw runtime_error("wrong memory mapping in point cloud");
00335                 const size_t cloudCLSize(cloud.cols() * cloud.stride() * sizeof(T));
00336                 cloudCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, cloudCLSize, const_cast<T*>(&cloud.coeff(0,0)));
00337                 knnKernel.setArg(0, sizeof(cl_mem), &cloudCL);
00338         }
00339         
00340         template<typename T, typename CloudType>
00341         unsigned long OpenCLSearch<T, CloudType>::knn(const Matrix& query, IndexMatrix& indices, Matrix& dists2, const Index k, const T epsilon, const unsigned optionFlags, const T maxRadius) const
00342         {
00343                 checkSizesKnn(query, indices, dists2, k, optionFlags);
00344                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T, CloudType>::TOUCH_STATISTICS);
00345                 
00346                 // check K
00347                 if (k > MAX_K)
00348                         throw runtime_error("number of neighbors too large for OpenCL");
00349                 
00350                 // check consistency of query wrt cloud
00351                 if (query.stride() != cloud.stride() ||
00352                         query.rows() != cloud.rows())
00353                         throw runtime_error("query is not of the same dimensionality as the point cloud");
00354                 
00355                 // map query
00356                 if (!(query.Flags & Eigen::DirectAccessBit) || (query.Flags & Eigen::RowMajorBit))
00357                         throw runtime_error("wrong memory mapping in query data");
00358                 const size_t queryCLSize(query.cols() * query.stride() * sizeof(T));
00359                 cl::Buffer queryCL(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, queryCLSize, const_cast<T*>(&query.coeff(0,0)));
00360                 knnKernel.setArg(1, sizeof(cl_mem), &queryCL);
00361                 // map indices
00362                 assert((indices.Flags & Eigen::DirectAccessBit) && (!(indices.Flags & Eigen::RowMajorBit)));
00363                 const int indexStride(indices.stride());
00364                 const size_t indicesCLSize(indices.cols() * indexStride * sizeof(int));
00365                 cl::Buffer indicesCL(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR, indicesCLSize, &indices.coeffRef(0,0));
00366                 knnKernel.setArg(2, sizeof(cl_mem), &indicesCL);
00367                 // map dists2
00368                 assert((dists2.Flags & Eigen::DirectAccessBit) && (!(dists2.Flags & Eigen::RowMajorBit)));
00369                 const int dists2Stride(dists2.stride());
00370                 const size_t dists2CLSize(dists2.cols() * dists2Stride * sizeof(T));
00371                 cl::Buffer dists2CL(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR, dists2CLSize, &dists2.coeffRef(0,0));
00372                 knnKernel.setArg(3, sizeof(cl_mem), &dists2CL);
00373                 
00374                 // set resulting parameters
00375                 knnKernel.setArg(4, k);
00376                 knnKernel.setArg(5, (1 + epsilon)*(1 + epsilon));
00377                 knnKernel.setArg(6, maxRadius*maxRadius);
00378                 knnKernel.setArg(7, optionFlags);
00379                 knnKernel.setArg(8, indexStride);
00380                 knnKernel.setArg(9, dists2Stride);
00381                 knnKernel.setArg(10, cl_uint(cloud.cols()));
00382                 
00383                 // if required, map visit count
00384                 vector<cl_uint> visitCounts;
00385                 const size_t visitCountCLSize(query.cols() * sizeof(cl_uint));
00386                 cl::Buffer visitCountCL;
00387                 if (collectStatistics)
00388                 {
00389                         visitCounts.resize(query.cols());
00390                         visitCountCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, visitCountCLSize, &visitCounts[0]);
00391                         knnKernel.setArg(11, sizeof(cl_mem), &visitCountCL);
00392                 }
00393                 
00394                 // execute query
00395                 queue.enqueueNDRangeKernel(knnKernel, cl::NullRange, cl::NDRange(query.cols()), cl::NullRange);
00396                 queue.enqueueMapBuffer(indicesCL, true, CL_MAP_READ, 0, indicesCLSize, 0, 0);
00397                 queue.enqueueMapBuffer(dists2CL, true, CL_MAP_READ, 0, dists2CLSize, 0, 0);
00398                 if (collectStatistics)
00399                         queue.enqueueMapBuffer(visitCountCL, true, CL_MAP_READ, 0, visitCountCLSize, 0, 0);
00400                 queue.finish();
00401                 
00402                 // if required, collect statistics
00403                 if (collectStatistics)
00404                 {
00405                         unsigned long totalVisitCounts(0);
00406                         for (size_t i = 0; i < visitCounts.size(); ++i)
00407                                 totalVisitCounts += (unsigned long)visitCounts[i];
00408                         return totalVisitCounts;
00409                 }
00410                 else
00411                         return 0;
00412         }
00413         
00414         template<typename T, typename CloudType>
00415         BruteForceSearchOpenCL<T, CloudType>::BruteForceSearchOpenCL(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00416         OpenCLSearch<T, CloudType>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00417         {
00418 #ifdef EIGEN3_API
00419                 const_cast<Vector&>(this->minBound) = cloud.topRows(this->dim).rowwise().minCoeff();
00420                 const_cast<Vector&>(this->maxBound) = cloud.topRows(this->dim).rowwise().maxCoeff();
00421 #else // EIGEN3_API
00422                 // compute bounds
00423                 for (int i = 0; i < cloud.cols(); ++i)
00424                 {
00425                         const Vector& v(cloud.block(0,i,this->dim,1));
00426                         const_cast<Vector&>(this->minBound) = this->minBound.cwise().min(v);
00427                         const_cast<Vector&>(this->maxBound) = this->maxBound.cwise().max(v);
00428                 }
00429 #endif // EIGEN3_API
00430                 // init openCL
00431                 initOpenCL("knn_bf.cl", "knnBruteForce");
00432         }
00433 
00434         template struct BruteForceSearchOpenCL<float>;
00435         template struct BruteForceSearchOpenCL<double>;
00436         template struct BruteForceSearchOpenCL<float, Eigen::Matrix3Xf>;
00437         template struct BruteForceSearchOpenCL<double, Eigen::Matrix3Xd>;
00438         template struct BruteForceSearchOpenCL<float, Eigen::Map<const Eigen::Matrix3Xf, Eigen::Aligned> >;
00439         template struct BruteForceSearchOpenCL<double, Eigen::Map<const Eigen::Matrix3Xd, Eigen::Aligned> >;
00440         
00441         
00442 
00443         template<typename T, typename CloudType>
00444         size_t KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::getTreeSize(size_t elCount) const
00445         {
00446                 // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00447                 assert(elCount > 0);
00448                 elCount --;
00449                 size_t count = 0;
00450                 int i = 31;
00451                 for (; i >= 0; --i)
00452                 {
00453                         if (elCount & (1 << i))
00454                                 break;
00455                 }
00456                 for (int j = 0; j <= i; ++j)
00457                         count |= (1 << j);
00458                 count <<= 1;
00459                 count |= 1;
00460                 return count;
00461         }
00462         
00463         template<typename T, typename CloudType>
00464         size_t KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::getTreeDepth(size_t elCount) const
00465         {
00466                 if (elCount <= 1)
00467                         return 0;
00468                 elCount --;
00469                 size_t i = 31;
00470                 for (; i >= 0; --i)
00471                 {
00472                         if (elCount & (1 << i))
00473                                 break;
00474                 }
00475                 return i+1;
00476         }
00477 
00478         template<typename T, typename CloudType>
00479         void KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::buildNodes(const BuildPointsIt first, const BuildPointsIt last, const size_t pos, const Vector minValues, const Vector maxValues)
00480         {
00481                 const size_t count(last - first);
00482                 //cerr << count << endl;
00483                 if (count == 1)
00484                 {
00485                         const int d = -2-(first->index);
00486                         assert(pos < nodes.size());
00487                         nodes[pos] = Node(d);
00488                         return;
00489                 }
00490                 
00491                 // find the largest dimension of the box
00492                 size_t cutDim = argMax<T, CloudType>(maxValues - minValues);
00493                 
00494                 // compute number of elements
00495                 const size_t rightCount(count/2);
00496                 const size_t leftCount(count - rightCount);
00497                 assert(last - rightCount == first + leftCount);
00498                 
00499                 // sort
00500                 nth_element(first, first + leftCount, last, CompareDim(cutDim));
00501                 
00502                 // set node
00503                 const T cutVal((first+leftCount)->pos.coeff(cutDim));
00504                 nodes[pos] = Node(cutDim, cutVal);
00505                 
00506                 //cerr << pos << " cutting on " << cutDim << " at " << (first+leftCount)->pos[cutDim] << endl;
00507                 
00508                 // update bounds for left
00509                 Vector leftMaxValues(maxValues);
00510                 leftMaxValues[cutDim] = cutVal;
00511                 // update bounds for right
00512                 Vector rightMinValues(minValues);
00513                 rightMinValues[cutDim] = cutVal;
00514                 
00515                 // recurse
00516                 buildNodes(first, first + leftCount, childLeft(pos), minValues, leftMaxValues);
00517                 buildNodes(first + leftCount, last, childRight(pos), rightMinValues, maxValues);
00518         }
00519         
00520         template<typename T, typename CloudType>
00521         KDTreeBalancedPtInLeavesStackOpenCL<T, CloudType>::KDTreeBalancedPtInLeavesStackOpenCL(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00522                 OpenCLSearch<T, CloudType>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00523         {
00524                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T>::TOUCH_STATISTICS);
00525                 
00526                 // build point vector and compute bounds
00527                 BuildPoints buildPoints;
00528                 buildPoints.reserve(cloud.cols());
00529                 for (int i = 0; i < cloud.cols(); ++i)
00530                 {
00531                         const Vector& v(cloud.block(0,i,this->dim,1));
00532                         buildPoints.push_back(BuildPoint(v, i));
00533 #ifdef EIGEN3_API
00534                         const_cast<Vector&>(minBound) = minBound.array().min(v.array());
00535                         const_cast<Vector&>(maxBound) = maxBound.array().max(v.array());
00536 #else // EIGEN3_API
00537                         const_cast<Vector&>(minBound) = minBound.cwise().min(v);
00538                         const_cast<Vector&>(maxBound) = maxBound.cwise().max(v);
00539 #endif // EIGEN3_API
00540                 }
00541                 
00542                 // create nodes
00543                 nodes.resize(getTreeSize(cloud.cols()));
00544                 buildNodes(buildPoints.begin(), buildPoints.end(), 0, minBound, maxBound);
00545                 const unsigned maxStackDepth(getTreeDepth(nodes.size()) + 1);
00546                 
00547                 // init openCL
00548                 initOpenCL("knn_kdtree_pt_in_leaves.cl", "knnKDTree", (boost::format("#define MAX_STACK_DEPTH %1%\n") % maxStackDepth).str());
00549                 
00550                 // map nodes, for info about alignment, see sect 6.1.5 
00551                 const size_t nodesCLSize(nodes.size() * sizeof(Node));
00552                 nodesCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, nodesCLSize, &nodes[0]);
00553                 if (collectStatistics)
00554                         knnKernel.setArg(12, sizeof(cl_mem), &nodesCL);
00555                 else
00556                         knnKernel.setArg(11, sizeof(cl_mem), &nodesCL);
00557         }
00558 
00559         template struct KDTreeBalancedPtInLeavesStackOpenCL<float>;
00560         template struct KDTreeBalancedPtInLeavesStackOpenCL<double>;
00561         template struct KDTreeBalancedPtInLeavesStackOpenCL<float, Eigen::Matrix3Xf>;
00562         template struct KDTreeBalancedPtInLeavesStackOpenCL<double, Eigen::Matrix3Xd>;
00563         template struct KDTreeBalancedPtInLeavesStackOpenCL<float, Eigen::Map<const Eigen::Matrix3Xf, Eigen::Aligned> >;
00564         template struct KDTreeBalancedPtInLeavesStackOpenCL<double, Eigen::Map<const Eigen::Matrix3Xd, Eigen::Aligned> >;
00565         
00566         
00567         template<typename T, typename CloudType>
00568         size_t KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::getTreeSize(size_t elCount) const
00569         {
00570                 // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00571                 size_t count = 0;
00572                 int i = 31;
00573                 for (; i >= 0; --i)
00574                 {
00575                         if (elCount & (1 << i))
00576                                 break;
00577                 }
00578                 for (int j = 0; j <= i; ++j)
00579                         count |= (1 << j);
00580                 //cerr << "tree size " << count << " (" << elCount << " elements)\n";
00581                 return count;
00582         }
00583         
00584         template<typename T, typename CloudType>
00585         size_t KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::getTreeDepth(size_t elCount) const
00586         {
00587                 // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00588                 int i = 31;
00589                 for (; i >= 0; --i)
00590                 {
00591                         if (elCount & (1 << i))
00592                                 break;
00593                 }
00594                 return i + 1;
00595         }
00596         
00597         template<typename T, typename CloudType>
00598         void KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::buildNodes(const BuildPointsIt first, const BuildPointsIt last, const size_t pos, const Vector minValues, const Vector maxValues)
00599         {
00600                 const size_t count(last - first);
00601                 //cerr << count << endl;
00602                 if (count == 1)
00603                 {
00604                         nodes[pos] = Node(-1, *first);
00605                         return;
00606                 }
00607                 
00608                 // find the largest dimension of the box
00609                 const size_t cutDim = argMax<T, CloudType>(maxValues - minValues);
00610                 
00611                 // compute number of elements
00612                 const size_t recurseCount(count-1);
00613                 const size_t rightCount(recurseCount/2);
00614                 const size_t leftCount(recurseCount-rightCount);
00615                 assert(last - rightCount == first + leftCount + 1);
00616                 
00617                 // sort
00618                 nth_element(first, first + leftCount, last, CompareDim(cloud, cutDim));
00619                 
00620                 // set node
00621                 const Index index(*(first+leftCount));
00622                 const T cutVal(cloud.coeff(cutDim, index));
00623                 nodes[pos] = Node(cutDim, index);
00624                 
00625                 //cerr << pos << " cutting on " << cutDim << " at " << (first+leftCount)->pos[cutDim] << endl;
00626                 
00627                 // update bounds for left
00628                 Vector leftMaxValues(maxValues);
00629                 leftMaxValues[cutDim] = cutVal;
00630                 // update bounds for right
00631                 Vector rightMinValues(minValues);
00632                 rightMinValues[cutDim] = cutVal;
00633                 
00634                 // recurse
00635                 if (count > 2)
00636                 {
00637                         buildNodes(first, first + leftCount, childLeft(pos), minValues, leftMaxValues);
00638                         buildNodes(first + leftCount + 1, last, childRight(pos), rightMinValues, maxValues);
00639                 }
00640                 else
00641                 {
00642                         nodes[childLeft(pos)] = Node(-1, *first);
00643                         nodes[childRight(pos)] = Node(-2, 0);
00644                 }
00645         }
00646         
00647         template<typename T, typename CloudType>
00648         KDTreeBalancedPtInNodesStackOpenCL<T, CloudType>::KDTreeBalancedPtInNodesStackOpenCL(const CloudType& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00649         OpenCLSearch<T, CloudType>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00650         {
00651                 const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T, CloudType>::TOUCH_STATISTICS);
00652                 
00653                 // build point vector and compute bounds
00654                 BuildPoints buildPoints;
00655                 buildPoints.reserve(cloud.cols());
00656                 for (int i = 0; i < cloud.cols(); ++i)
00657                 {
00658                         buildPoints.push_back(i);
00659                         const Vector& v(cloud.block(0,i,this->dim,1));
00660 #ifdef EIGEN3_API
00661                         const_cast<Vector&>(minBound) = minBound.array().min(v.array());
00662                         const_cast<Vector&>(maxBound) = maxBound.array().max(v.array());
00663 #else // EIGEN3_API
00664                         const_cast<Vector&>(minBound) = minBound.cwise().min(v);
00665                         const_cast<Vector&>(maxBound) = maxBound.cwise().max(v);
00666 #endif // EIGEN3_API
00667                 }
00668                 
00669                 // create nodes
00670                 nodes.resize(getTreeSize(cloud.cols()));
00671                 buildNodes(buildPoints.begin(), buildPoints.end(), 0, minBound, maxBound);
00672                 const unsigned maxStackDepth(getTreeDepth(nodes.size()) + 1);
00673                 
00674                 // init openCL
00675                 initOpenCL("knn_kdtree_pt_in_nodes.cl", "knnKDTree", (boost::format("#define MAX_STACK_DEPTH %1%\n") % maxStackDepth).str());
00676                 
00677                 // map nodes, for info about alignment, see sect 6.1.5 
00678                 const size_t nodesCLSize(nodes.size() * sizeof(Node));
00679                 nodesCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, nodesCLSize, &nodes[0]);
00680                 if (collectStatistics)
00681                         knnKernel.setArg(12, sizeof(cl_mem), &nodesCL);
00682                 else
00683                         knnKernel.setArg(11, sizeof(cl_mem), &nodesCL);
00684         }
00685         
00686         template struct KDTreeBalancedPtInNodesStackOpenCL<float>;
00687         template struct KDTreeBalancedPtInNodesStackOpenCL<double>;
00688         template struct KDTreeBalancedPtInNodesStackOpenCL<float, Eigen::Matrix3Xf>;
00689         template struct KDTreeBalancedPtInNodesStackOpenCL<double, Eigen::Matrix3Xd>;
00690         template struct KDTreeBalancedPtInNodesStackOpenCL<float, Eigen::Map<const Eigen::Matrix3Xf, Eigen::Aligned> >;
00691         template struct KDTreeBalancedPtInNodesStackOpenCL<double, Eigen::Map<const Eigen::Matrix3Xd, Eigen::Aligned> >;
00692         
00694 }
00695 
00696 #endif // HAVE_OPENCL


libnabo
Author(s): St├ęphane Magnenat
autogenerated on Thu Sep 10 2015 10:54:55