Skip to content

Commit 262b3c6

Browse files
authored
Merge pull request #93 from iotamudelta/batchnorm
Batchnorm
2 parents 8035a4b + e41b01f commit 262b3c6

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ namespace at { namespace native {
2828

2929
template<int nt, int vt, typename func_t>
3030
__launch_bounds__(nt, 4)
31-
#ifdef __HIP_PLATFORM_HCC__
32-
__global__ void elementwise_kernel(int N, const func_t& f) {
33-
#else
3431
__global__ void elementwise_kernel(int N, func_t f) {
35-
#endif
3632
int tid = threadIdx.x;
3733
int nv = nt * vt;
3834
int idx = nv * blockIdx.x + tid;

aten/src/THCUNN/BatchNormalization.cu

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,26 @@
77
#include "THCDeviceTensor.cuh"
88
#include "THCDeviceTensorUtils.cuh"
99
#include "THCDeviceUtils.cuh"
10+
#if defined(__HIP_PLATFORM_HCC__)
11+
const int WARP_SIZE = 64;
12+
#else
1013
const int WARP_SIZE = 32;
14+
#endif
1115

1216
// The maximum number of threads in a block
17+
#if defined(__HIP_PLATFORM_HCC__)
18+
const int MAX_BLOCK_SIZE = 256;
19+
#else
1320
const int MAX_BLOCK_SIZE = 512;
21+
#endif
1422

1523
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
1624
static int getNumThreads(int nElem) {
25+
#if defined(__HIP_PLATFORM_HCC__)
26+
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
27+
#else
1728
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
29+
#endif
1830
for (int i = 0; i != 5; ++i) {
1931
if (nElem <= threadSizes[i]) {
2032
return threadSizes[i];
@@ -116,7 +128,7 @@ __device__ T reduce(Op op, DeviceTensor3 tensor, int plane) {
116128
sum = warpSum(sum);
117129

118130
// 'transpose', and reduce within warp again
119-
__shared__ T shared[32];
131+
__shared__ T shared[WARP_SIZE];
120132
__syncthreads();
121133
if (threadIdx.x % WARP_SIZE == 0) {
122134
shared[threadIdx.x / WARP_SIZE] = sum;

aten/src/THCUNN/generic/BatchNormalization.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void THNN_(BatchNormalization_updateOutput)(
6464
dim3 blocks(input.getSize(1));
6565
dim3 threads(getNumThreads(input.getSize(2)));
6666
BatchNormalizationUpdateOutput_kernel<real, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
67-
input, output, weight, bias, eps, momentum, runningMean, runningVar,
67+
input, output, weight, bias, static_cast<accreal>(eps), static_cast<accreal>(momentum), runningMean, runningVar,
6868
saveMean, saveStd);
6969
}
7070
THCudaCheck(cudaGetLastError());

0 commit comments

Comments
 (0)