Skip to content

Fix topk #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions aten/src/THC/THCScanUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include "THCAsmUtils.cuh"
#include "THCDeviceUtils.cuh"

#if defined(__HIP_PLATFORM_HCC__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

// Collection of in-kernel scan / prefix sum utilities

// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
Expand Down Expand Up @@ -157,7 +163,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);

int warp = threadIdx.x / 32;
int warp = threadIdx.x / WARP_SIZE;

// Per each warp, write out a value
if (getLaneId() == 0) {
Expand All @@ -170,7 +176,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
// warp shuffle scan for CC 3.0+
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / 32; ++i) {
for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) {
T v = smem[i];
smem[i] = binop(smem[i], current);
current = binop(current, v);
Expand Down Expand Up @@ -201,11 +207,13 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
*out -= (T) in;

// The outgoing carry for all threads is the last warp's sum
*carry = smem[(blockDim.x / 32) - 1];
*carry = smem[(blockDim.x / WARP_SIZE) - 1];

if (KillWARDependency) {
__syncthreads();
}
}

#undef WARP_SIZE

#endif // THC_SCAN_UTILS_INC
8 changes: 8 additions & 0 deletions aten/src/THC/THCTensorTopK.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ __device__ DataType findPattern(DataType* smem,
IndexType withinSliceStride,
BitDataType desired,
BitDataType desiredMask) {
#ifdef __HIP_PLATFORM_HCC__
if (threadIdx.x < 64) {
#else
if (threadIdx.x < 32) {
#endif
smem[threadIdx.x] = ScalarConvert<int, DataType>::to(0);
}
__syncthreads();
Expand Down Expand Up @@ -366,7 +370,11 @@ __global__ void gatherTopK(TensorInfo<T, IndexType> input,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
#ifdef __HIP_PLATFORM_HCC__
__shared__ int smem[64];
#else
__shared__ int smem[32]; // one per each warp, up to warp limit
#endif

IndexType slice = getLinearBlockId<IndexType>();
if (slice >= numInputSlices) {
Expand Down
23 changes: 15 additions & 8 deletions aten/src/THC/generic/THCTensorTopK.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ THC_API void THCTensor_(topk)(THCState* state,
gatherTopK<real, INDEX_T, DIM, DIR> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
inputInfo, \
sliceSize, \
k, \
inputSlices, \
static_cast<INDEX_T>(sliceSize), \
static_cast<INDEX_T>(k), \
static_cast<INDEX_T>(inputSlices), \
/* The actual dimension that the k-selection is running in */ \
/* may have changed from collapseDims() */ \
inputInfo.strides[collapseInputDim], \
static_cast<INDEX_T>(inputInfo.strides[collapseInputDim]), \
topKInfo, \
topKSlices, \
topKInfo.strides[collapseTopKDim], \
static_cast<INDEX_T>(topKSlices), \
static_cast<INDEX_T>(topKInfo.strides[collapseTopKDim]), \
indicesInfo, \
indicesInfo.strides[collapseIndicesDim])
static_cast<INDEX_T>(indicesInfo.strides[collapseIndicesDim]))

#define RUN_DIR(INDEX_T, DIM) \
if (dir) { \
Expand All @@ -63,6 +63,12 @@ THC_API void THCTensor_(topk)(THCState* state,
RUN_DIR(INDEX_T, -1); \
}

#ifdef __HIP_PLATFORM_HCC__
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

#define RUN_T(INDEX_T) \
TensorInfo<real, INDEX_T> inputInfo = \
getTensorInfo<real, THCTensor, INDEX_T>(state, input); \
Expand Down Expand Up @@ -96,7 +102,7 @@ THC_API void THCTensor_(topk)(THCState* state,
THError("Slice to sort is too large"); \
} \
\
dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) 32), (int64_t) 1024)); \
dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) WARP_SIZE), (int64_t) 1024)); \
\
/* This is used as a template parameter to calculate indices. */ \
/* We only specialize it if all collapsed dim sizes are the */ \
Expand Down Expand Up @@ -124,6 +130,7 @@ THC_API void THCTensor_(topk)(THCState* state,
#undef RUN_DIM
#undef RUN_DIR
#undef RUN_K
#undef WARP_SIZE

// Sort the results if the user wants them sorted, since our
// selection routine does not ensure sorting
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3385,7 +3385,6 @@ def test_topk_arguments(self):
self.assertRaises(TypeError, lambda: q.topk(4, True))

@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
@skipIfRocm
def test_topk_noncontiguous_gpu(self):
t = torch.randn(20, device="cuda")[::2]
top1, idx1 = t.topk(5)
Expand Down