23 #if defined(EIGEN_USE_SYCL) && \ 24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H) 25 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H 27 #include <CL/sycl.hpp> 28 #ifdef EIGEN_EXCEPTIONS 34 #include <unordered_map> 37 namespace TensorSycl {
40 using sycl_acc_target = cl::sycl::access::target;
46 using buffer_data_type_t =
uint8_t;
47 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
65 struct virtual_pointer_t {
68 base_ptr_t m_contents;
73 operator void *()
const {
return reinterpret_cast<void *
>(m_contents); }
78 operator base_ptr_t()
const {
return m_contents; }
84 virtual_pointer_t
operator+(
size_t off) {
return m_contents + off; }
87 bool operator<(virtual_pointer_t rhs)
const {
88 return (static_cast<base_ptr_t>(m_contents) <
89 static_cast<base_ptr_t>(rhs.m_contents));
92 bool operator>(virtual_pointer_t rhs)
const {
93 return (static_cast<base_ptr_t>(m_contents) >
94 static_cast<base_ptr_t>(rhs.m_contents));
100 bool operator==(virtual_pointer_t rhs)
const {
101 return (static_cast<base_ptr_t>(m_contents) ==
102 static_cast<base_ptr_t>(rhs.m_contents));
108 bool operator!=(virtual_pointer_t rhs)
const {
118 virtual_pointer_t(
const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
130 const virtual_pointer_t null_virtual_ptr =
nullptr;
136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (static_cast<void *>(ptr) ==
nullptr);
142 using buffer_t = cl::sycl::buffer_mem;
154 pMapNode_t(buffer_t
b,
size_t size,
bool f)
155 : m_buffer{b}, m_size{
size}, m_free{
f} {
156 m_buffer.set_final_data(
nullptr);
159 bool operator<=(
const pMapNode_t &rhs) {
return (m_size <= rhs.m_size); }
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
171 typename pointerMap_t::iterator get_insertion_point(
size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
174 if (!m_freeList.empty()) {
176 for (
auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
181 m_freeList.erase(freeElem);
187 retVal = std::prev(m_pointerMap.end());
202 typename pointerMap_t::iterator get_node(
const virtual_pointer_t ptr) {
203 if (this->count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range(
"There are no pointers allocated\n"));
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range(
"Cannot access null pointer\n"));
214 auto node = m_pointerMap.lower_bound(ptr);
217 if (node ==
std::end(m_pointerMap)) {
219 }
else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
223 std::out_of_range(
"The pointer is not registered in the map\n"));
235 template <
typename buffer_data_type = buffer_data_type_t>
236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
244 auto node = get_node(ptr);
246 eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
248 return *(
static_cast<sycl_buffer_t *
>(&node->second.m_buffer));
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
261 get_access(
const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
278 get_access(
const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
286 inline std::ptrdiff_t get_offset(
const virtual_pointer_t ptr) {
289 auto node = get_node(ptr);
290 auto start = node->first;
293 return (ptr - start);
300 template <
typename buffer_data_type>
301 inline size_t get_element_offset(
const virtual_pointer_t ptr) {
302 return get_offset(ptr) /
sizeof(buffer_data_type);
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument(
"Base address cannot be zero\n"));
318 PointerMapper(
const PointerMapper &) =
delete;
323 inline void clear() {
325 m_pointerMap.clear();
331 inline virtual_pointer_t add_pointer(
const buffer_t &
b) {
332 return add_pointer_impl(b);
338 inline virtual_pointer_t add_pointer(buffer_t &&
b) {
339 return add_pointer_impl(
b);
348 void fuse_forward(
typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
360 node->second.m_size += fwd_size;
370 void fuse_backward(
typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
378 prev_node->second.m_size += node->second.m_size;
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
393 template <
bool ReUse = true>
394 void remove_pointer(
const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
398 auto node = this->get_node(ptr);
400 node->second.m_free =
true;
401 m_freeList.emplace(node);
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
420 size_t count()
const {
return (m_pointerMap.size() - m_freeList.size()); }
427 template <
class BufferT>
428 virtual_pointer_t add_pointer_impl(BufferT
b) {
429 virtual_pointer_t retVal =
nullptr;
430 size_t bufSize = b.get_count();
431 pMapNode_t
p{
b, bufSize,
false};
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal,
p);
439 auto lastElemIter = get_insertion_point(bufSize);
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer =
b;
443 lastElemIter->second.m_free =
false;
447 if (lastElemIter->second.m_size > bufSize) {
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t
p2{
b, remainingSize,
true};
453 lastElemIter->second.m_size = bufSize;
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr,
p2).first;
458 m_freeList.emplace(freeNode);
461 retVal = lastElemIter->first;
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal,
p);
476 typename pointerMap_t::iterator b)
const {
477 return ((a->first < b->first) && (a->second <= b->second)) ||
478 ((a->first < b->first) && (b->second <= a->second));
484 pointerMap_t m_pointerMap;
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
500 inline void PointerMapper::remove_pointer<false>(
const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
504 m_pointerMap.erase(this->get_node(ptr));
514 inline void *SYCLmalloc(
size_t size, PointerMapper &pMap) {
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
522 return static_cast<void *
>(thePointer);
532 template <
bool ReUse = true,
typename Po
interMapper>
533 inline void SYCLfree(
void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
540 template <
typename Po
interMapper>
541 inline void SYCLfreeAll(PointerMapper &pMap) {
545 template <cl::sycl::access::mode AcMd,
typename T>
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
550 typedef scalar_t &ref_t;
551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
557 typedef RangeAccess<AcMd, T>
self_t;
561 : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
568 RangeAccess(std::nullptr_t) : RangeAccess() {}
571 return (access_.get_pointer().get() + offset_);
573 template <
typename Index>
578 template <
typename Index>
580 return self_t(access_, offset_ + offset, virtual_ptr_);
582 template <
typename Index>
584 return self_t(access_, offset_ - offset, virtual_ptr_);
586 template <
typename Index>
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
598 const RangeAccess &lhs, std::nullptr_t
i) {
604 std::nullptr_t,
const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
608 std::nullptr_t
i,
const RangeAccess &rhs) {
620 self_t temp_iterator(*
this);
622 return temp_iterator;
626 return (access_.get_count() - offset_);
638 return *get_pointer();
642 return *get_pointer();
648 return *(get_pointer() +
x);
652 return *(get_pointer() +
x);
656 return reinterpret_cast<scalar_t *
>(virtual_ptr_ +
657 (offset_ *
sizeof(scalar_t)));
661 return (virtual_ptr_ != -1);
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
669 operator RangeAccess<AcMd, const T>()
const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
674 cl::sycl::handler &cgh)
const {
675 cgh.require(access_);
684 template <cl::sycl::access::mode AcMd,
typename T>
685 struct RangeAccess<AcMd, const
T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T>
Base;
694 #endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
#define EIGEN_STRONG_INLINE
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 & operator+=(bfloat16 &a, const bfloat16 &b)
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator==(const Tuple< U, V > &x, const Tuple< U, V > &y)
bool operator<(const benchmark_t &b1, const benchmark_t &b2)
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy y set format x g set format y g set format x2 g set format y2 g set format z g set angles radians set nogrid set key title set key left top Right noreverse box linetype linewidth samplen spacing width set nolabel set noarrow set nologscale set logscale x set set pointsize set encoding default set nopolar set noparametric set set set set surface set nocontour set clabel set mapping cartesian set nohidden3d set cntrparam order set cntrparam linear set cntrparam levels auto set cntrparam points set size set set xzeroaxis lt lw set x2zeroaxis lt lw set yzeroaxis lt lw set y2zeroaxis lt lw set tics in set ticslevel set tics set mxtics default set mytics default set mx2tics default set my2tics default set xtics border mirror norotate autofreq set ytics border mirror norotate autofreq set ztics border nomirror norotate autofreq set nox2tics set noy2tics set timestamp bottom norotate offset
Namespace containing all symbols from the Eigen library.
EIGEN_STRONG_INLINE const CwiseBinaryOp< internal::scalar_sum_op< typename DenseDerived::Scalar, typename SparseDerived::Scalar >, const DenseDerived, const SparseDerived > operator+(const MatrixBase< DenseDerived > &a, const SparseMatrixBase< SparseDerived > &b)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16 &a, const bfloat16 &b)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
EIGEN_STRONG_INLINE const CwiseBinaryOp< internal::scalar_difference_op< typename DenseDerived::Scalar, typename SparseDerived::Scalar >, const DenseDerived, const SparseDerived > operator-(const MatrixBase< DenseDerived > &a, const SparseMatrixBase< SparseDerived > &b)
EIGEN_DEVICE_FUNC const Product< MatrixDerived, PermutationDerived, AliasFreeProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
#define EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 & operator-=(bfloat16 &a, const bfloat16 &b)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16 &a, const bfloat16 &b)
static EIGEN_DEPRECATED const end_t end
static const DiscreteKey mode(modeKey, 2)
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator!=(const Tuple< U, V > &x, const Tuple< U, V > &y)
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
#define EIGEN_UNUSED_VARIABLE(var)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16 &a)