Skip to content

Commit 7484fb1

Browse files
jithunnair-amdiotamudelta
authored andcommitted
Topk fixes (#239)
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255 * Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs * Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63 * Round up blockDim.x to prevent negative index for smem * Revert unintended change; back to __forceinline__ * Revert change to TOPK_WARP_SIZE that crept in with previous commit
1 parent 3a5fda7 commit 7484fb1

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

aten/src/THC/THCAsmUtils.cuh

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ struct Bitfield<unsigned int> {
1111
static __device__ __forceinline__
1212
unsigned int getBitfield(unsigned int val, int pos, int len) {
1313
#if defined(__HIP_PLATFORM_HCC__)
14-
pos &= 0x1f;
15-
len &= 0x1f;
14+
pos &= 0xff;
15+
len &= 0xff;
1616

1717
unsigned int m = (1u << len) - 1u;
18-
m <<= pos;
19-
return val & m;
18+
return (val >> pos) & m;
2019
#else
2120
unsigned int ret;
2221
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
@@ -27,8 +26,8 @@ struct Bitfield<unsigned int> {
2726
static __device__ __forceinline__
2827
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
2928
#if defined(__HIP_PLATFORM_HCC__)
30-
pos &= 0x1f;
31-
len &= 0x1f;
29+
pos &= 0xff;
30+
len &= 0xff;
3231

3332
unsigned int m = (1u << len) - 1u;
3433
toInsert &= m;
@@ -50,12 +49,11 @@ struct Bitfield<uint64_t> {
5049
static __device__ __forceinline__
5150
uint64_t getBitfield(uint64_t val, int pos, int len) {
5251
#if defined(__HIP_PLATFORM_HCC__)
53-
pos &= 0x1f;
54-
len &= 0x1f;
52+
pos &= 0xff;
53+
len &= 0xff;
5554

5655
uint64_t m = (1u << len) - 1u;
57-
m <<= pos;
58-
return val & m;
56+
return (val >> pos) & m;
5957
#else
6058
uint64_t ret;
6159
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
@@ -66,8 +64,8 @@ struct Bitfield<uint64_t> {
6664
static __device__ __forceinline__
6765
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
6866
#if defined(__HIP_PLATFORM_HCC__)
69-
pos &= 0x1f;
70-
len &= 0x1f;
67+
pos &= 0xff;
68+
len &= 0xff;
7169

7270
uint64_t m = (1u << len) - 1u;
7371
toInsert &= m;
@@ -107,7 +105,7 @@ __device__ __forceinline__ unsigned getLaneMaskLt() {
107105

108106
#if defined (__HIP_PLATFORM_HCC__)
109107
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
110-
std::uint64_t m = (1ull << (getLaneId() + 1ull)) - 1ull;
108+
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
111109
return m;
112110
}
113111
#else

aten/src/THC/THCScanUtils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
213213
*out -= (T) in;
214214

215215
// The outgoing carry for all threads is the last warp's sum
216-
*carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1];
216+
*carry = smem[THCCeilDiv<int>(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1];
217217

218218
if (KillWARDependency) {
219219
__syncthreads();

0 commit comments

Comments
 (0)