Skip to content

topk and sort fixes #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions aten/src/THC/THCAsmUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

unsigned int m = (1u << len) - 1u;
m <<= pos;
return val & m;
return (val >> pos) & m;
#else
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
Expand All @@ -27,8 +26,8 @@ struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

unsigned int m = (1u << len) - 1u;
toInsert &= m;
Expand All @@ -50,12 +49,11 @@ struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t getBitfield(uint64_t val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

uint64_t m = (1u << len) - 1u;
m <<= pos;
return val & m;
return (val >> pos) & m;
#else
uint64_t ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
Expand All @@ -66,8 +64,8 @@ struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

uint64_t m = (1u << len) - 1u;
toInsert &= m;
Expand Down Expand Up @@ -105,16 +103,18 @@ __device__ __forceinline__ unsigned getLaneMaskLt() {
#endif
}

__device__ __forceinline__ unsigned getLaneMaskLe() {
#if defined(__HIP_PLATFORM_HCC__)
std::uint64_t m = (1ull << (getLaneId() + 1ull)) - 1ull;
#if defined (__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
return m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskLe() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask;
#endif
}
#endif

__device__ __forceinline__ unsigned getLaneMaskGt() {
#if defined(__HIP_PLATFORM_HCC__)
Expand Down
7 changes: 7 additions & 0 deletions aten/src/THC/THCDeviceUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK()
#endif
}

#if defined(__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
{
return __ballot(predicate);
}
#else
__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
Expand All @@ -52,6 +58,7 @@ __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int
return __ballot(predicate);
#endif
}
#endif

#ifdef __HIP_PLATFORM_HCC__
//To handle ambiguity, add a type double version.
Expand Down
8 changes: 7 additions & 1 deletion aten/src/THC/THCScanUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,15 @@ __device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunct
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
// Within-warp, we use warp voting.
#if defined (__HIP_PLATFORM_HCC__)
unsigned long long int vote = WARP_BALLOT(in);
T index = __popcll(getLaneMaskLe() & vote);
T carry = __popcll(vote);
#else
T vote = WARP_BALLOT(in);
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);
#endif

int warp = threadIdx.x / SCAN_UTILS_WARP_SIZE;

Expand Down Expand Up @@ -207,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
*out -= (T) in;

// The outgoing carry for all threads is the last warp's sum
*carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1];
*carry = smem[THCCeilDiv<int>(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1];

if (KillWARDependency) {
__syncthreads();
Expand Down
4 changes: 4 additions & 0 deletions aten/src/THC/THCTensorTopK.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize],
#pragma unroll
for (unsigned int j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
#if defined (__HIP_PLATFORM_HCC__)
counts[j] += __popcll(WARP_BALLOT(vote));
#else
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
#endif
}
}

Expand Down
5 changes: 5 additions & 0 deletions aten/src/THC/generic/THCTensorSort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,12 @@ void THCTensor_(sort)(THCState* state,
int maxSliceSize = 2048;
#endif

#ifdef __HIP_PLATFORM_HCC__
// TODO bitonicSortKVInPlace hangs on ROCm currently.
if (0) {
#else
if (sliceSize <= maxSliceSize) {
#endif
// Fill `indices` (the values) with the
// slice-relative index.
THCudaLongTensor_fillSliceWithIndex(state, indices, dim);
Expand Down
5 changes: 5 additions & 0 deletions aten/src/THC/generic/THCTensorTopK.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ void THCTensor_(topk)(THCState* state,
if (sorted) {
// FIXME: the k/v inplace sort along slice only works for size <=
// 2048 at the moment
#ifdef __HIP_PLATFORM_HCC__
// TODO bitonicSortKVInPlace hangs on ROCm currently.
if (0) {
#else
if (sliceSize <= 2048) {
#endif
// This avoids any memory allocations and performs all sorting
// work inplace along the slice
THCTensor_(sortKeyValueInplace)(state, topK, indices, dim, dir);
Expand Down
1 change: 1 addition & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6577,6 +6577,7 @@ def test_tensor_shape_empty(self):
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])

# functions that operate over a dimension but don't reduce.
@skipIfRocm
def test_dim_function_empty(self):
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for device in devices:
Expand Down