14 #define EIGEN_TEST_NO_LONGDOUBLE 15 #define EIGEN_TEST_NO_COMPLEX 16 #define EIGEN_TEST_FUNC cxx11_tensor_broadcast_sycl 17 #define EIGEN_DEFAULT_DENSE_INDEX_TYPE int 18 #define EIGEN_USE_SYCL 21 #include <unsupported/Eigen/CXX11/Tensor> 24 using Eigen::SyclDevice;
34 for (
size_t i = 0; i < out_range.
size(); ++i)
35 out_range[i] = in_range[i] * broadcasts[i];
40 for (
size_t i = 0; i < in_range.
size(); ++i)
41 VERIFY_IS_EQUAL(out.
dimension(i), out_range[i]);
44 for (
int i = 0; i < input.
size(); ++i)
45 input(i) =
static_cast<float>(i);
47 float * gpu_in_data =
static_cast<float*
>(sycl_device.allocate(input.
dimensions().TotalSize()*
sizeof(float)));
48 float * gpu_out_data =
static_cast<float*
>(sycl_device.allocate(out.
dimensions().TotalSize()*
sizeof(float)));
52 sycl_device.memcpyHostToDevice(gpu_in_data, input.
data(),(input.
dimensions().TotalSize())*
sizeof(
float));
53 gpu_out.
device(sycl_device) = gpu_in.broadcast(broadcasts);
54 sycl_device.memcpyDeviceToHost(out.
data(), gpu_out_data,(out.
dimensions().TotalSize())*
sizeof(
float));
56 for (
int i = 0; i < 4; ++i) {
57 for (
int j = 0; j < 9; ++j) {
58 for (
int k = 0; k < 5; ++k) {
59 for (
int l = 0; l < 28; ++l) {
60 VERIFY_IS_APPROX(input(i%2,j%3,k%5,l%7), out(i,j,k,l));
65 printf(
"Broadcast Test Passed\n");
66 sycl_device.deallocate(gpu_in_data);
67 sycl_device.deallocate(gpu_out_data);
71 cl::sycl::gpu_selector
s;
72 Eigen::SyclDevice sycl_device(s);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::size_t size()
TensorDevice< TensorMap< PlainObjectType, Options_, MakePointer_ >, DeviceType > device(const DeviceType &device)
A tensor expression mapping an existing array of data.
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
void test_cxx11_tensor_broadcast_sycl()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
static void test_broadcast_sycl(const Eigen::SyclDevice &sycl_device)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const