23 #if defined(EIGEN_USE_SYCL) && \ 
   24     !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H) 
   25 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H 
   27 #include <CL/sycl.hpp> 
   28 #ifdef EIGEN_EXCEPTIONS 
   34 #include <unordered_map> 
   37 namespace TensorSycl {
 
   40 using sycl_acc_target = cl::sycl::access::target;
 
   46 using buffer_data_type_t = 
uint8_t;
 
   47 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
 
   48 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
 
   65   struct virtual_pointer_t {
 
   68     base_ptr_t m_contents;
 
   73     operator void *() 
const { 
return reinterpret_cast<void *
>(m_contents); }
 
   78     operator base_ptr_t()
 const { 
return m_contents; }
 
   84     virtual_pointer_t 
operator+(
size_t off) { 
return m_contents + off; }
 
   87     bool operator<(virtual_pointer_t rhs)
 const {
 
   88       return (
static_cast<base_ptr_t
>(m_contents) <
 
   89               static_cast<base_ptr_t
>(rhs.m_contents));
 
   92     bool operator>(virtual_pointer_t rhs)
 const {
 
   93       return (
static_cast<base_ptr_t
>(m_contents) >
 
   94               static_cast<base_ptr_t
>(rhs.m_contents));
 
  100     bool operator==(virtual_pointer_t rhs)
 const {
 
  101       return (
static_cast<base_ptr_t
>(m_contents) ==
 
  102               static_cast<base_ptr_t
>(rhs.m_contents));
 
  108     bool operator!=(virtual_pointer_t rhs)
 const {
 
  118     virtual_pointer_t(
const void *ptr)
 
  119         : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
 
  125     virtual_pointer_t(base_ptr_t u) : m_contents(u){};
 
  130   const virtual_pointer_t null_virtual_ptr = 
nullptr;
 
  136   static inline bool is_nullptr(virtual_pointer_t ptr) {
 
  137     return (
static_cast<void *
>(ptr) == 
nullptr);
 
  142   using buffer_t = cl::sycl::buffer_mem;
 
  154     pMapNode_t(buffer_t 
b, 
size_t size, 
bool f)
 
  155         : m_buffer{
b}, m_size{
size}, m_free{
f} {
 
  156       m_buffer.set_final_data(
nullptr);
 
  159     bool operator<=(
const pMapNode_t &rhs) { 
return (m_size <= rhs.m_size); }
 
  164   using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
 
  171   typename pointerMap_t::iterator get_insertion_point(
size_t requiredSize) {
 
  172     typename pointerMap_t::iterator retVal;
 
  174     if (!m_freeList.empty()) {
 
  176       for (
auto freeElem : m_freeList) {
 
  177         if (freeElem->second.m_size >= requiredSize) {
 
  181           m_freeList.erase(freeElem);
 
  187       retVal = std::prev(m_pointerMap.end());
 
  202   typename pointerMap_t::iterator get_node(
const virtual_pointer_t ptr) {
 
  203     if (this->count() == 0) {
 
  204       m_pointerMap.clear();
 
  205       EIGEN_THROW_X(std::out_of_range(
"There are no pointers allocated\n"));
 
  208     if (is_nullptr(ptr)) {
 
  209       m_pointerMap.clear();
 
  210       EIGEN_THROW_X(std::out_of_range(
"Cannot access null pointer\n"));
 
  214     auto node = m_pointerMap.lower_bound(ptr);
 
  217     if (node == 
std::end(m_pointerMap)) {
 
  219     } 
else if (node->first != ptr) {
 
  220       if (node == std::begin(m_pointerMap)) {
 
  221         m_pointerMap.clear();
 
  223             std::out_of_range(
"The pointer is not registered in the map\n"));
 
  235   template <
typename buffer_data_type = buffer_data_type_t>
 
  236   cl::sycl::buffer<buffer_data_type, 1> get_buffer(
 
  237       const virtual_pointer_t ptr) {
 
  238     using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
 
  244     auto node = get_node(ptr);
 
  246     eigen_assert(ptr < 
static_cast<virtual_pointer_t
>(node->second.m_size +
 
  248     return *(
static_cast<sycl_buffer_t *
>(&node->second.m_buffer));
 
  257   template <sycl_acc_mode access_mode = default_acc_mode,
 
  258             sycl_acc_target access_target = default_acc_target,
 
  259             typename buffer_data_type = buffer_data_type_t>
 
  260   cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
 
  261   get_access(
const virtual_pointer_t ptr) {
 
  262     auto buf = get_buffer<buffer_data_type>(ptr);
 
  263     return buf.template get_access<access_mode, access_target>();
 
  274   template <sycl_acc_mode access_mode = default_acc_mode,
 
  275             sycl_acc_target access_target = default_acc_target,
 
  276             typename buffer_data_type = buffer_data_type_t>
 
  277   cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
 
  278   get_access(
const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
 
  279     auto buf = get_buffer<buffer_data_type>(ptr);
 
  280     return buf.template get_access<access_mode, access_target>(cgh);
 
  286   inline std::ptrdiff_t get_offset(
const virtual_pointer_t ptr) {
 
  289     auto node = get_node(ptr);
 
  290     auto start = node->first;
 
  293     return (ptr - start);
 
  300   template <
typename buffer_data_type>
 
  301   inline size_t get_element_offset(
const virtual_pointer_t ptr) {
 
  302     return get_offset(ptr) / 
sizeof(buffer_data_type);
 
  308   PointerMapper(base_ptr_t baseAddress = 4096)
 
  309       : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
 
  310     if (m_baseAddress == 0) {
 
  311       EIGEN_THROW_X(std::invalid_argument(
"Base address cannot be zero\n"));
 
  318   PointerMapper(
const PointerMapper &) = 
delete;
 
  323   inline void clear() {
 
  325     m_pointerMap.clear();
 
  331   inline virtual_pointer_t add_pointer(
const buffer_t &
b) {
 
  332     return add_pointer_impl(
b);
 
  338   inline virtual_pointer_t add_pointer(buffer_t &&
b) {
 
  339     return add_pointer_impl(
b);
 
  348   void fuse_forward(
typename pointerMap_t::iterator &node) {
 
  349     while (node != std::prev(m_pointerMap.end())) {
 
  352       auto fwd_node = std::next(node);
 
  353       if (!fwd_node->second.m_free) {
 
  356       auto fwd_size = fwd_node->second.m_size;
 
  357       m_freeList.erase(fwd_node);
 
  358       m_pointerMap.erase(fwd_node);
 
  360       node->second.m_size += fwd_size;
 
  370   void fuse_backward(
typename pointerMap_t::iterator &node) {
 
  371     while (node != m_pointerMap.begin()) {
 
  374       auto prev_node = std::prev(node);
 
  375       if (!prev_node->second.m_free) {
 
  378       prev_node->second.m_size += node->second.m_size;
 
  381       m_freeList.erase(node);
 
  382       m_pointerMap.erase(node);
 
  393   template <
bool ReUse = true>
 
  394   void remove_pointer(
const virtual_pointer_t ptr) {
 
  395     if (is_nullptr(ptr)) {
 
  398     auto node = this->get_node(ptr);
 
  400     node->second.m_free = 
true;
 
  401     m_freeList.emplace(node);
 
  410     if (node == std::prev(m_pointerMap.end())) {
 
  411       m_freeList.erase(node);
 
  412       m_pointerMap.erase(node);
 
  420   size_t count()
 const { 
return (m_pointerMap.size() - m_freeList.size()); }
 
  427   template <
class BufferT>
 
  428   virtual_pointer_t add_pointer_impl(BufferT 
b) {
 
  429     virtual_pointer_t retVal = 
nullptr;
 
  430     size_t bufSize = 
b.get_count();
 
  431     pMapNode_t 
p{
b, bufSize, 
false};
 
  433     if (m_pointerMap.empty()) {
 
  434       virtual_pointer_t initialVal{m_baseAddress};
 
  435       m_pointerMap.emplace(initialVal, 
p);
 
  439     auto lastElemIter = get_insertion_point(bufSize);
 
  441     if (lastElemIter->second.m_free) {
 
  442       lastElemIter->second.m_buffer = 
b;
 
  443       lastElemIter->second.m_free = 
false;
 
  447       if (lastElemIter->second.m_size > bufSize) {
 
  449         auto remainingSize = lastElemIter->second.m_size - bufSize;
 
  450         pMapNode_t 
p2{
b, remainingSize, 
true};
 
  453         lastElemIter->second.m_size = bufSize;
 
  456         auto newFreePtr = lastElemIter->first + bufSize;
 
  457         auto freeNode = m_pointerMap.emplace(newFreePtr, 
p2).first;
 
  458         m_freeList.emplace(freeNode);
 
  461       retVal = lastElemIter->first;
 
  463       size_t lastSize = lastElemIter->second.m_size;
 
  464       retVal = lastElemIter->first + lastSize;
 
  465       m_pointerMap.emplace(retVal, 
p);
 
  476                     typename pointerMap_t::iterator 
b)
 const {
 
  477       return ((
a->first < 
b->first) && (
a->second <= 
b->second)) ||
 
  478              ((
a->first < 
b->first) && (
b->second <= 
a->second));
 
  484   pointerMap_t m_pointerMap;
 
  488   std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
 
  500 inline void PointerMapper::remove_pointer<false>(
const virtual_pointer_t ptr) {
 
  501   if (is_nullptr(ptr)) {
 
  504   m_pointerMap.erase(this->get_node(ptr));
 
  514 inline void *SYCLmalloc(
size_t size, PointerMapper &pMap) {
 
  519   using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
 
  520   auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{
size}));
 
  522   return static_cast<void *
>(thePointer);
 
  532 template <
bool ReUse = true, 
typename Po
interMapper>
 
  533 inline void SYCLfree(
void *ptr, PointerMapper &pMap) {
 
  534   pMap.template remove_pointer<ReUse>(ptr);
 
  540 template <
typename Po
interMapper>
 
  541 inline void SYCLfreeAll(PointerMapper &pMap) {
 
  545 template <cl::sycl::access::mode AcMd, 
typename T>
 
  547   static const auto global_access = cl::sycl::access::target::global_buffer;
 
  548   static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
 
  550   typedef scalar_t &ref_t;
 
  551   typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
 
  554   typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
 
  557   typedef RangeAccess<AcMd, T> 
self_t;
 
  561       : access_(access), offset_(
offset), virtual_ptr_(virtual_ptr) {}
 
  563   RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
 
  564                   cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
 
  565       : access_{
accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
 
  568   RangeAccess(std::nullptr_t) : RangeAccess() {}
 
  571     return (access_.get_pointer().get() + offset_);
 
  573   template <
typename Index>
 
  578   template <
typename Index>
 
  582   template <
typename Index>
 
  586   template <
typename Index>
 
  594       const RangeAccess &lhs, std::nullptr_t) {
 
  595     return ((lhs.virtual_ptr_ == -1));
 
  598       const RangeAccess &lhs, std::nullptr_t 
i) {
 
  604       std::nullptr_t, 
const RangeAccess &rhs) {
 
  605     return ((rhs.virtual_ptr_ == -1));
 
  608       std::nullptr_t 
i, 
const RangeAccess &rhs) {
 
  620     self_t temp_iterator(*
this);
 
  622     return temp_iterator;
 
  626     return (access_.get_count() - offset_);
 
  638     return *get_pointer();
 
  642     return *get_pointer();
 
  648     return *(get_pointer() + 
x);
 
  652     return *(get_pointer() + 
x);
 
  656     return reinterpret_cast<scalar_t *
>(virtual_ptr_ +
 
  657                                         (offset_ * 
sizeof(scalar_t)));
 
  661     return (virtual_ptr_ != -1);
 
  665     return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
 
  669   operator RangeAccess<AcMd, const T>()
 const {
 
  670     return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
 
  674       cl::sycl::handler &cgh)
 const {
 
  675     cgh.require(access_);
 
  684 template <cl::sycl::access::mode AcMd, 
typename T>
 
  685 struct RangeAccess<AcMd, const 
T> : RangeAccess<AcMd, T> {
 
  686   typedef RangeAccess<AcMd, T> 
Base;
 
  694 #endif  // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H