Skip to content

Commit 975c772

Browse files
authored
[DevMSAN] Correctly apply stride for dest/src pointer when do async copy (#19766)
Currently, we always define stride in elements when writing to destination pointer. However, according to spirv spec, the stride can only be applied to global ptr.
1 parent e789a08 commit 975c772

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

libdevice/sanitizer/msan_rtl.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,15 @@ inline void ReportError(const uint32_t size, const char __SYCL_CONSTANT__ *file,
303303

304304
// This function is only used for shadow propagation
305305
template <typename T>
306-
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride) {
306+
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride,
307+
bool StrideOnSrc) {
307308
auto DestPtr = (__SYCL_GLOBAL__ T *)Dest;
308309
auto SrcPtr = (const __SYCL_GLOBAL__ T *)Src;
309310
for (size_t i = 0; i < NumElements; i++) {
310-
DestPtr[i] = SrcPtr[i * Stride];
311+
if (StrideOnSrc)
312+
DestPtr[i] = SrcPtr[i * Stride];
313+
else
314+
DestPtr[i * Stride] = SrcPtr[i];
311315
}
312316
}
313317

@@ -749,16 +753,20 @@ __msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
749753

750754
switch (element_size) {
751755
case 1:
752-
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
756+
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride,
757+
src_as == ADDRESS_SPACE_GLOBAL);
753758
break;
754759
case 2:
755-
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
760+
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride,
761+
src_as == ADDRESS_SPACE_GLOBAL);
756762
break;
757763
case 4:
758-
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
764+
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride,
765+
src_as == ADDRESS_SPACE_GLOBAL);
759766
break;
760767
case 8:
761-
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
768+
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride,
769+
src_as == ADDRESS_SPACE_GLOBAL);
762770
break;
763771
default:
764772
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type,

0 commit comments

Comments
 (0)