kdtree_opencl.cpp
Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2010--2011, Stephane Magnenat, ASL, ETHZ, Switzerland
00004 You can contact the author at <stephane at magnenat dot net>
00005 
00006 All rights reserved.
00007 
00008 Redistribution and use in source and binary forms, with or without
00009 modification, are permitted provided that the following conditions are met:
00010     * Redistributions of source code must retain the above copyright
00011       notice, this list of conditions and the following disclaimer.
00012     * Redistributions in binary form must reproduce the above copyright
00013       notice, this list of conditions and the following disclaimer in the
00014       documentation and/or other materials provided with the distribution.
00015     * Neither the name of the <organization> nor the
00016       names of its contributors may be used to endorse or promote products
00017       derived from this software without specific prior written permission.
00018 
00019 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
00020 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00021 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00022 DISCLAIMED. IN NO EVENT SHALL ETH-ASL BE LIABLE FOR ANY
00023 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00024 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00025 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00026 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00028 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 #ifdef HAVE_OPENCL
00033 
00034 #include "nabo_private.h"
00035 #include "index_heap.h"
00036 #include <iostream>
00037 #include <sstream>
00038 #include <fstream>
00039 #include <stdexcept>
00040 #include <limits>
00041 #include <queue>
00042 #include <algorithm>
00043 // #include <map>
00044 #include <boost/numeric/conversion/bounds.hpp>
00045 #include <boost/limits.hpp>
00046 #include <boost/format.hpp>
00047 #include <boost/thread.hpp>
00048 
00049 
00055 namespace cl
00056 {
00058     typedef std::vector<Device> Devices;
00059 }
00060 
00061 namespace Nabo
00062 {
00063     // argmax is already defined in kdtree_cpu.cpp, which is always compiled
00064     template<typename T>
00065     size_t argMax(const typename NearestNeighbourSearch<T>::Vector& v);
00066     
00068 
00069     
00071     #define MAX_K 32
00072     
00073     using namespace std;
00074     
00076     template<typename T>
00077     struct EnableCLTypeSupport {};
00078     
00080     template<> struct EnableCLTypeSupport<float>
00081     {
00083         static string code(const cl::Device& device)
00084         {
00085             return "typedef float T;\n";
00086         }
00087     };
00088     
00090     template<> struct EnableCLTypeSupport<double>
00091     {
00093 
00094         static string code(const cl::Device& device)
00095         {
00096             string s;
00097             const string& exts(device.getInfo<CL_DEVICE_EXTENSIONS>());
00098             //cerr << "extensions: " << exts << endl;
00099             // first try generic 64-bits fp, otherwise try to fall back on vendor-specific extensions
00100             if (exts.find("cl_khr_fp64") != string::npos)
00101                 s += "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
00102             else if (exts.find("cl_amd_fp64") != string::npos)
00103                 s += "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n";
00104             else
00105                 throw runtime_error("The OpenCL platform does not support 64 bits double-precision floating-points scalars.");
00106             s += "typedef double T;\n";
00107             return s;
00108         }
00109     };
00110     
00112     struct SourceCacher
00113     {
00115         typedef std::vector<cl::Device> Devices;
00117         typedef std::map<std::string, cl::Program> ProgramCache;
00118         
00119         cl::Context context; 
00120         Devices devices; 
00121         ProgramCache cachedPrograms; 
00122         
00124         SourceCacher(const cl_device_type deviceType)
00125         {
00126             // looking for platforms, AMD drivers do not like the default for creating context
00127             vector<cl::Platform> platforms;
00128             cl::Platform::get(&platforms);
00129             if (platforms.empty())
00130                 throw runtime_error("No OpenCL platform found");
00131             //for(vector<cl::Platform>::iterator i = platforms.begin(); i != platforms.end(); ++i)
00132             //  cerr << "platform " << i - platforms.begin() << " is " << (*i).getInfo<CL_PLATFORM_VENDOR>() << endl;
00133             cl::Platform platform = platforms[0];
00134             const char *userDefinedPlatform(getenv("NABO_OPENCL_USE_PLATFORM"));
00135             if (userDefinedPlatform)
00136             {
00137                 size_t userDefinedPlatformId = atoi(userDefinedPlatform);
00138                 if (userDefinedPlatformId < platforms.size())
00139                     platform = platforms[userDefinedPlatformId];
00140             }
00141             
00142             // create OpenCL contexts
00143             cl_context_properties properties[] = { CL_CONTEXT_PLATFORM, (cl_context_properties)platform(), 0 };
00144             bool deviceFound = false;
00145             try {
00146                 context = cl::Context(deviceType, properties);
00147                 deviceFound = true;
00148             } catch (cl::Error e) {
00149                 cerr << "Cannot find device type " << deviceType << " for OpenCL, falling back to any device" << endl;
00150             }
00151             if (!deviceFound)
00152                 context = cl::Context(CL_DEVICE_TYPE_ALL, properties);
00153             devices = context.getInfo<CL_CONTEXT_DEVICES>();
00154             if (devices.empty())
00155                 throw runtime_error("No devices on OpenCL platform");
00156         }
00157         
00159         ~SourceCacher()
00160         {
00161             cerr << "Destroying source cacher containing " << cachedPrograms.size() << " cached programs" << endl;
00162         }
00163         
00165         bool contains(const std::string& source)
00166         {
00167             return cachedPrograms.find(source) != cachedPrograms.end();
00168         }
00169     };
00170     
00172     class ContextManager
00173     {
00174     public:
00176         typedef std::map<cl_device_type, SourceCacher*> Devices;
00177         
00179         ~ContextManager()
00180         {
00181             cerr << "Destroying CL context manager, used " << devices.size() << " contexts" << endl;
00182             for (Devices::iterator it(devices.begin()); it != devices.end(); ++it)
00183                 delete it->second;
00184         }
00186         cl::Context& createContext(const cl_device_type deviceType)
00187         {
00188             boost::mutex::scoped_lock lock(mutex);
00189             Devices::iterator it(devices.find(deviceType));
00190             if (it == devices.end())
00191             {
00192                 it = devices.insert(
00193                     pair<cl_device_type, SourceCacher*>(deviceType, new SourceCacher(deviceType))
00194                     ).first;
00195             }
00196             return it->second->context;
00197         }
00199         SourceCacher* getSourceCacher(const cl_device_type deviceType)
00200         {
00201             boost::mutex::scoped_lock lock(mutex);
00202             Devices::iterator it(devices.find(deviceType));
00203             if (it == devices.end())
00204                 throw runtime_error("Attempt to get source cacher before creating a context");
00205             return it->second;
00206         }
00207         
00208     protected:
00209         Devices devices; 
00210         boost::mutex mutex; 
00211     };
00212     
00214     static ContextManager contextManager;
00215     
00216     template<typename T>
00217     OpenCLSearch<T>::OpenCLSearch(const Matrix& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00218         NearestNeighbourSearch<T>::NearestNeighbourSearch(cloud, dim, creationOptionFlags),
00219         deviceType(deviceType),
00220         context(contextManager.createContext(deviceType))
00221     {
00222     }
00223     
00224     template<typename T>
00225     void OpenCLSearch<T>::initOpenCL(const char* clFileName, const char* kernelName, const std::string& additionalDefines)
00226     {
00227         const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T>::TOUCH_STATISTICS);
00228         
00229         SourceCacher* sourceCacher(contextManager.getSourceCacher(deviceType));
00230         SourceCacher::Devices& devices(sourceCacher->devices);
00231         
00232         // build and load source files
00233         cl::Program::Sources sources;
00234         // build defines
00235         ostringstream oss;
00236         oss << EnableCLTypeSupport<T>::code(devices.back());
00237         oss << "#define EPSILON " << numeric_limits<T>::epsilon() << "\n";
00238         oss << "#define DIM_COUNT " << dim << "\n";
00239         //oss << "#define CLOUD_POINT_COUNT " << cloud.cols() << "\n";
00240         oss << "#define POINT_STRIDE " << cloud.stride() << "\n";
00241         oss << "#define MAX_K " << MAX_K << "\n";
00242         if (collectStatistics)
00243             oss << "#define TOUCH_STATISTICS\n";
00244         oss << additionalDefines;
00245         //cerr << "params:\n" << oss.str() << endl;
00246         
00247         const std::string& source(oss.str());
00248         if (!sourceCacher->contains(source))
00249         {
00250             const size_t defLen(source.length());
00251             char *defContent(new char[defLen+1]);
00252             strcpy(defContent, source.c_str());
00253             sources.push_back(std::make_pair(defContent, defLen));
00254             string sourceFileName(OPENCL_SOURCE_DIR);
00255             sourceFileName += clFileName;
00256             // load files
00257             const char* files[] = {
00258                 OPENCL_SOURCE_DIR "structure.cl",
00259                 OPENCL_SOURCE_DIR "heap.cl",
00260                 sourceFileName.c_str(),
00261                 NULL 
00262             };
00263             for (const char** file = files; *file != NULL; ++file)
00264             {
00265                 std::ifstream stream(*file);
00266                 if (!stream.good())
00267                     throw runtime_error((string("cannot open file: ") + *file));
00268                 
00269                 stream.seekg(0, std::ios_base::end);
00270                 size_t size(stream.tellg());
00271                 stream.seekg(0, std::ios_base::beg);
00272                 
00273                 char* content(new char[size + 1]);
00274                 std::copy(std::istreambuf_iterator<char>(stream),
00275                             std::istreambuf_iterator<char>(), content);
00276                 content[size] = '\0';
00277                 
00278                 sources.push_back(std::make_pair(content, size));
00279             }
00280             sourceCacher->cachedPrograms[source] = cl::Program(context, sources);
00281             cl::Program& program = sourceCacher->cachedPrograms[source];
00282             
00283             // build
00284             cl::Error error(CL_SUCCESS);
00285             try {
00286                 program.build(devices);
00287             } catch (cl::Error e) {
00288                 error = e;
00289             }
00290             
00291             // dump
00292             for (cl::Devices::const_iterator it = devices.begin(); it != devices.end(); ++it)
00293             {
00294                 cerr << "device : " << it->getInfo<CL_DEVICE_NAME>() << "\n";
00295                 cerr << "compilation log:\n" << program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(*it) << endl;
00296             }
00297             // cleanup sources
00298             for (cl::Program::Sources::iterator it = sources.begin(); it != sources.end(); ++it)
00299             {
00300                 delete[] it->first;
00301             }
00302             sources.clear();
00303             
00304             // make sure to stop if compilation failed
00305             if (error.err() != CL_SUCCESS)
00306                 throw error;
00307         }
00308         cl::Program& program = sourceCacher->cachedPrograms[source];
00309         
00310         // build kernel and command queue
00311         knnKernel = cl::Kernel(program, kernelName); 
00312         queue = cl::CommandQueue(context, devices.back());
00313         
00314         // map cloud
00315         if (!(cloud.Flags & Eigen::DirectAccessBit) || (cloud.Flags & Eigen::RowMajorBit))
00316             throw runtime_error("wrong memory mapping in point cloud");
00317         const size_t cloudCLSize(cloud.cols() * cloud.stride() * sizeof(T));
00318         cloudCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, cloudCLSize, const_cast<T*>(&cloud.coeff(0,0)));
00319         knnKernel.setArg(0, sizeof(cl_mem), &cloudCL);
00320     }
00321     
00322     template<typename T>
00323     unsigned long OpenCLSearch<T>::knn(const Matrix& query, IndexMatrix& indices, Matrix& dists2, const Index k, const T epsilon, const unsigned optionFlags, const T maxRadius) const
00324     {
00325         checkSizesKnn(query, indices, dists2, k);
00326         const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T>::TOUCH_STATISTICS);
00327         
00328         // check K
00329         if (k > MAX_K)
00330             throw runtime_error("number of neighbors too large for OpenCL");
00331         
00332         // check consistency of query wrt cloud
00333         if (query.stride() != cloud.stride() ||
00334             query.rows() != cloud.rows())
00335             throw runtime_error("query is not of the same dimensionality as the point cloud");
00336         
00337         // map query
00338         if (!(query.Flags & Eigen::DirectAccessBit) || (query.Flags & Eigen::RowMajorBit))
00339             throw runtime_error("wrong memory mapping in query data");
00340         const size_t queryCLSize(query.cols() * query.stride() * sizeof(T));
00341         cl::Buffer queryCL(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, queryCLSize, const_cast<T*>(&query.coeff(0,0)));
00342         knnKernel.setArg(1, sizeof(cl_mem), &queryCL);
00343         // map indices
00344         assert((indices.Flags & Eigen::DirectAccessBit) && (!(indices.Flags & Eigen::RowMajorBit)));
00345         const int indexStride(indices.stride());
00346         const size_t indicesCLSize(indices.cols() * indexStride * sizeof(int));
00347         cl::Buffer indicesCL(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR, indicesCLSize, &indices.coeffRef(0,0));
00348         knnKernel.setArg(2, sizeof(cl_mem), &indicesCL);
00349         // map dists2
00350         assert((dists2.Flags & Eigen::DirectAccessBit) && (!(dists2.Flags & Eigen::RowMajorBit)));
00351         const int dists2Stride(dists2.stride());
00352         const size_t dists2CLSize(dists2.cols() * dists2Stride * sizeof(T));
00353         cl::Buffer dists2CL(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR, dists2CLSize, &dists2.coeffRef(0,0));
00354         knnKernel.setArg(3, sizeof(cl_mem), &dists2CL);
00355         
00356         // set resulting parameters
00357         knnKernel.setArg(4, k);
00358         knnKernel.setArg(5, (1 + epsilon)*(1 + epsilon));
00359         knnKernel.setArg(6, maxRadius*maxRadius);
00360         knnKernel.setArg(7, optionFlags);
00361         knnKernel.setArg(8, indexStride);
00362         knnKernel.setArg(9, dists2Stride);
00363         knnKernel.setArg(10, cl_uint(cloud.cols()));
00364         
00365         // if required, map visit count
00366         vector<cl_uint> visitCounts;
00367         const size_t visitCountCLSize(query.cols() * sizeof(cl_uint));
00368         cl::Buffer visitCountCL;
00369         if (collectStatistics)
00370         {
00371             visitCounts.resize(query.cols());
00372             visitCountCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, visitCountCLSize, &visitCounts[0]);
00373             knnKernel.setArg(11, sizeof(cl_mem), &visitCountCL);
00374         }
00375         
00376         // execute query
00377         queue.enqueueNDRangeKernel(knnKernel, cl::NullRange, cl::NDRange(query.cols()), cl::NullRange);
00378         queue.enqueueMapBuffer(indicesCL, true, CL_MAP_READ, 0, indicesCLSize, 0, 0);
00379         queue.enqueueMapBuffer(dists2CL, true, CL_MAP_READ, 0, dists2CLSize, 0, 0);
00380         if (collectStatistics)
00381             queue.enqueueMapBuffer(visitCountCL, true, CL_MAP_READ, 0, visitCountCLSize, 0, 0);
00382         queue.finish();
00383         
00384         // if required, collect statistics
00385         if (collectStatistics)
00386         {
00387             unsigned long totalVisitCounts(0);
00388             for (size_t i = 0; i < visitCounts.size(); ++i)
00389                 totalVisitCounts += (unsigned long)visitCounts[i];
00390             return totalVisitCounts;
00391         }
00392         else
00393             return 0;
00394     }
00395     
00396     template<typename T>
00397     BruteForceSearchOpenCL<T>::BruteForceSearchOpenCL(const Matrix& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00398     OpenCLSearch<T>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00399     {
00400 #ifdef EIGEN3_API
00401         const_cast<Vector&>(this->minBound) = cloud.topRows(this->dim).rowwise().minCoeff();
00402         const_cast<Vector&>(this->maxBound) = cloud.topRows(this->dim).rowwise().maxCoeff();
00403 #else // EIGEN3_API
00404         // compute bounds
00405         for (int i = 0; i < cloud.cols(); ++i)
00406         {
00407             const Vector& v(cloud.block(0,i,this->dim,1));
00408             const_cast<Vector&>(this->minBound) = this->minBound.cwise().min(v);
00409             const_cast<Vector&>(this->maxBound) = this->maxBound.cwise().max(v);
00410         }
00411 #endif // EIGEN3_API
00412         // init openCL
00413         initOpenCL("knn_bf.cl", "knnBruteForce");
00414     }
00415     
00416     template struct BruteForceSearchOpenCL<float>;
00417     template struct BruteForceSearchOpenCL<double>;
00418     
00419     
00420 
00421     template<typename T>
00422     size_t KDTreeBalancedPtInLeavesStackOpenCL<T>::getTreeSize(size_t elCount) const
00423     {
00424         // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00425         assert(elCount > 0);
00426         elCount --;
00427         size_t count = 0;
00428         int i = 31;
00429         for (; i >= 0; --i)
00430         {
00431             if (elCount & (1 << i))
00432                 break;
00433         }
00434         for (int j = 0; j <= i; ++j)
00435             count |= (1 << j);
00436         count <<= 1;
00437         count |= 1;
00438         return count;
00439     }
00440     
00441     template<typename T>
00442     size_t KDTreeBalancedPtInLeavesStackOpenCL<T>::getTreeDepth(size_t elCount) const
00443     {
00444         if (elCount <= 1)
00445             return 0;
00446         elCount --;
00447         size_t i = 31;
00448         for (; i >= 0; --i)
00449         {
00450             if (elCount & (1 << i))
00451                 break;
00452         }
00453         return i+1;
00454     }
00455 
00456     template<typename T>
00457     void KDTreeBalancedPtInLeavesStackOpenCL<T>::buildNodes(const BuildPointsIt first, const BuildPointsIt last, const size_t pos, const Vector minValues, const Vector maxValues)
00458     {
00459         const size_t count(last - first);
00460         //cerr << count << endl;
00461         if (count == 1)
00462         {
00463             const int d = -2-(first->index);
00464             assert(pos < nodes.size());
00465             nodes[pos] = Node(d);
00466             return;
00467         }
00468         
00469         // find the largest dimension of the box
00470         size_t cutDim = argMax<T>(maxValues - minValues);
00471         
00472         // compute number of elements
00473         const size_t rightCount(count/2);
00474         const size_t leftCount(count - rightCount);
00475         assert(last - rightCount == first + leftCount);
00476         
00477         // sort
00478         nth_element(first, first + leftCount, last, CompareDim(cutDim));
00479         
00480         // set node
00481         const T cutVal((first+leftCount)->pos.coeff(cutDim));
00482         nodes[pos] = Node(cutDim, cutVal);
00483         
00484         //cerr << pos << " cutting on " << cutDim << " at " << (first+leftCount)->pos[cutDim] << endl;
00485         
00486         // update bounds for left
00487         Vector leftMaxValues(maxValues);
00488         leftMaxValues[cutDim] = cutVal;
00489         // update bounds for right
00490         Vector rightMinValues(minValues);
00491         rightMinValues[cutDim] = cutVal;
00492         
00493         // recurse
00494         buildNodes(first, first + leftCount, childLeft(pos), minValues, leftMaxValues);
00495         buildNodes(first + leftCount, last, childRight(pos), rightMinValues, maxValues);
00496     }
00497     
00498     template<typename T>
00499     KDTreeBalancedPtInLeavesStackOpenCL<T>::KDTreeBalancedPtInLeavesStackOpenCL(const Matrix& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00500         OpenCLSearch<T>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00501     {
00502         const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T>::TOUCH_STATISTICS);
00503         
00504         // build point vector and compute bounds
00505         BuildPoints buildPoints;
00506         buildPoints.reserve(cloud.cols());
00507         for (int i = 0; i < cloud.cols(); ++i)
00508         {
00509             const Vector& v(cloud.block(0,i,this->dim,1));
00510             buildPoints.push_back(BuildPoint(v, i));
00511 #ifdef EIGEN3_API
00512             const_cast<Vector&>(minBound) = minBound.array().min(v.array());
00513             const_cast<Vector&>(maxBound) = maxBound.array().max(v.array());
00514 #else // EIGEN3_API
00515             const_cast<Vector&>(minBound) = minBound.cwise().min(v);
00516             const_cast<Vector&>(maxBound) = maxBound.cwise().max(v);
00517 #endif // EIGEN3_API
00518         }
00519         
00520         // create nodes
00521         nodes.resize(getTreeSize(cloud.cols()));
00522         buildNodes(buildPoints.begin(), buildPoints.end(), 0, minBound, maxBound);
00523         const unsigned maxStackDepth(getTreeDepth(nodes.size()) + 1);
00524         
00525         // init openCL
00526         initOpenCL("knn_kdtree_pt_in_leaves.cl", "knnKDTree", (boost::format("#define MAX_STACK_DEPTH %1%\n") % maxStackDepth).str());
00527         
00528         // map nodes, for info about alignment, see sect 6.1.5 
00529         const size_t nodesCLSize(nodes.size() * sizeof(Node));
00530         nodesCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, nodesCLSize, &nodes[0]);
00531         if (collectStatistics)
00532             knnKernel.setArg(12, sizeof(cl_mem), &nodesCL);
00533         else
00534             knnKernel.setArg(11, sizeof(cl_mem), &nodesCL);
00535     }
00536 
00537     template struct KDTreeBalancedPtInLeavesStackOpenCL<float>;
00538     template struct KDTreeBalancedPtInLeavesStackOpenCL<double>;
00539     
00540     
00541     template<typename T>
00542     size_t KDTreeBalancedPtInNodesStackOpenCL<T>::getTreeSize(size_t elCount) const
00543     {
00544         // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00545         size_t count = 0;
00546         int i = 31;
00547         for (; i >= 0; --i)
00548         {
00549             if (elCount & (1 << i))
00550                 break;
00551         }
00552         for (int j = 0; j <= i; ++j)
00553             count |= (1 << j);
00554         //cerr << "tree size " << count << " (" << elCount << " elements)\n";
00555         return count;
00556     }
00557     
00558     template<typename T>
00559     size_t KDTreeBalancedPtInNodesStackOpenCL<T>::getTreeDepth(size_t elCount) const
00560     {
00561         // FIXME: 64 bits safe stuff, only work for 2^32 elements right now
00562         int i = 31;
00563         for (; i >= 0; --i)
00564         {
00565             if (elCount & (1 << i))
00566                 break;
00567         }
00568         return i + 1;
00569     }
00570     
00571     template<typename T>
00572     void KDTreeBalancedPtInNodesStackOpenCL<T>::buildNodes(const BuildPointsIt first, const BuildPointsIt last, const size_t pos, const Vector minValues, const Vector maxValues)
00573     {
00574         const size_t count(last - first);
00575         //cerr << count << endl;
00576         if (count == 1)
00577         {
00578             nodes[pos] = Node(-1, *first);
00579             return;
00580         }
00581         
00582         // find the largest dimension of the box
00583         const size_t cutDim = argMax<T>(maxValues - minValues);
00584         
00585         // compute number of elements
00586         const size_t recurseCount(count-1);
00587         const size_t rightCount(recurseCount/2);
00588         const size_t leftCount(recurseCount-rightCount);
00589         assert(last - rightCount == first + leftCount + 1);
00590         
00591         // sort
00592         nth_element(first, first + leftCount, last, CompareDim(cloud, cutDim));
00593         
00594         // set node
00595         const Index index(*(first+leftCount));
00596         const T cutVal(cloud.coeff(cutDim, index));
00597         nodes[pos] = Node(cutDim, index);
00598         
00599         //cerr << pos << " cutting on " << cutDim << " at " << (first+leftCount)->pos[cutDim] << endl;
00600         
00601         // update bounds for left
00602         Vector leftMaxValues(maxValues);
00603         leftMaxValues[cutDim] = cutVal;
00604         // update bounds for right
00605         Vector rightMinValues(minValues);
00606         rightMinValues[cutDim] = cutVal;
00607         
00608         // recurse
00609         if (count > 2)
00610         {
00611             buildNodes(first, first + leftCount, childLeft(pos), minValues, leftMaxValues);
00612             buildNodes(first + leftCount + 1, last, childRight(pos), rightMinValues, maxValues);
00613         }
00614         else
00615         {
00616             nodes[childLeft(pos)] = Node(-1, *first);
00617             nodes[childRight(pos)] = Node(-2, 0);
00618         }
00619     }
00620     
00621     template<typename T>
00622     KDTreeBalancedPtInNodesStackOpenCL<T>::KDTreeBalancedPtInNodesStackOpenCL(const Matrix& cloud, const Index dim, const unsigned creationOptionFlags, const cl_device_type deviceType):
00623     OpenCLSearch<T>::OpenCLSearch(cloud, dim, creationOptionFlags, deviceType)
00624     {
00625         const bool collectStatistics(creationOptionFlags & NearestNeighbourSearch<T>::TOUCH_STATISTICS);
00626         
00627         // build point vector and compute bounds
00628         BuildPoints buildPoints;
00629         buildPoints.reserve(cloud.cols());
00630         for (int i = 0; i < cloud.cols(); ++i)
00631         {
00632             buildPoints.push_back(i);
00633             const Vector& v(cloud.block(0,i,this->dim,1));
00634 #ifdef EIGEN3_API
00635             const_cast<Vector&>(minBound) = minBound.array().min(v.array());
00636             const_cast<Vector&>(maxBound) = maxBound.array().max(v.array());
00637 #else // EIGEN3_API
00638             const_cast<Vector&>(minBound) = minBound.cwise().min(v);
00639             const_cast<Vector&>(maxBound) = maxBound.cwise().max(v);
00640 #endif // EIGEN3_API
00641         }
00642         
00643         // create nodes
00644         nodes.resize(getTreeSize(cloud.cols()));
00645         buildNodes(buildPoints.begin(), buildPoints.end(), 0, minBound, maxBound);
00646         const unsigned maxStackDepth(getTreeDepth(nodes.size()) + 1);
00647         
00648         // init openCL
00649         initOpenCL("knn_kdtree_pt_in_nodes.cl", "knnKDTree", (boost::format("#define MAX_STACK_DEPTH %1%\n") % maxStackDepth).str());
00650         
00651         // map nodes, for info about alignment, see sect 6.1.5 
00652         const size_t nodesCLSize(nodes.size() * sizeof(Node));
00653         nodesCL = cl::Buffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR, nodesCLSize, &nodes[0]);
00654         if (collectStatistics)
00655             knnKernel.setArg(12, sizeof(cl_mem), &nodesCL);
00656         else
00657             knnKernel.setArg(11, sizeof(cl_mem), &nodesCL);
00658     }
00659     
00660     template struct KDTreeBalancedPtInNodesStackOpenCL<float>;
00661     template struct KDTreeBalancedPtInNodesStackOpenCL<double>;
00662     
00664 }
00665 
00666 #endif // HAVE_OPENCL


libnabo
Author(s): St├ęphane Magnenat
autogenerated on Thu Jan 2 2014 11:15:54