23#if defined(EIGEN_USE_SYCL) && \
24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25#define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
28#ifdef EIGEN_EXCEPTIONS
34#include <unordered_map>
40using sycl_acc_target = cl::sycl::access::target;
41using sycl_acc_mode = cl::sycl::access::mode;
46using buffer_data_type_t =
uint8_t;
47const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
57 using base_ptr_t = std::intptr_t;
65 struct virtual_pointer_t {
68 base_ptr_t m_contents;
73 operator void *()
const {
return reinterpret_cast<void *
>(m_contents); }
78 operator base_ptr_t()
const {
return m_contents; }
84 virtual_pointer_t
operator+(
size_t off) {
return m_contents + off; }
87 bool operator<(virtual_pointer_t rhs)
const {
88 return (
static_cast<base_ptr_t
>(m_contents) <
89 static_cast<base_ptr_t
>(rhs.m_contents));
92 bool operator>(virtual_pointer_t rhs)
const {
93 return (
static_cast<base_ptr_t
>(m_contents) >
94 static_cast<base_ptr_t
>(rhs.m_contents));
100 bool operator==(virtual_pointer_t rhs)
const {
101 return (
static_cast<base_ptr_t
>(m_contents) ==
102 static_cast<base_ptr_t
>(rhs.m_contents));
108 bool operator!=(virtual_pointer_t rhs)
const {
118 virtual_pointer_t(
const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
130 const virtual_pointer_t null_virtual_ptr =
nullptr;
136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (
static_cast<void *
>(ptr) ==
nullptr);
142 using buffer_t = cl::sycl::buffer_mem;
154 pMapNode_t(buffer_t
b,
size_t size,
bool f)
155 : m_buffer{
b}, m_size{
size}, m_free{f} {
156 m_buffer.set_final_data(
nullptr);
159 bool operator<=(
const pMapNode_t &rhs) {
return (m_size <= rhs.m_size); }
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
171 typename pointerMap_t::iterator get_insertion_point(
size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
174 if (!m_freeList.empty()) {
176 for (
auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
181 m_freeList.erase(freeElem);
187 retVal = std::prev(m_pointerMap.end());
202 typename pointerMap_t::iterator get_node(
const virtual_pointer_t ptr) {
203 if (this->
count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range(
"There are no pointers allocated\n"));
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range(
"Cannot access null pointer\n"));
214 auto node = m_pointerMap.lower_bound(ptr);
217 if (node == std::end(m_pointerMap)) {
219 }
else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
223 std::out_of_range(
"The pointer is not registered in the map\n"));
235 template <
typename buffer_data_type = buffer_data_type_t>
236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
244 auto node = get_node(ptr);
246 eigen_assert(ptr <
static_cast<virtual_pointer_t
>(node->second.m_size +
248 return *(
static_cast<sycl_buffer_t *
>(&node->second.m_buffer));
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
261 get_access(
const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
278 get_access(
const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
286 inline std::ptrdiff_t get_offset(
const virtual_pointer_t ptr) {
289 auto node = get_node(ptr);
290 auto start = node->first;
293 return (ptr - start);
300 template <
typename buffer_data_type>
301 inline size_t get_element_offset(
const virtual_pointer_t ptr) {
302 return get_offset(ptr) /
sizeof(buffer_data_type);
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument(
"Base address cannot be zero\n"));
318 PointerMapper(
const PointerMapper &) =
delete;
323 inline void clear() {
325 m_pointerMap.clear();
331 inline virtual_pointer_t add_pointer(
const buffer_t &
b) {
332 return add_pointer_impl(
b);
338 inline virtual_pointer_t add_pointer(buffer_t &&
b) {
339 return add_pointer_impl(
b);
348 void fuse_forward(
typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
360 node->second.m_size += fwd_size;
370 void fuse_backward(
typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
378 prev_node->second.m_size += node->second.m_size;
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
393 template <
bool ReUse = true>
394 void remove_pointer(
const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
398 auto node = this->get_node(ptr);
400 node->second.m_free =
true;
401 m_freeList.emplace(node);
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
420 size_t count()
const {
return (m_pointerMap.size() - m_freeList.size()); }
427 template <
class BufferT>
428 virtual_pointer_t add_pointer_impl(BufferT
b) {
429 virtual_pointer_t retVal =
nullptr;
430 size_t bufSize =
b.get_count();
431 pMapNode_t
p{
b, bufSize,
false};
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal,
p);
439 auto lastElemIter = get_insertion_point(bufSize);
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer =
b;
443 lastElemIter->second.m_free =
false;
447 if (lastElemIter->second.m_size > bufSize) {
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t p2{
b, remainingSize,
true};
453 lastElemIter->second.m_size = bufSize;
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
458 m_freeList.emplace(freeNode);
461 retVal = lastElemIter->first;
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal,
p);
476 typename pointerMap_t::iterator
b)
const {
477 return ((
a->first <
b->first) && (
a->second <=
b->second)) ||
478 ((
a->first <
b->first) && (
b->second <=
a->second));
484 pointerMap_t m_pointerMap;
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
492 std::intptr_t m_baseAddress;
500inline void PointerMapper::remove_pointer<false>(
const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
504 m_pointerMap.erase(this->get_node(ptr));
514inline void *SYCLmalloc(
size_t size, PointerMapper &pMap) {
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{
size}));
522 return static_cast<void *
>(thePointer);
532template <
bool ReUse = true,
typename Po
interMapper>
533inline void SYCLfree(
void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
540template <
typename Po
interMapper>
541inline void SYCLfreeAll(PointerMapper &pMap) {
545template <cl::sycl::access::mode AcMd,
typename T>
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
550 typedef scalar_t &ref_t;
551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
557 typedef RangeAccess<AcMd, T> self_t;
560 std::intptr_t virtual_ptr)
561 : access_(access), offset_(
offset), virtual_ptr_(virtual_ptr) {}
563 RangeAccess(cl::sycl::buffer<scalar_t, 1>
buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{accessor{
buff}}, offset_(0), virtual_ptr_(-1) {}
568 RangeAccess(std::nullptr_t) : RangeAccess() {}
571 return (access_.get_pointer().get() + offset_);
573 template <
typename Index>
578 template <
typename Index>
580 return self_t(access_, offset_ +
offset, virtual_ptr_);
582 template <
typename Index>
584 return self_t(access_, offset_ -
offset, virtual_ptr_);
586 template <
typename Index>
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
598 const RangeAccess &lhs, std::nullptr_t
i) {
604 std::nullptr_t,
const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
608 std::nullptr_t
i,
const RangeAccess &rhs) {
620 self_t temp_iterator(*
this);
622 return temp_iterator;
626 return (access_.get_count() - offset_);
638 return *get_pointer();
642 return *get_pointer();
648 return *(get_pointer() +
x);
652 return *(get_pointer() +
x);
656 return reinterpret_cast<scalar_t *
>(virtual_ptr_ +
657 (offset_ *
sizeof(scalar_t)));
661 return (virtual_ptr_ != -1);
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
669 operator RangeAccess<AcMd, const T>()
const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
674 cl::sycl::handler &cgh)
const {
675 cgh.require(access_);
681 std::intptr_t virtual_ptr_;
684template <cl::sycl::access::mode AcMd,
typename T>
685struct RangeAccess<AcMd, const
T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T> Base;
ArrayXXi a
Definition Array_initializer_list_23_cxx11.cpp:1
int i
Definition BiCGSTAB_step_by_step.cpp:9
EIGEN_DEVICE_FUNC const NegativeReturnType operator-() const
Definition CommonCwiseUnaryOps.h:45
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typenameEIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typenameEIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
Definition IndexedViewMethods.h:73
#define EIGEN_UNUSED_VARIABLE(var)
Definition Macros.h:1076
#define EIGEN_DEVICE_FUNC
Definition Macros.h:976
#define eigen_assert(x)
Definition Macros.h:1037
#define EIGEN_STRONG_INLINE
Definition Macros.h:917
#define EIGEN_THROW_X(X)
Definition Macros.h:1403
float * p
Definition Tutorial_Map_using.cpp:9
Scalar * b
Definition benchVecAdd.cpp:17
Scalar Scalar int size
Definition benchVecAdd.cpp:17
bool operator<(const benchmark_t &b1, const benchmark_t &b2)
Definition benchmark-blocking-sizes.cpp:144
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition gnuplot_common_settings.hh:12
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy y set format x g set format y g set format x2 g set format y2 g set format z g set angles radians set nogrid set key title set key left top Right noreverse box linetype linewidth samplen spacing width set nolabel set noarrow set nologscale set logscale x set set pointsize set encoding default set nopolar set noparametric set set set set surface set nocontour set clabel set mapping cartesian set nohidden3d set cntrparam order set cntrparam linear set cntrparam levels auto set cntrparam points set size set set xzeroaxis lt lw set x2zeroaxis lt lw set yzeroaxis lt lw set y2zeroaxis lt lw set tics in set ticslevel set tics set mxtics default set mytics default set mx2tics default set my2tics default set xtics border mirror norotate autofreq set ytics border mirror norotate autofreq set ztics border nomirror norotate autofreq set nox2tics set noy2tics set timestamp bottom norotate offset
Definition gnuplot_common_settings.hh:64
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator+(const bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:161
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 & operator-=(bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:192
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:218
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:221
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:230
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator*(const bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:170
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:227
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 & operator+=(bfloat16 &a, const bfloat16 &b)
Definition BFloat16.h:184
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16 &a)
Definition BFloat16.h:200
::uint8_t uint8_t
Definition Meta.h:52
Namespace containing all symbols from the Eigen library.
Definition bench_norm.cpp:85
Definition BandTriangularSolver.h:13
buff_t buff
Definition ref_serial.cpp:62
uint8_t count
Definition ref_serial.h:256
real function second()
SECOND returns nothing
Definition second_NONE.f:39