5
5
#include < cuda_bf16.h>
6
6
#endif
7
7
8
- namespace at {
9
- namespace native {
8
+ // ROCm 6.3 is planned to have these functions, but until then here they are.
9
+ #if defined(USE_ROCM) && ROCM_VERSION >= 60201
10
+ #include < hip/hip_bf16.h>
11
+ #include < hip/hip_fp16.h>
12
+
13
+ __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd (__hip_bfloat162* address, __hip_bfloat162 value) {
14
+ #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
15
+ __has_builtin (__builtin_amdgcn_flat_atomic_fadd_v2bf16)
16
+ typedef unsigned short __attribute__ ((ext_vector_type (2 ))) vec_short2;
17
+ static_assert (sizeof (vec_short2) == sizeof (__hip_bfloat162_raw));
18
+ union {
19
+ __hip_bfloat162_raw bf162_raw;
20
+ vec_short2 vs2;
21
+ } u{static_cast <__hip_bfloat162_raw>(value)};
22
+ u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16 ((vec_short2*)address, u.vs2 );
23
+ return static_cast <__hip_bfloat162>(u.bf162_raw );
24
+ #else
25
+ static_assert (sizeof (unsigned int ) == sizeof (__hip_bfloat162_raw));
26
+ union u_hold {
27
+ __hip_bfloat162_raw h2r;
28
+ unsigned int u32;
29
+ };
30
+ u_hold old_val, new_val;
31
+ old_val.u32 = __hip_atomic_load ((unsigned int *)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
32
+ do {
33
+ new_val.h2r = __hadd2 (old_val.h2r , value);
34
+ } while (!__hip_atomic_compare_exchange_strong (
35
+ (unsigned int *)address, &old_val.u32 , new_val.u32 ,
36
+ __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
37
+ return old_val.h2r ;
38
+ #endif
39
+ }
40
+
41
+ __device__ inline __half2 preview_unsafeAtomicAdd (__half2* address, __half2 value) {
42
+ #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
43
+ __has_builtin (__builtin_amdgcn_flat_atomic_fadd_v2f16)
44
+ // The api expects an ext_vector_type of half
45
+ typedef _Float16 __attribute__ ((ext_vector_type (2 ))) vec_fp162;
46
+ static_assert (sizeof (vec_fp162) == sizeof (__half2_raw));
47
+ union {
48
+ __half2_raw h2r;
49
+ vec_fp162 fp16;
50
+ } u {static_cast <__half2_raw>(value)};
51
+ u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16 ((vec_fp162*)address, u.fp16 );
52
+ return static_cast <__half2>(u.h2r );
53
+ #else
54
+ static_assert (sizeof (__half2_raw) == sizeof (unsigned int ));
55
+ union u_hold {
56
+ __half2_raw h2r;
57
+ unsigned int u32;
58
+ };
59
+ u_hold old_val, new_val;
60
+ old_val.u32 = __hip_atomic_load ((unsigned int *)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
61
+ do {
62
+ new_val.h2r = __hadd2 (old_val.h2r , value);
63
+ } while (!__hip_atomic_compare_exchange_strong (
64
+ (unsigned int *)address, &old_val.u32 , new_val.u32 ,
65
+ __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
66
+ return old_val.h2r ;
67
+ #endif
68
+ }
69
+ #define ATOMICADD preview_unsafeAtomicAdd
70
+ #define NATIVE_ZERO_BF16 __float2bfloat16 (0 .0f )
71
+ #else
72
+ #define ATOMICADD atomicAdd
73
+ #define NATIVE_ZERO_BF16 __int2bfloat16_rz (0 )
74
+ #endif
75
+
76
+ namespace at :: native {
10
77
11
78
__device__ __forceinline__ size_t
12
79
idx (const size_t nc,
@@ -48,7 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
48
115
const index_t numel,
49
116
scalar_t value) {
50
117
#if ( \
51
- (defined (USE_ROCM)) || \
118
+ (defined (USE_ROCM) && ROCM_VERSION < 60201 ) || \
52
119
(defined (__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 )))
53
120
gpuAtomicAddNoReturn (
54
121
reinterpret_cast <at::Half*>(tensor) + index ,
@@ -62,17 +129,22 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
62
129
__half2 value2;
63
130
value2.x = static_cast <__half>(value);
64
131
value2.y = __int2half_rz (0 );
65
- atomicAdd (reinterpret_cast <__half2*>(target_addr), value2);
132
+ ATOMICADD (reinterpret_cast <__half2*>(target_addr), value2);
66
133
67
134
} else if (!low_byte && index > 0 ) {
68
135
__half2 value2;
69
136
value2.x = __int2half_rz (0 );
70
137
value2.y = static_cast <__half>(value);
71
- atomicAdd (reinterpret_cast <__half2*>(target_addr - 1 ), value2);
138
+ ATOMICADD (reinterpret_cast <__half2*>(target_addr - 1 ), value2);
72
139
73
140
} else {
141
+ #ifdef USE_ROCM
142
+ gpuAtomicAddNoReturn (
143
+ reinterpret_cast <at::Half*>(tensor) + index , static_cast <at::Half>(value));
144
+ #else
74
145
atomicAdd (
75
146
reinterpret_cast <__half*>(tensor) + index , static_cast <__half>(value));
147
+ #endif
76
148
}
77
149
#endif
78
150
}
@@ -88,7 +160,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
88
160
const index_t numel,
89
161
scalar_t value) {
90
162
#if ( \
91
- (defined (USE_ROCM)) || \
163
+ (defined (USE_ROCM) && ROCM_VERSION < 60201 ) || \
92
164
(defined (__CUDA_ARCH__) && (__CUDA_ARCH__ < 800 )))
93
165
gpuAtomicAddNoReturn (
94
166
reinterpret_cast <at::BFloat16*>(tensor) + index ,
@@ -101,18 +173,23 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
101
173
if (low_byte && index < (numel - 1 )) {
102
174
__nv_bfloat162 value2;
103
175
value2.x = *reinterpret_cast <__nv_bfloat16*>(&value);
104
- value2.y = __int2bfloat16_rz ( 0 ) ;
105
- atomicAdd (reinterpret_cast <__nv_bfloat162*>(target_addr), value2);
176
+ value2.y = NATIVE_ZERO_BF16 ;
177
+ ATOMICADD (reinterpret_cast <__nv_bfloat162*>(target_addr), value2);
106
178
107
179
} else if (!low_byte && index > 0 ) {
108
180
__nv_bfloat162 value2;
109
- value2.x = __int2bfloat16_rz ( 0 ) ;
181
+ value2.x = NATIVE_ZERO_BF16 ;
110
182
value2.y = *reinterpret_cast <__nv_bfloat16*>(&value);
111
- atomicAdd (reinterpret_cast <__nv_bfloat162*>(target_addr - 1 ), value2);
183
+ ATOMICADD (reinterpret_cast <__nv_bfloat162*>(target_addr - 1 ), value2);
112
184
113
185
} else {
186
+ #ifdef USE_ROCM
187
+ gpuAtomicAddNoReturn (
188
+ reinterpret_cast <at::BFloat16*>(tensor) + index , static_cast <at::BFloat16>(value));
189
+ #else
114
190
atomicAdd (
115
191
reinterpret_cast <__nv_bfloat16*>(tensor) + index , *reinterpret_cast <__nv_bfloat16*>(&value));
192
+ #endif
116
193
}
117
194
#endif
118
195
}
@@ -145,5 +222,7 @@ __device__ __forceinline__ void fastAtomicAdd(
145
222
}
146
223
}
147
224
148
- } // namespace native
149
- } // namespace at
225
+ #undef ATOMICADD
226
+ #undef NATIVE_ZERO_BF16
227
+
228
+ } // namespace at::native
0 commit comments