Skip to content

Commit 30b01a1

Browse files
iotamudeltagchanan
authored andcommitted
topk and sort fixes (pytorch#12337)
Summary: * Topk part 1: fix intrinsincs for 64 wave front (pytorch#224) 64 in a wave front - intrinsics change. * Disable in-place sorting on ROCm. (pytorch#237) It is known to hang - use the Thrust fallback Skip one test - fails with the fallback. * Topk fixes (pytorch#239) * Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255 * Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs * Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63 * Round up blockDim.x to prevent negative index for smem bddppq ezyang Note the one additional skipped test resulting from using the thrust sort fallback for all sizes. We are working on getting bitonic to work properly (and always). Until then, this needs to be skipped on ROCm. Pull Request resolved: pytorch#12337 Differential Revision: D10259481 Pulled By: ezyang fbshipit-source-id: 5c8dc6596d7a3103ba7b4b550cba895f38c8148e
1 parent ce0b6ec commit 30b01a1

File tree

7 files changed

+45
-17
lines changed

7 files changed

+45
-17
lines changed

aten/src/THC/THCAsmUtils.cuh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ struct Bitfield<unsigned int> {
1111
static __device__ __forceinline__
1212
unsigned int getBitfield(unsigned int val, int pos, int len) {
1313
#if defined(__HIP_PLATFORM_HCC__)
14-
pos &= 0x1f;
15-
len &= 0x1f;
14+
pos &= 0xff;
15+
len &= 0xff;
1616

1717
unsigned int m = (1u << len) - 1u;
18-
m <<= pos;
19-
return val & m;
18+
return (val >> pos) & m;
2019
#else
2120
unsigned int ret;
2221
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
@@ -27,8 +26,8 @@ struct Bitfield<unsigned int> {
2726
static __device__ __forceinline__
2827
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
2928
#if defined(__HIP_PLATFORM_HCC__)
30-
pos &= 0x1f;
31-
len &= 0x1f;
29+
pos &= 0xff;
30+
len &= 0xff;
3231

3332
unsigned int m = (1u << len) - 1u;
3433
toInsert &= m;
@@ -50,12 +49,11 @@ struct Bitfield<uint64_t> {
5049
static __device__ __forceinline__
5150
uint64_t getBitfield(uint64_t val, int pos, int len) {
5251
#if defined(__HIP_PLATFORM_HCC__)
53-
pos &= 0x1f;
54-
len &= 0x1f;
52+
pos &= 0xff;
53+
len &= 0xff;
5554

5655
uint64_t m = (1u << len) - 1u;
57-
m <<= pos;
58-
return val & m;
56+
return (val >> pos) & m;
5957
#else
6058
uint64_t ret;
6159
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
@@ -66,8 +64,8 @@ struct Bitfield<uint64_t> {
6664
static __device__ __forceinline__
6765
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
6866
#if defined(__HIP_PLATFORM_HCC__)
69-
pos &= 0x1f;
70-
len &= 0x1f;
67+
pos &= 0xff;
68+
len &= 0xff;
7169

7270
uint64_t m = (1u << len) - 1u;
7371
toInsert &= m;
@@ -105,16 +103,18 @@ __device__ __forceinline__ unsigned getLaneMaskLt() {
105103
#endif
106104
}
107105

108-
__device__ __forceinline__ unsigned getLaneMaskLe() {
109-
#if defined(__HIP_PLATFORM_HCC__)
110-
std::uint64_t m = (1ull << (getLaneId() + 1ull)) - 1ull;
106+
#if defined (__HIP_PLATFORM_HCC__)
107+
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
108+
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
111109
return m;
110+
}
112111
#else
112+
__device__ __forceinline__ unsigned getLaneMaskLe() {
113113
unsigned mask;
114114
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
115115
return mask;
116-
#endif
117116
}
117+
#endif
118118

119119
__device__ __forceinline__ unsigned getLaneMaskGt() {
120120
#if defined(__HIP_PLATFORM_HCC__)

aten/src/THC/THCDeviceUtils.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK()
4444
#endif
4545
}
4646

47+
#if defined(__HIP_PLATFORM_HCC__)
48+
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
49+
{
50+
return __ballot(predicate);
51+
}
52+
#else
4753
__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
4854
{
4955
#if CUDA_VERSION >= 9000
@@ -52,6 +58,7 @@ __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int
5258
return __ballot(predicate);
5359
#endif
5460
}
61+
#endif
5562

5663
#ifdef __HIP_PLATFORM_HCC__
5764
//To handle ambiguity, add a type double version.

aten/src/THC/THCScanUtils.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,15 @@ __device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunct
159159
template <typename T, bool KillWARDependency, class BinaryFunction>
160160
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
161161
// Within-warp, we use warp voting.
162+
#if defined (__HIP_PLATFORM_HCC__)
163+
unsigned long long int vote = WARP_BALLOT(in);
164+
T index = __popcll(getLaneMaskLe() & vote);
165+
T carry = __popcll(vote);
166+
#else
162167
T vote = WARP_BALLOT(in);
163168
T index = __popc(getLaneMaskLe() & vote);
164169
T carry = __popc(vote);
170+
#endif
165171

166172
int warp = threadIdx.x / SCAN_UTILS_WARP_SIZE;
167173

@@ -207,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
207213
*out -= (T) in;
208214

209215
// The outgoing carry for all threads is the last warp's sum
210-
*carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1];
216+
*carry = smem[THCCeilDiv<int>(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1];
211217

212218
if (KillWARDependency) {
213219
__syncthreads();

aten/src/THC/THCTensorTopK.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize],
176176
#pragma unroll
177177
for (unsigned int j = 0; j < RadixSize; ++j) {
178178
bool vote = hasVal && (digitInRadix == j);
179+
#if defined (__HIP_PLATFORM_HCC__)
180+
counts[j] += __popcll(WARP_BALLOT(vote));
181+
#else
179182
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
183+
#endif
180184
}
181185
}
182186

aten/src/THC/generic/THCTensorSort.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,12 @@ void THCTensor_(sort)(THCState* state,
309309
int maxSliceSize = 2048;
310310
#endif
311311

312+
#ifdef __HIP_PLATFORM_HCC__
313+
// TODO bitonicSortKVInPlace hangs on ROCm currently.
314+
if (0) {
315+
#else
312316
if (sliceSize <= maxSliceSize) {
317+
#endif
313318
// Fill `indices` (the values) with the
314319
// slice-relative index.
315320
THCudaLongTensor_fillSliceWithIndex(state, indices, dim);

aten/src/THC/generic/THCTensorTopK.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ void THCTensor_(topk)(THCState* state,
140140
if (sorted) {
141141
// FIXME: the k/v inplace sort along slice only works for size <=
142142
// 2048 at the moment
143+
#ifdef __HIP_PLATFORM_HCC__
144+
// TODO bitonicSortKVInPlace hangs on ROCm currently.
145+
if (0) {
146+
#else
143147
if (sliceSize <= 2048) {
148+
#endif
144149
// This avoids any memory allocations and performs all sorting
145150
// work inplace along the slice
146151
THCTensor_(sortKeyValueInplace)(state, topK, indices, dim, dir);

test/test_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6581,6 +6581,7 @@ def test_tensor_shape_empty(self):
65816581
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
65826582

65836583
# functions that operate over a dimension but don't reduce.
6584+
@skipIfRocm
65846585
def test_dim_function_empty(self):
65856586
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
65866587
for device in devices:

0 commit comments

Comments
 (0)