Skip to content

Commit c3ba1e8

Browse files
[release/2.5] [ROCm] fastSpecializedAtomicAdd for MI300 (pytorch#135770) (#1746)
MI300 adds HW support for packed bfloat16 and fp16. Enable via existing fastSpecializedAtomicAdd. Helps with improving [torch.scatter_add_ performance](https://ontrack-internal.amd.com/browse/SWDEV-497013), among others. Pull Request resolved: pytorch#135770 Co-authored-by: Jeff Daily <[email protected]>
1 parent 5d212c2 commit c3ba1e8

File tree

1 file changed

+91
-12
lines changed

1 file changed

+91
-12
lines changed

aten/src/ATen/native/cuda/KernelUtils.cuh

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,75 @@
55
#include <cuda_bf16.h>
66
#endif
77

8-
namespace at {
9-
namespace native {
8+
// ROCm 6.3 is planned to have these functions, but until then here they are.
9+
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
10+
#include <hip/hip_bf16.h>
11+
#include <hip/hip_fp16.h>
12+
13+
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
14+
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
15+
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
16+
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
17+
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
18+
union {
19+
__hip_bfloat162_raw bf162_raw;
20+
vec_short2 vs2;
21+
} u{static_cast<__hip_bfloat162_raw>(value)};
22+
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
23+
return static_cast<__hip_bfloat162>(u.bf162_raw);
24+
#else
25+
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
26+
union u_hold {
27+
__hip_bfloat162_raw h2r;
28+
unsigned int u32;
29+
};
30+
u_hold old_val, new_val;
31+
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
32+
do {
33+
new_val.h2r = __hadd2(old_val.h2r, value);
34+
} while (!__hip_atomic_compare_exchange_strong(
35+
(unsigned int*)address, &old_val.u32, new_val.u32,
36+
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
37+
return old_val.h2r;
38+
#endif
39+
}
40+
41+
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
42+
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
43+
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
44+
// The api expects an ext_vector_type of half
45+
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
46+
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
47+
union {
48+
__half2_raw h2r;
49+
vec_fp162 fp16;
50+
} u {static_cast<__half2_raw>(value)};
51+
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
52+
return static_cast<__half2>(u.h2r);
53+
#else
54+
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
55+
union u_hold {
56+
__half2_raw h2r;
57+
unsigned int u32;
58+
};
59+
u_hold old_val, new_val;
60+
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
61+
do {
62+
new_val.h2r = __hadd2(old_val.h2r, value);
63+
} while (!__hip_atomic_compare_exchange_strong(
64+
(unsigned int*)address, &old_val.u32, new_val.u32,
65+
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
66+
return old_val.h2r;
67+
#endif
68+
}
69+
#define ATOMICADD preview_unsafeAtomicAdd
70+
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
71+
#else
72+
#define ATOMICADD atomicAdd
73+
#define NATIVE_ZERO_BF16 __int2bfloat16_rz(0)
74+
#endif
75+
76+
namespace at:: native {
1077

1178
__device__ __forceinline__ size_t
1279
idx(const size_t nc,
@@ -48,7 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
48115
const index_t numel,
49116
scalar_t value) {
50117
#if ( \
51-
(defined(USE_ROCM)) || \
118+
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
52119
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
53120
gpuAtomicAddNoReturn(
54121
reinterpret_cast<at::Half*>(tensor) + index,
@@ -62,17 +129,22 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
62129
__half2 value2;
63130
value2.x = static_cast<__half>(value);
64131
value2.y = __int2half_rz(0);
65-
atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
132+
ATOMICADD(reinterpret_cast<__half2*>(target_addr), value2);
66133

67134
} else if (!low_byte && index > 0) {
68135
__half2 value2;
69136
value2.x = __int2half_rz(0);
70137
value2.y = static_cast<__half>(value);
71-
atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
138+
ATOMICADD(reinterpret_cast<__half2*>(target_addr - 1), value2);
72139

73140
} else {
141+
#ifdef USE_ROCM
142+
gpuAtomicAddNoReturn(
143+
reinterpret_cast<at::Half*>(tensor) + index, static_cast<at::Half>(value));
144+
#else
74145
atomicAdd(
75146
reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
147+
#endif
76148
}
77149
#endif
78150
}
@@ -88,7 +160,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
88160
const index_t numel,
89161
scalar_t value) {
90162
#if ( \
91-
(defined(USE_ROCM)) || \
163+
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
92164
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
93165
gpuAtomicAddNoReturn(
94166
reinterpret_cast<at::BFloat16*>(tensor) + index,
@@ -101,18 +173,23 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
101173
if (low_byte && index < (numel - 1)) {
102174
__nv_bfloat162 value2;
103175
value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
104-
value2.y = __int2bfloat16_rz(0);
105-
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
176+
value2.y = NATIVE_ZERO_BF16;
177+
ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
106178

107179
} else if (!low_byte && index > 0) {
108180
__nv_bfloat162 value2;
109-
value2.x = __int2bfloat16_rz(0);
181+
value2.x = NATIVE_ZERO_BF16;
110182
value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
111-
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
183+
ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
112184

113185
} else {
186+
#ifdef USE_ROCM
187+
gpuAtomicAddNoReturn(
188+
reinterpret_cast<at::BFloat16*>(tensor) + index, static_cast<at::BFloat16>(value));
189+
#else
114190
atomicAdd(
115191
reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
192+
#endif
116193
}
117194
#endif
118195
}
@@ -145,5 +222,7 @@ __device__ __forceinline__ void fastAtomicAdd(
145222
}
146223
}
147224

148-
} // namespace native
149-
} // namespace at
225+
#undef ATOMICADD
226+
#undef NATIVE_ZERO_BF16
227+
228+
} // namespace at::native

0 commit comments

Comments
 (0)