From 95afb10875f98402a3adaa9869344e84a5d2a6b5 Mon Sep 17 00:00:00 2001 From: jithunnair-amd Date: Wed, 3 Oct 2018 10:54:36 -0500 Subject: [PATCH 1/6] Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255 --- aten/src/THC/THCAsmUtils.cuh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index e3d0cc9ce8651e..f50f22aef44a71 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -11,8 +11,8 @@ struct Bitfield { static __device__ __forceinline__ unsigned int getBitfield(unsigned int val, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; unsigned int m = (1u << len) - 1u; m <<= pos; @@ -27,8 +27,8 @@ struct Bitfield { static __device__ __forceinline__ unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; unsigned int m = (1u << len) - 1u; toInsert &= m; @@ -50,8 +50,8 @@ struct Bitfield { static __device__ __forceinline__ uint64_t getBitfield(uint64_t val, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; uint64_t m = (1u << len) - 1u; m <<= pos; @@ -66,8 +66,8 @@ struct Bitfield { static __device__ __forceinline__ uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; uint64_t m = (1u << len) - 1u; toInsert &= m; From 03e27345d2a69519a18ee3720e11e271c11ffb78 Mon Sep 17 00:00:00 2001 From: jithunnair-amd Date: Wed, 3 Oct 2018 10:56:21 -0500 Subject: [PATCH 2/6] Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs --- aten/src/THC/THCAsmUtils.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index f50f22aef44a71..ae6ebf7a3fc798 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -15,8 +15,7 @@ struct Bitfield { len &= 0xff; unsigned int m = (1u << len) - 1u; - m <<= pos; - return val & m; + return (val >> pos) & m; #else unsigned int ret; asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); @@ -54,8 +53,7 @@ struct Bitfield { len &= 0xff; uint64_t m = (1u << len) - 1u; - m <<= pos; - return val & m; + return (val >> pos) & m; #else uint64_t ret; asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); From b4ae46542c842a71f14adfbdc0d362a2f38df4c9 Mon Sep 17 00:00:00 2001 From: jithunnair-amd Date: Wed, 3 Oct 2018 10:57:37 -0500 Subject: [PATCH 3/6] Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63 --- aten/src/THC/THCAsmUtils.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index ae6ebf7a3fc798..1110c41f0c6b18 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -104,8 +104,8 @@ __device__ __forceinline__ unsigned getLaneMaskLt() { } #if defined (__HIP_PLATFORM_HCC__) -__device__ __forceinline__ unsigned long long int getLaneMaskLe() { - std::uint64_t m = (1ull << (getLaneId() + 1ull)) - 1ull; +__device__ inline unsigned long long int getLaneMaskLe() { + std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); return m; } #else From 42a5ff1601e421072577b72974d8940f213b8eaf Mon Sep 17 00:00:00 2001 From: jithunnair-amd Date: Wed, 3 Oct 2018 10:58:15 -0500 Subject: [PATCH 4/6] Round up blockDim.x to prevent negative index for smem --- aten/src/THC/THCScanUtils.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/THC/THCScanUtils.cuh b/aten/src/THC/THCScanUtils.cuh index 0dc92b0f120a9d..c03ded0c8dd839 100644 --- a/aten/src/THC/THCScanUtils.cuh +++ b/aten/src/THC/THCScanUtils.cuh @@ -213,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi *out -= (T) in; // The outgoing carry for all threads is the last warp's sum - *carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1]; + *carry = smem[THCCeilDiv(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1]; if (KillWARDependency) { __syncthreads(); From 1e39dac83d80a2a706527c019cc09b3aef65a5cb Mon Sep 17 00:00:00 2001 From: jithunnair-amd Date: Wed, 3 Oct 2018 12:51:53 -0500 Subject: [PATCH 5/6] Revert unintended change; back to __forceinline__ --- aten/src/THC/THCAsmUtils.cuh | 2 +- aten/src/THC/generic/THCTensorTopK.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index 1110c41f0c6b18..b75408e9ec512d 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -104,7 +104,7 @@ __device__ __forceinline__ unsigned getLaneMaskLt() { } #if defined (__HIP_PLATFORM_HCC__) -__device__ inline unsigned long long int getLaneMaskLe() { +__device__ __forceinline__ unsigned long long int getLaneMaskLe() { std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); return m; } diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index a195dfbe5ca7a8..9a1b94d1315fec 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -105,7 +105,7 @@ void THCTensor_(topk)(THCState* state, THError("Slice to sort is too large"); \ } \ \ - dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) TOPK_WARP_SIZE), (int64_t) 1024)); \ + dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) (1024/TOPK_WARP_SIZE)), (int64_t) 1024)); \ \ /* This is used as a template parameter to calculate indices. */ \ /* We only specialize it if all collapsed dim sizes are the */ \ From 9103ba6ecf236cd8de02c07d36144174b6f11510 Mon Sep 17 00:00:00 2001 From: jithunnair-amd Date: Wed, 3 Oct 2018 13:01:25 -0500 Subject: [PATCH 6/6] Revert change to TOPK_WARP_SIZE that crept in with previous commit --- aten/src/THC/generic/THCTensorTopK.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index 9a1b94d1315fec..a195dfbe5ca7a8 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -105,7 +105,7 @@ void THCTensor_(topk)(THCState* state, THError("Slice to sort is too large"); \ } \ \ - dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) (1024/TOPK_WARP_SIZE)), (int64_t) 1024)); \ + dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) TOPK_WARP_SIZE), (int64_t) 1024)); \ \ /* This is used as a template parameter to calculate indices. */ \ /* We only specialize it if all collapsed dim sizes are the */ \