Skip to content

Commit 3f49dbf

Browse files
authored
fix bugs in syncbn (pytorch#46)
- incorrect use of __shfl_down - fix warp size assumptions - update unit tests to exit on failure
1 parent c1e88fa commit 3f49dbf

File tree

6 files changed

+68
-27
lines changed

6 files changed

+68
-27
lines changed

csrc/welford.cu

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "compat.h"
1313

1414
#if defined __HIP_PLATFORM_HCC__
15-
#define SHFL_DOWN __shfl_down
15+
#define SHFL_DOWN(mask,val,i) __shfl_down(val, i)
1616
#else
1717
#define SHFL_DOWN __shfl_down_sync
1818
#endif
@@ -44,8 +44,11 @@ __host__ __forceinline__ int h_last_pow2(unsigned int n) {
4444
return n - (n >> 1);
4545
}
4646

47-
47+
#ifdef __HIP_PLATFORM_HCC__
48+
#define WARP_SIZE 64
49+
#else
4850
#define WARP_SIZE 32
51+
#endif
4952

5053
template<typename T>
5154
__device__ __forceinline__ T warp_reduce_sum(T val)
@@ -61,25 +64,27 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
6164
{
6265
int tid = threadIdx.y*blockDim.x + threadIdx.x;
6366
int blockSize = blockDim.x * blockDim.y;
67+
int lane = tid % WARP_SIZE;
68+
int wid = tid / WARP_SIZE;
6469

65-
if (blockSize > 32) {
70+
if (blockSize > WARP_SIZE) {
6671
val = warp_reduce_sum(val);
67-
if (tid % WARP_SIZE == 0)
68-
x[tid/WARP_SIZE] = val;
72+
if (lane == 0)
73+
x[wid] = val;
6974

7075
__syncthreads();
7176

72-
val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));
77+
val = (tid < blockSize / WARP_SIZE? x[lane] : T(0));
7378
}
7479

75-
if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);
80+
if(wid==0) val = warp_reduce_sum(val);
7681

7782
return val;
7883
}
7984

8085
#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
8186
#define ELEMENTS_PER_THREAD 16
82-
#define OPTIMAL_TILE_W 32
87+
#define OPTIMAL_TILE_W WARP_SIZE
8388
#define MAX_H_BLOCK 128
8489
#define MAX_BLOCK_SIZE 512
8590

@@ -137,11 +142,7 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
137142
auto num_new = SHFL_DOWN(0xffffffff, num, i);
138143
auto mean_new = SHFL_DOWN(0xffffffff, mean, i);
139144
auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i);
140-
#if defined __HIP_PLATFORM_HCC__
141-
welford_merge_element<T, int>(num, mean, m2n, num_new, mean_new, m2n_new);
142-
#else
143145
welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
144-
#endif
145146
}
146147
}
147148

@@ -158,7 +159,7 @@ __device__ void welford_reduce_mean_m2n(
158159
int lane = thread_id % WARP_SIZE;
159160
int wid = thread_id / WARP_SIZE;
160161

161-
if (block_size > 32) {
162+
if (block_size > WARP_SIZE) {
162163
warp_reduce_mean_m2n(mean, m2n, num);
163164
if (lane == 0) {
164165
x[wid*2] = mean;
@@ -265,6 +266,9 @@ __device__ __forceinline__ void merge_block_vertical(T& sum_dy,
265266

266267
// welford kernel calculating mean/biased_variance/unbiased_variance
267268
template <typename scalar_t, typename accscalar_t, typename outscalar_t>
269+
#ifdef __HIP_PLATFORM_HCC__
270+
__launch_bounds__(MAX_BLOCK_SIZE)
271+
#endif
268272
__global__ void welford_kernel(
269273
const scalar_t* __restrict__ input,
270274
outscalar_t* __restrict__ out_mean,
@@ -291,8 +295,8 @@ __global__ void welford_kernel(
291295
}
292296
}
293297

294-
static __shared__ int s_mem[160];
295-
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
298+
static __shared__ int s_mem[WARP_SIZE];
299+
static __shared__ accscalar_t s_mem_ac[WARP_SIZE*2];
296300

297301
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
298302

@@ -304,6 +308,9 @@ __global__ void welford_kernel(
304308

305309
// elementwise BN kernel
306310
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
311+
#ifdef __HIP_PLATFORM_HCC__
312+
__launch_bounds__(MAX_BLOCK_SIZE)
313+
#endif
307314
__global__ void batchnorm_forward_kernel(
308315
const scalar_t* __restrict__ input,
309316
const accscalar_t* __restrict__ mean,
@@ -331,6 +338,9 @@ __global__ void batchnorm_forward_kernel(
331338
// Breaking the grad_input to two step to support sync BN, which requires all
332339
// reduce of the intermediate results across processes.
333340
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
341+
#ifdef __HIP_PLATFORM_HCC__
342+
__launch_bounds__(MAX_BLOCK_SIZE)
343+
#endif
334344
__global__ void reduce_bn_kernel(
335345
const scalar_t* __restrict__ input,
336346
const scalar_t* __restrict__ grad_output,
@@ -343,7 +353,7 @@ __global__ void reduce_bn_kernel(
343353
const int bs,
344354
const int fs,
345355
const int ss) {
346-
static __shared__ int s_mem[64];
356+
static __shared__ int s_mem[WARP_SIZE];
347357
//int total_item_num = bs * ss;
348358

349359
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
@@ -395,6 +405,9 @@ __global__ void reduce_bn_kernel(
395405

396406
// elementwise backward BN kernel
397407
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
408+
#ifdef __HIP_PLATFORM_HCC__
409+
__launch_bounds__(MAX_BLOCK_SIZE)
410+
#endif
398411
__global__ void batchnorm_backward_kernel(
399412
const scalar_t* __restrict__ grad_output,
400413
const scalar_t* __restrict__ input,
@@ -434,6 +447,9 @@ template
434447
typename accscalar_t,
435448
typename outscalar_t,
436449
int PARALLEL_LOADS>
450+
#ifdef __HIP_PLATFORM_HCC__
451+
__launch_bounds__(MAX_BLOCK_SIZE)
452+
#endif
437453
__global__ void
438454
welford_kernel_c_last(
439455
const scalar_t* __restrict__ input,
@@ -575,6 +591,9 @@ welford_kernel_c_last(
575591
// parallel welford kernel to further reduce mean / biased_var
576592
// into mean / unbiased_var / inv_std across multiple processes.
577593
template <typename scalar_t>
594+
#ifdef __HIP_PLATFORM_HCC__
595+
__launch_bounds__(MAX_BLOCK_SIZE)
596+
#endif
578597
__global__ void welford_kernel_parallel(
579598
const scalar_t* __restrict__ mean,
580599
const scalar_t* __restrict__ var_biased,
@@ -608,6 +627,9 @@ template <
608627
typename accscalar_t,
609628
typename layerscalar_t,
610629
int PARALLEL_LOADS>
630+
#ifdef __HIP_PLATFORM_HCC__
631+
__launch_bounds__(MAX_BLOCK_SIZE)
632+
#endif
611633
__global__ void batchnorm_forward_c_last_kernel(
612634
const scalar_t* __restrict__ input,
613635
const scalar_t* __restrict__ z,
@@ -658,6 +680,9 @@ template <
658680
typename accscalar_t,
659681
typename layerscalar_t,
660682
int PARALLEL_LOADS>
683+
#ifdef __HIP_PLATFORM_HCC__
684+
__launch_bounds__(MAX_BLOCK_SIZE)
685+
#endif
661686
__global__ void relu_backward_c_last_kernel(
662687
const scalar_t* __restrict__ grad_output,
663688
const scalar_t* __restrict__ input,
@@ -708,6 +733,9 @@ template
708733
typename accscalar_t,
709734
typename layerscalar_t,
710735
int PARALLEL_LOADS>
736+
#ifdef __HIP_PLATFORM_HCC__
737+
__launch_bounds__(MAX_BLOCK_SIZE)
738+
#endif
711739
__global__ void reduce_bn_c_last_kernel(
712740
const scalar_t* __restrict__ input,
713741
const scalar_t* __restrict__ grad_output,
@@ -861,6 +889,9 @@ template <
861889
typename accscalar_t,
862890
typename layerscalar_t,
863891
int PARALLEL_LOADS>
892+
#ifdef __HIP_PLATFORM_HCC__
893+
__launch_bounds__(MAX_BLOCK_SIZE)
894+
#endif
864895
__global__ void batchnorm_backward_c_last_kernel(
865896
const scalar_t* __restrict__ grad_output,
866897
const scalar_t* __restrict__ input,
@@ -921,7 +952,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
921952
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
922953
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
923954

924-
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));
955+
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / WARP_SIZE));
925956
int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
926957
const dim3 block(block_x, block_y);
927958
const dim3 grid(feature_size);
@@ -957,7 +988,7 @@ at::Tensor batchnorm_forward_CUDA(
957988

958989
auto space_size = get_tensor_spatial_size(input);
959990

960-
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
991+
int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
961992
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
962993
const dim3 block(block_x, block_y);
963994
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
@@ -1030,7 +1061,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
10301061

10311062
auto space_size = get_tensor_spatial_size(input);
10321063

1033-
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));
1064+
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ WARP_SIZE));
10341065
int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
10351066
const dim3 block(block_x, block_y);
10361067
const dim3 grid(feature_size);
@@ -1097,7 +1128,7 @@ at::Tensor batchnorm_backward_CUDA(
10971128

10981129
auto space_size = get_tensor_spatial_size(input);
10991130

1100-
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
1131+
int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
11011132
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
11021133
const dim3 block(block_x, block_y);
11031134
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));

tests/distributed/synced_batchnorm/python_single_gpu_unit_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,4 @@ def compare(desc, inp1, inp2, error):
109109
else:
110110
print("*SBN single gpu failed*")
111111

112+
assert sbn_result

tests/distributed/synced_batchnorm/single_gpu_unit_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,6 @@ def compare(desc, inp1, inp2, error):
157157
print("====SBN channel last single gpu passed tests")
158158
else:
159159
print("*SBN channel last single gpu failed*")
160+
161+
assert sbn_result
162+
assert sbn_result_c_last

tests/distributed/synced_batchnorm/test_groups.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ def compare(desc, inp1, inp2, error):
6060
grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
6161
weight = np.random.randn(feature_size).astype(dtype)
6262
bias = np.random.randn(feature_size).astype(dtype)
63+
#count = torch.cuda.IntTensor([batch_size*space_size**2])
64+
count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)]
65+
count = torch.cuda.IntTensor(count)
6366

67+
print("--- count : " , count)
6468

6569
type_tensor = torch.cuda.FloatTensor
6670
if args.fp16:
@@ -153,7 +157,7 @@ def compare(desc, inp1, inp2, error):
153157
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
154158

155159
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
156-
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
160+
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu, count)
157161

158162
if args.local_rank == 0:
159163
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result

tests/distributed/synced_batchnorm/two_gpu_unit_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,5 @@ def compare(desc, inp1, inp2, error):
178178
print("====SBN two gpu passed tests")
179179
else:
180180
print("*SBN two gpu failed*")
181+
182+
assert sbn_result
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
python python_single_gpu_unit_test.py
2-
python single_gpu_unit_test.py
3-
python test_batchnorm1d.py
4-
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
5-
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16
6-
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex
1+
python python_single_gpu_unit_test.py || exit 1
2+
python single_gpu_unit_test.py || exit 1
3+
python test_batchnorm1d.py || exit 1
4+
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py || exit 1
5+
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 || exit 1
6+
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex || exit 1
77
#beware, you need a system with at least 4 gpus to test group_size<world_size
88
#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2

0 commit comments

Comments
 (0)