12
12
#include " compat.h"
13
13
14
14
#if defined __HIP_PLATFORM_HCC__
15
- #define SHFL_DOWN __shfl_down
15
+ #define SHFL_DOWN ( mask,val,i ) __shfl_down (val, i)
16
16
#else
17
17
#define SHFL_DOWN __shfl_down_sync
18
18
#endif
@@ -44,8 +44,11 @@ __host__ __forceinline__ int h_last_pow2(unsigned int n) {
44
44
return n - (n >> 1 );
45
45
}
46
46
47
-
47
+ #ifdef __HIP_PLATFORM_HCC__
48
+ #define WARP_SIZE 64
49
+ #else
48
50
#define WARP_SIZE 32
51
+ #endif
49
52
50
53
template <typename T>
51
54
__device__ __forceinline__ T warp_reduce_sum (T val)
@@ -61,25 +64,27 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
61
64
{
62
65
int tid = threadIdx .y *blockDim .x + threadIdx .x ;
63
66
int blockSize = blockDim .x * blockDim .y ;
67
+ int lane = tid % WARP_SIZE;
68
+ int wid = tid / WARP_SIZE;
64
69
65
- if (blockSize > 32 ) {
70
+ if (blockSize > WARP_SIZE ) {
66
71
val = warp_reduce_sum (val);
67
- if (tid % WARP_SIZE == 0 )
68
- x[tid/WARP_SIZE ] = val;
72
+ if (lane == 0 )
73
+ x[wid ] = val;
69
74
70
75
__syncthreads ();
71
76
72
- val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE ] : T (0 ));
77
+ val = (tid < blockSize / WARP_SIZE? x[lane ] : T (0 ));
73
78
}
74
79
75
- if (tid/WARP_SIZE ==0 ) val = warp_reduce_sum (val);
80
+ if (wid ==0 ) val = warp_reduce_sum (val);
76
81
77
82
return val;
78
83
}
79
84
80
85
#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
81
86
#define ELEMENTS_PER_THREAD 16
82
- #define OPTIMAL_TILE_W 32
87
+ #define OPTIMAL_TILE_W WARP_SIZE
83
88
#define MAX_H_BLOCK 128
84
89
#define MAX_BLOCK_SIZE 512
85
90
@@ -137,11 +142,7 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
137
142
auto num_new = SHFL_DOWN (0xffffffff , num, i);
138
143
auto mean_new = SHFL_DOWN (0xffffffff , mean, i);
139
144
auto m2n_new = SHFL_DOWN (0xffffffff , m2n, i);
140
- #if defined __HIP_PLATFORM_HCC__
141
- welford_merge_element<T, int >(num, mean, m2n, num_new, mean_new, m2n_new);
142
- #else
143
145
welford_merge_element (num, mean, m2n, num_new, mean_new, m2n_new);
144
- #endif
145
146
}
146
147
}
147
148
@@ -158,7 +159,7 @@ __device__ void welford_reduce_mean_m2n(
158
159
int lane = thread_id % WARP_SIZE;
159
160
int wid = thread_id / WARP_SIZE;
160
161
161
- if (block_size > 32 ) {
162
+ if (block_size > WARP_SIZE ) {
162
163
warp_reduce_mean_m2n (mean, m2n, num);
163
164
if (lane == 0 ) {
164
165
x[wid*2 ] = mean;
@@ -265,6 +266,9 @@ __device__ __forceinline__ void merge_block_vertical(T& sum_dy,
265
266
266
267
// welford kernel calculating mean/biased_variance/unbiased_variance
267
268
template <typename scalar_t , typename accscalar_t , typename outscalar_t >
269
+ #ifdef __HIP_PLATFORM_HCC__
270
+ __launch_bounds__ (MAX_BLOCK_SIZE)
271
+ #endif
268
272
__global__ void welford_kernel (
269
273
const scalar_t * __restrict__ input,
270
274
outscalar_t * __restrict__ out_mean,
@@ -291,8 +295,8 @@ __global__ void welford_kernel(
291
295
}
292
296
}
293
297
294
- static __shared__ int s_mem[160 ];
295
- accscalar_t * s_mem_ac = ( accscalar_t *) &s_mem[ 32 ];
298
+ static __shared__ int s_mem[WARP_SIZE ];
299
+ static __shared__ accscalar_t s_mem_ac[WARP_SIZE* 2 ];
296
300
297
301
welford_reduce_mean_m2n<accscalar_t >(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
298
302
@@ -304,6 +308,9 @@ __global__ void welford_kernel(
304
308
305
309
// elementwise BN kernel
306
310
template <typename scalar_t , typename accscalar_t , typename layerscalar_t >
311
+ #ifdef __HIP_PLATFORM_HCC__
312
+ __launch_bounds__ (MAX_BLOCK_SIZE)
313
+ #endif
307
314
__global__ void batchnorm_forward_kernel (
308
315
const scalar_t * __restrict__ input,
309
316
const accscalar_t * __restrict__ mean,
@@ -331,6 +338,9 @@ __global__ void batchnorm_forward_kernel(
331
338
// Breaking the grad_input to two step to support sync BN, which requires all
332
339
// reduce of the intermediate results across processes.
333
340
template <typename scalar_t , typename accscalar_t , typename layerscalar_t >
341
+ #ifdef __HIP_PLATFORM_HCC__
342
+ __launch_bounds__ (MAX_BLOCK_SIZE)
343
+ #endif
334
344
__global__ void reduce_bn_kernel (
335
345
const scalar_t * __restrict__ input,
336
346
const scalar_t * __restrict__ grad_output,
@@ -343,7 +353,7 @@ __global__ void reduce_bn_kernel(
343
353
const int bs,
344
354
const int fs,
345
355
const int ss) {
346
- static __shared__ int s_mem[64 ];
356
+ static __shared__ int s_mem[WARP_SIZE ];
347
357
// int total_item_num = bs * ss;
348
358
349
359
int thread_id = threadIdx .y *blockDim .x + threadIdx .x ;
@@ -395,6 +405,9 @@ __global__ void reduce_bn_kernel(
395
405
396
406
// elementwise backward BN kernel
397
407
template <typename scalar_t , typename accscalar_t , typename layerscalar_t >
408
+ #ifdef __HIP_PLATFORM_HCC__
409
+ __launch_bounds__ (MAX_BLOCK_SIZE)
410
+ #endif
398
411
__global__ void batchnorm_backward_kernel (
399
412
const scalar_t * __restrict__ grad_output,
400
413
const scalar_t * __restrict__ input,
@@ -434,6 +447,9 @@ template
434
447
typename accscalar_t ,
435
448
typename outscalar_t ,
436
449
int PARALLEL_LOADS>
450
+ #ifdef __HIP_PLATFORM_HCC__
451
+ __launch_bounds__ (MAX_BLOCK_SIZE)
452
+ #endif
437
453
__global__ void
438
454
welford_kernel_c_last (
439
455
const scalar_t * __restrict__ input,
@@ -575,6 +591,9 @@ welford_kernel_c_last(
575
591
// parallel welford kernel to further reduce mean / biased_var
576
592
// into mean / unbiased_var / inv_std across multiple processes.
577
593
template <typename scalar_t >
594
+ #ifdef __HIP_PLATFORM_HCC__
595
+ __launch_bounds__ (MAX_BLOCK_SIZE)
596
+ #endif
578
597
__global__ void welford_kernel_parallel (
579
598
const scalar_t * __restrict__ mean,
580
599
const scalar_t * __restrict__ var_biased,
@@ -608,6 +627,9 @@ template <
608
627
typename accscalar_t ,
609
628
typename layerscalar_t ,
610
629
int PARALLEL_LOADS>
630
+ #ifdef __HIP_PLATFORM_HCC__
631
+ __launch_bounds__ (MAX_BLOCK_SIZE)
632
+ #endif
611
633
__global__ void batchnorm_forward_c_last_kernel (
612
634
const scalar_t * __restrict__ input,
613
635
const scalar_t * __restrict__ z,
@@ -658,6 +680,9 @@ template <
658
680
typename accscalar_t ,
659
681
typename layerscalar_t ,
660
682
int PARALLEL_LOADS>
683
+ #ifdef __HIP_PLATFORM_HCC__
684
+ __launch_bounds__ (MAX_BLOCK_SIZE)
685
+ #endif
661
686
__global__ void relu_backward_c_last_kernel (
662
687
const scalar_t * __restrict__ grad_output,
663
688
const scalar_t * __restrict__ input,
@@ -708,6 +733,9 @@ template
708
733
typename accscalar_t ,
709
734
typename layerscalar_t ,
710
735
int PARALLEL_LOADS>
736
+ #ifdef __HIP_PLATFORM_HCC__
737
+ __launch_bounds__ (MAX_BLOCK_SIZE)
738
+ #endif
711
739
__global__ void reduce_bn_c_last_kernel (
712
740
const scalar_t * __restrict__ input,
713
741
const scalar_t * __restrict__ grad_output,
@@ -861,6 +889,9 @@ template <
861
889
typename accscalar_t ,
862
890
typename layerscalar_t ,
863
891
int PARALLEL_LOADS>
892
+ #ifdef __HIP_PLATFORM_HCC__
893
+ __launch_bounds__ (MAX_BLOCK_SIZE)
894
+ #endif
864
895
__global__ void batchnorm_backward_c_last_kernel (
865
896
const scalar_t * __restrict__ grad_output,
866
897
const scalar_t * __restrict__ input,
@@ -921,7 +952,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
921
952
at::Tensor out_var_biased = at::empty ({feature_size}, input.options ().dtype (scalar_type));
922
953
at::Tensor out_mean = at::empty ({feature_size}, input.options ().dtype (scalar_type));
923
954
924
- int block_y = min (h_last_pow2 (batch_size), int (MAX_BLOCK_SIZE / 32 ));
955
+ int block_y = min (h_last_pow2 (batch_size), int (MAX_BLOCK_SIZE / WARP_SIZE ));
925
956
int block_x = max (1 , min (MAX_BLOCK_SIZE / block_y, h_last_pow2 (space_size)));
926
957
const dim3 block (block_x, block_y);
927
958
const dim3 grid (feature_size);
@@ -957,7 +988,7 @@ at::Tensor batchnorm_forward_CUDA(
957
988
958
989
auto space_size = get_tensor_spatial_size (input);
959
990
960
- int block_x = max (32 , min (MAX_BLOCK_SIZE, h_last_pow2 (space_size)/4 ));
991
+ int block_x = max (WARP_SIZE , min (MAX_BLOCK_SIZE, h_last_pow2 (space_size)/4 ));
961
992
int block_y = max (1 , min (MAX_BLOCK_SIZE/block_x, h_last_pow2 (batch_size)/4 ));
962
993
const dim3 block (block_x, block_y);
963
994
int grid_z = max (1 , min (65535 , h_last_pow2 (space_size)/4 /block_x));
@@ -1030,7 +1061,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
1030
1061
1031
1062
auto space_size = get_tensor_spatial_size (input);
1032
1063
1033
- int block_y = min (h_last_pow2 (batch_size), int (MAX_BLOCK_SIZE/ 32 ));
1064
+ int block_y = min (h_last_pow2 (batch_size), int (MAX_BLOCK_SIZE/ WARP_SIZE ));
1034
1065
int block_x = max (1 , min (MAX_BLOCK_SIZE/ block_y, h_last_pow2 (space_size)));
1035
1066
const dim3 block (block_x, block_y);
1036
1067
const dim3 grid (feature_size);
@@ -1097,7 +1128,7 @@ at::Tensor batchnorm_backward_CUDA(
1097
1128
1098
1129
auto space_size = get_tensor_spatial_size (input);
1099
1130
1100
- int block_x = max (32 , min (MAX_BLOCK_SIZE, h_last_pow2 (space_size)/4 ));
1131
+ int block_x = max (WARP_SIZE , min (MAX_BLOCK_SIZE, h_last_pow2 (space_size)/4 ));
1101
1132
int block_y = max (1 , min (MAX_BLOCK_SIZE/block_x, h_last_pow2 (batch_size)/4 ));
1102
1133
const dim3 block (block_x, block_y);
1103
1134
int grid_z = max (1 , min (65535 , h_last_pow2 (space_size)/4 /block_x));
0 commit comments