Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions aten/src/THC/THCAsmUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

unsigned int m = (1u << len) - 1u;
m <<= pos;
return val & m;
return (val >> pos) & m;
#else
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
Expand All @@ -27,8 +26,8 @@ struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

unsigned int m = (1u << len) - 1u;
toInsert &= m;
Expand All @@ -50,12 +49,11 @@ struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t getBitfield(uint64_t val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

uint64_t m = (1u << len) - 1u;
m <<= pos;
return val & m;
return (val >> pos) & m;
#else
uint64_t ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
Expand All @@ -66,8 +64,8 @@ struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0x1f;
len &= 0x1f;
pos &= 0xff;
len &= 0xff;

uint64_t m = (1u << len) - 1u;
toInsert &= m;
Expand Down Expand Up @@ -106,8 +104,8 @@ __device__ __forceinline__ unsigned getLaneMaskLt() {
}

#if defined (__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
std::uint64_t m = (1ull << (getLaneId() + 1ull)) - 1ull;
__device__ inline unsigned long long int getLaneMaskLe() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how __forceinline__ is now correctly defined, I'd like to keep it that way

std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
return m;
}
#else
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCScanUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
*out -= (T) in;

// The outgoing carry for all threads is the last warp's sum
*carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1];
*carry = smem[THCCeilDiv<int>(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1];

if (KillWARDependency) {
__syncthreads();
Expand Down