diff --git a/aten/src/THC/THCAsmUtils.cuh b/aten/src/THC/THCAsmUtils.cuh index e3d0cc9ce8651e..b75408e9ec512d 100644 --- a/aten/src/THC/THCAsmUtils.cuh +++ b/aten/src/THC/THCAsmUtils.cuh @@ -11,12 +11,11 @@ struct Bitfield { static __device__ __forceinline__ unsigned int getBitfield(unsigned int val, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; unsigned int m = (1u << len) - 1u; - m <<= pos; - return val & m; + return (val >> pos) & m; #else unsigned int ret; asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); @@ -27,8 +26,8 @@ struct Bitfield { static __device__ __forceinline__ unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; unsigned int m = (1u << len) - 1u; toInsert &= m; @@ -50,12 +49,11 @@ struct Bitfield { static __device__ __forceinline__ uint64_t getBitfield(uint64_t val, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; uint64_t m = (1u << len) - 1u; - m <<= pos; - return val & m; + return (val >> pos) & m; #else uint64_t ret; asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); @@ -66,8 +64,8 @@ struct Bitfield { static __device__ __forceinline__ uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) { #if defined(__HIP_PLATFORM_HCC__) - pos &= 0x1f; - len &= 0x1f; + pos &= 0xff; + len &= 0xff; uint64_t m = (1u << len) - 1u; toInsert &= m; @@ -107,7 +105,7 @@ __device__ __forceinline__ unsigned getLaneMaskLt() { #if defined (__HIP_PLATFORM_HCC__) __device__ __forceinline__ unsigned long long int getLaneMaskLe() { - std::uint64_t m = (1ull << (getLaneId() + 1ull)) - 1ull; + std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); return m; } #else diff --git a/aten/src/THC/THCScanUtils.cuh b/aten/src/THC/THCScanUtils.cuh index 0dc92b0f120a9d..c03ded0c8dd839 100644 --- a/aten/src/THC/THCScanUtils.cuh +++ b/aten/src/THC/THCScanUtils.cuh @@ -213,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi *out -= (T) in; // The outgoing carry for all threads is the last warp's sum - *carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1]; + *carry = smem[THCCeilDiv(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1]; if (KillWARDependency) { __syncthreads();