23 #if defined(EIGEN_USE_SYCL) && \
24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
27 #include <CL/sycl.hpp>
28 #ifdef EIGEN_EXCEPTIONS
34 #include <unordered_map>
37 namespace TensorSycl {
40 using sycl_acc_target = cl::sycl::access::target;
46 using buffer_data_type_t =
uint8_t;
47 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
65 struct virtual_pointer_t {
68 base_ptr_t m_contents;
73 operator void *()
const {
return reinterpret_cast<void *
>(m_contents); }
78 operator base_ptr_t()
const {
return m_contents; }
84 virtual_pointer_t
operator+(
size_t off) {
return m_contents + off; }
87 bool operator<(virtual_pointer_t rhs)
const {
88 return (
static_cast<base_ptr_t
>(m_contents) <
89 static_cast<base_ptr_t
>(rhs.m_contents));
92 bool operator>(virtual_pointer_t rhs)
const {
93 return (
static_cast<base_ptr_t
>(m_contents) >
94 static_cast<base_ptr_t
>(rhs.m_contents));
100 bool operator==(virtual_pointer_t rhs)
const {
101 return (
static_cast<base_ptr_t
>(m_contents) ==
102 static_cast<base_ptr_t
>(rhs.m_contents));
108 bool operator!=(virtual_pointer_t rhs)
const {
118 virtual_pointer_t(
const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
130 const virtual_pointer_t null_virtual_ptr =
nullptr;
136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (
static_cast<void *
>(ptr) ==
nullptr);
142 using buffer_t = cl::sycl::buffer_mem;
154 pMapNode_t(buffer_t
b,
size_t size,
bool f)
155 : m_buffer{
b}, m_size{
size}, m_free{
f} {
156 m_buffer.set_final_data(
nullptr);
159 bool operator<=(
const pMapNode_t &rhs) {
return (m_size <= rhs.m_size); }
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
171 typename pointerMap_t::iterator get_insertion_point(
size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
174 if (!m_freeList.empty()) {
176 for (
auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
181 m_freeList.erase(freeElem);
187 retVal = std::prev(m_pointerMap.end());
202 typename pointerMap_t::iterator get_node(
const virtual_pointer_t ptr) {
203 if (this->count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range(
"There are no pointers allocated\n"));
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range(
"Cannot access null pointer\n"));
214 auto node = m_pointerMap.lower_bound(ptr);
217 if (node ==
std::end(m_pointerMap)) {
219 }
else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
223 std::out_of_range(
"The pointer is not registered in the map\n"));
235 template <
typename buffer_data_type = buffer_data_type_t>
236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
244 auto node = get_node(ptr);
246 eigen_assert(ptr <
static_cast<virtual_pointer_t
>(node->second.m_size +
248 return *(
static_cast<sycl_buffer_t *
>(&node->second.m_buffer));
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
261 get_access(
const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
278 get_access(
const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
286 inline std::ptrdiff_t get_offset(
const virtual_pointer_t ptr) {
289 auto node = get_node(ptr);
290 auto start = node->first;
293 return (ptr - start);
300 template <
typename buffer_data_type>
301 inline size_t get_element_offset(
const virtual_pointer_t ptr) {
302 return get_offset(ptr) /
sizeof(buffer_data_type);
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument(
"Base address cannot be zero\n"));
318 PointerMapper(
const PointerMapper &) =
delete;
323 inline void clear() {
325 m_pointerMap.clear();
331 inline virtual_pointer_t add_pointer(
const buffer_t &
b) {
332 return add_pointer_impl(
b);
338 inline virtual_pointer_t add_pointer(buffer_t &&
b) {
339 return add_pointer_impl(
b);
348 void fuse_forward(
typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
360 node->second.m_size += fwd_size;
370 void fuse_backward(
typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
378 prev_node->second.m_size += node->second.m_size;
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
393 template <
bool ReUse = true>
394 void remove_pointer(
const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
398 auto node = this->get_node(ptr);
400 node->second.m_free =
true;
401 m_freeList.emplace(node);
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
420 size_t count()
const {
return (m_pointerMap.size() - m_freeList.size()); }
427 template <
class BufferT>
428 virtual_pointer_t add_pointer_impl(BufferT
b) {
429 virtual_pointer_t retVal =
nullptr;
430 size_t bufSize =
b.get_count();
431 pMapNode_t
p{
b, bufSize,
false};
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal,
p);
439 auto lastElemIter = get_insertion_point(bufSize);
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer =
b;
443 lastElemIter->second.m_free =
false;
447 if (lastElemIter->second.m_size > bufSize) {
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t
p2{
b, remainingSize,
true};
453 lastElemIter->second.m_size = bufSize;
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr,
p2).first;
458 m_freeList.emplace(freeNode);
461 retVal = lastElemIter->first;
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal,
p);
476 typename pointerMap_t::iterator
b)
const {
477 return ((
a->first <
b->first) && (
a->second <=
b->second)) ||
478 ((
a->first <
b->first) && (
b->second <=
a->second));
484 pointerMap_t m_pointerMap;
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
500 inline void PointerMapper::remove_pointer<false>(
const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
504 m_pointerMap.erase(this->get_node(ptr));
514 inline void *SYCLmalloc(
size_t size, PointerMapper &pMap) {
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{
size}));
522 return static_cast<void *
>(thePointer);
532 template <
bool ReUse = true,
typename Po
interMapper>
533 inline void SYCLfree(
void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
540 template <
typename Po
interMapper>
541 inline void SYCLfreeAll(PointerMapper &pMap) {
545 template <cl::sycl::access::mode AcMd,
typename T>
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
550 typedef scalar_t &ref_t;
551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
557 typedef RangeAccess<AcMd, T>
self_t;
561 : access_(access), offset_(
offset), virtual_ptr_(virtual_ptr) {}
563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{
accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
568 RangeAccess(std::nullptr_t) : RangeAccess() {}
571 return (access_.get_pointer().get() + offset_);
573 template <
typename Index>
578 template <
typename Index>
582 template <
typename Index>
586 template <
typename Index>
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
598 const RangeAccess &lhs, std::nullptr_t
i) {
604 std::nullptr_t,
const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
608 std::nullptr_t
i,
const RangeAccess &rhs) {
620 self_t temp_iterator(*
this);
622 return temp_iterator;
626 return (access_.get_count() - offset_);
638 return *get_pointer();
642 return *get_pointer();
648 return *(get_pointer() +
x);
652 return *(get_pointer() +
x);
656 return reinterpret_cast<scalar_t *
>(virtual_ptr_ +
657 (offset_ *
sizeof(scalar_t)));
661 return (virtual_ptr_ != -1);
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
669 operator RangeAccess<AcMd, const T>()
const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
674 cl::sycl::handler &cgh)
const {
675 cgh.require(access_);
684 template <cl::sycl::access::mode AcMd,
typename T>
685 struct RangeAccess<AcMd, const
T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T>
Base;
694 #endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H