Skip to content

Commit e1946c7

Browse files
authored
Merge pull request #150 from iotamudelta/fix_topk
Fix topk
2 parents fbe1d24 + 30f2ac8 commit e1946c7

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

aten/src/THC/THCScanUtils.cuh

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
#include "THCAsmUtils.cuh"
55
#include "THCDeviceUtils.cuh"
66

7+
#if defined(__HIP_PLATFORM_HCC__)
8+
#define WARP_SIZE 64
9+
#else
10+
#define WARP_SIZE 32
11+
#endif
12+
713
// Collection of in-kernel scan / prefix sum utilities
814

915
// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
@@ -157,7 +163,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
157163
T index = __popc(getLaneMaskLe() & vote);
158164
T carry = __popc(vote);
159165

160-
int warp = threadIdx.x / 32;
166+
int warp = threadIdx.x / WARP_SIZE;
161167

162168
// Per each warp, write out a value
163169
if (getLaneId() == 0) {
@@ -170,7 +176,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
170176
// warp shuffle scan for CC 3.0+
171177
if (threadIdx.x == 0) {
172178
int current = 0;
173-
for (int i = 0; i < blockDim.x / 32; ++i) {
179+
for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
174180
T v = smem[i];
175181
smem[i] = binop(smem[i], current);
176182
current = binop(current, v);
@@ -201,11 +207,13 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
201207
*out -= (T) in;
202208

203209
// The outgoing carry for all threads is the last warp's sum
204-
*carry = smem[(blockDim.x / 32) - 1];
210+
*carry = smem[(blockDim.x / WARP_SIZE) - 1];
205211

206212
if (KillWARDependency) {
207213
__syncthreads();
208214
}
209215
}
210216

217+
#undef WARP_SIZE
218+
211219
#endif // THC_SCAN_UTILS_INC

aten/src/THC/THCTensorTopK.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,11 @@ __device__ DataType findPattern(DataType* smem,
213213
IndexType withinSliceStride,
214214
BitDataType desired,
215215
BitDataType desiredMask) {
216+
#ifdef __HIP_PLATFORM_HCC__
217+
if (threadIdx.x < 64) {
218+
#else
216219
if (threadIdx.x < 32) {
220+
#endif
217221
smem[threadIdx.x] = ScalarConvert<int, DataType>::to(0);
218222
}
219223
__syncthreads();
@@ -366,7 +370,11 @@ __global__ void gatherTopK(TensorInfo<T, IndexType> input,
366370
IndexType indicesWithinSliceStride) {
367371
// Indices are limited to integer fp precision, so counts can fit in
368372
// int32, regardless of IndexType
373+
#ifdef __HIP_PLATFORM_HCC__
374+
__shared__ int smem[64];
375+
#else
369376
__shared__ int smem[32]; // one per each warp, up to warp limit
377+
#endif
370378

371379
IndexType slice = getLinearBlockId<IndexType>();
372380
if (slice >= numInputSlices) {

aten/src/THC/generic/THCTensorTopK.cu

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ THC_API void THCTensor_(topk)(THCState* state,
3333
gatherTopK<real, INDEX_T, DIM, DIR> \
3434
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
3535
inputInfo, \
36-
sliceSize, \
37-
k, \
38-
inputSlices, \
36+
static_cast<INDEX_T>(sliceSize), \
37+
static_cast<INDEX_T>(k), \
38+
static_cast<INDEX_T>(inputSlices), \
3939
/* The actual dimension that the k-selection is running in */ \
4040
/* may have changed from collapseDims() */ \
41-
inputInfo.strides[collapseInputDim], \
41+
static_cast<INDEX_T>(inputInfo.strides[collapseInputDim]), \
4242
topKInfo, \
43-
topKSlices, \
44-
topKInfo.strides[collapseTopKDim], \
43+
static_cast<INDEX_T>(topKSlices), \
44+
static_cast<INDEX_T>(topKInfo.strides[collapseTopKDim]), \
4545
indicesInfo, \
46-
indicesInfo.strides[collapseIndicesDim])
46+
static_cast<INDEX_T>(indicesInfo.strides[collapseIndicesDim]))
4747

4848
#define RUN_DIR(INDEX_T, DIM) \
4949
if (dir) { \
@@ -63,6 +63,12 @@ THC_API void THCTensor_(topk)(THCState* state,
6363
RUN_DIR(INDEX_T, -1); \
6464
}
6565

66+
#ifdef __HIP_PLATFORM_HCC__
67+
#define WARP_SIZE 64
68+
#else
69+
#define WARP_SIZE 32
70+
#endif
71+
6672
#define RUN_T(INDEX_T) \
6773
TensorInfo<real, INDEX_T> inputInfo = \
6874
getTensorInfo<real, THCTensor, INDEX_T>(state, input); \
@@ -96,7 +102,7 @@ THC_API void THCTensor_(topk)(THCState* state,
96102
THError("Slice to sort is too large"); \
97103
} \
98104
\
99-
dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) 32), (int64_t) 1024)); \
105+
dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) WARP_SIZE), (int64_t) 1024)); \
100106
\
101107
/* This is used as a template parameter to calculate indices. */ \
102108
/* We only specialize it if all collapsed dim sizes are the */ \
@@ -124,6 +130,7 @@ THC_API void THCTensor_(topk)(THCState* state,
124130
#undef RUN_DIM
125131
#undef RUN_DIR
126132
#undef RUN_K
133+
#undef WARP_SIZE
127134

128135
// Sort the results if the user wants them sorted, since our
129136
// selection routine does not ensure sorting

test/test_torch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3385,7 +3385,6 @@ def test_topk_arguments(self):
33853385
self.assertRaises(TypeError, lambda: q.topk(4, True))
33863386

33873387
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
3388-
@skipIfRocm
33893388
def test_topk_noncontiguous_gpu(self):
33903389
t = torch.randn(20, device="cuda")[::2]
33913390
top1, idx1 = t.topk(5)

0 commit comments

Comments
 (0)