Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sycl/include/syclcompat/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ T shift_sub_group_right(sycl::sub_group g, T x, unsigned int delta,
template <typename T>
T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
int logical_sub_group_size = 32) {
if (logical_sub_group_size == 32) {
return permute_group_by_xor(g, x, mask);
}
unsigned int id = g.get_local_linear_id();
unsigned int start_index =
id / logical_sub_group_size * logical_sub_group_size;
Expand Down
19 changes: 15 additions & 4 deletions sycl/test-e2e/syclcompat/util/util_permute_sub_group_by_xor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ void test_permute_sub_group_by_xor() {
syclcompat::device_ext &dev_ct1 = syclcompat::get_current_device();
sycl::queue *q_ct1 = dev_ct1.default_queue();
bool Result = true;
int *dev_data = nullptr;
unsigned int *dev_data_u = nullptr;
sycl::range<3> GridSize(1, 1, 1);
sycl::range<3> BlockSize(1, 1, 1);
dev_data = sycl::malloc_device<int>(DATA_NUM, *q_ct1);
dev_data_u = sycl::malloc_device<unsigned int>(DATA_NUM, *q_ct1);

GridSize = sycl::range<3>(1, 1, 2);
Expand Down Expand Up @@ -120,6 +118,19 @@ void test_permute_sub_group_by_xor() {
q_ct1->memcpy(host_dev_data_u, dev_data_u, DATA_NUM * sizeof(unsigned int))
.wait();
verify_data<unsigned int>(host_dev_data_u, expect1, DATA_NUM);
sycl::free(dev_data_u, *q_ct1);
}

void test_permute_sub_group_by_xor_extra_arg() {
std::cout << __PRETTY_FUNCTION__ << std::endl;

syclcompat::device_ext &dev_ct1 = syclcompat::get_current_device();
sycl::queue *q_ct1 = dev_ct1.default_queue();
bool Result = true;
unsigned int *dev_data_u = nullptr;
sycl::range<3> GridSize(1, 1, 1);
sycl::range<3> BlockSize(1, 1, 1);
dev_data_u = sycl::malloc_device<unsigned int>(DATA_NUM, *q_ct1);

GridSize = sycl::range<3>(1, 1, 2);
BlockSize = sycl::range<3>(1, 2, 32);
Expand All @@ -133,6 +144,7 @@ void test_permute_sub_group_by_xor() {
91, 90, 93, 92, 95, 94, 97, 96, 99, 98, 101, 100, 103, 102, 105,
104, 107, 106, 109, 108, 111, 110, 113, 112, 115, 114, 117, 116, 119, 118,
121, 120, 123, 122, 125, 124, 127, 126};
unsigned int host_dev_data_u[DATA_NUM];
init_data<unsigned int>(host_dev_data_u, DATA_NUM);

q_ct1->memcpy(dev_data_u, host_dev_data_u, DATA_NUM * sizeof(unsigned int))
Expand All @@ -147,13 +159,12 @@ void test_permute_sub_group_by_xor() {
q_ct1->memcpy(host_dev_data_u, dev_data_u, DATA_NUM * sizeof(unsigned int))
.wait();
verify_data<unsigned int>(host_dev_data_u, expect2, DATA_NUM);

sycl::free(dev_data, *q_ct1);
sycl::free(dev_data_u, *q_ct1);
}

int main() {
test_permute_sub_group_by_xor();
test_permute_sub_group_by_xor_extra_arg();

return 0;
}
Loading