Skip to content

Commit a38b572

Browse files
iotamudeltafacebook-github-bot
authored andcommitted
enable unit tests and other changes (pytorch#10266)
Summary: This PR for the ROCm target does the following: * enable some unit tests on ROCm * fix a missing static_cast that breaks BatchNorm call on ROCm * fix BatchNorm to work on ROCm w/ ROCm warp sizes etc * improve the pyhipify script by introducing kernel scope to some transpilations and other improvements * fix a linking issue on ROCm * for more unit test sets: mark currently broken tests broken (to be fixed) * enable THINLTO (phase one) to parallelize linking * address the first failing of the elementwise kernel by removing non-working ROCm specialization Pull Request resolved: pytorch#10266 Differential Revision: D9184178 Pulled By: ezyang fbshipit-source-id: 03bcd1fe4ca4dd3241f09634dbd42b6a4c350297
1 parent e0d4357 commit a38b572

File tree

19 files changed

+216
-35
lines changed

19 files changed

+216
-35
lines changed

.jenkins/pytorch/build.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cmake --version
3030
pip install -r requirements.txt || true
3131

3232
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
33+
# This is necessary in order to cross compile (or else we'll have missing GPU device).
3334
export MAX_JOBS=4
3435
# This is necessary in order to cross compile (or else we'll have missing GPU device).
3536
export HCC_AMDGPU_TARGET=gfx900
@@ -41,7 +42,12 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
4142

4243
# This environment variable enabled HCC Optimizations that speed up the linking stage.
4344
# https://github.com/RadeonOpenCompute/hcc#hcc-with-thinlto-linking
44-
# export KMTHINLTO=1
45+
export KMTHINLTO=1
46+
47+
# Need the libc++1 and libc++abi1 libraries to allow torch._C to load at runtime
48+
sudo apt-get install libc++1
49+
sudo apt-get install libc++abi1
50+
4551
python tools/amd_build/build_pytorch_amd.py
4652
USE_ROCM=1 python setup.py install --user
4753
exit 0

.jenkins/pytorch/disabled-configs.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,3 @@
33
# fail. You can use this to temporarily reserve a test name to
44
# turn on CI side before PyTorch repository supports it. This
55
# file has the same format as .jenkins/enabled-configs.txt
6-
7-
py2-clang3.8-rocm1.7.1-ubuntu16.04-test

.jenkins/pytorch/enabled-configs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ pytorch-docker-build-test
4141
short-perf-test-cpu
4242
short-perf-test-gpu
4343
py2-clang3.8-rocm1.7.1-ubuntu16.04-build
44+
py2-clang3.8-rocm1.7.1-ubuntu16.04-test

.jenkins/pytorch/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ test_python_all_except_nn() {
7474

7575
test_aten() {
7676
# Test ATen
77-
if [[ "$BUILD_ENVIRONMENT" != *asan* ]]; then
77+
if ([[ "$BUILD_ENVIRONMENT" != *asan* ]] && [[ "$BUILD_ENVIRONMENT" != *rocm* ]]); then
7878
echo "Running ATen tests with pytorch lib"
7979
TORCH_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/torch/lib
8080
# NB: the ATen test binaries don't have RPATH set, so it's necessary to
@@ -101,7 +101,7 @@ test_torchvision() {
101101
# this should be a transient requirement...)
102102
# See https://github.com/pytorch/pytorch/issues/7525
103103
#time python setup.py install
104-
pip install .
104+
pip install --user .
105105
popd
106106
}
107107

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ namespace at { namespace native {
2828

2929
template<int nt, int vt, typename func_t>
3030
__launch_bounds__(nt, 4)
31-
#ifdef __HIP_PLATFORM_HCC__
32-
__global__ void elementwise_kernel(int N, const func_t& f) {
33-
#else
3431
__global__ void elementwise_kernel(int N, func_t f) {
35-
#endif
3632
int tid = threadIdx.x;
3733
int nv = nt * vt;
3834
int idx = nv * blockIdx.x + tid;

aten/src/THCUNN/BatchNormalization.cu

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,26 @@
77
#include "THCDeviceTensor.cuh"
88
#include "THCDeviceTensorUtils.cuh"
99
#include "THCDeviceUtils.cuh"
10+
#if defined(__HIP_PLATFORM_HCC__)
11+
const int WARP_SIZE = 64;
12+
#else
1013
const int WARP_SIZE = 32;
14+
#endif
1115

1216
// The maximum number of threads in a block
17+
#if defined(__HIP_PLATFORM_HCC__)
18+
const int MAX_BLOCK_SIZE = 256;
19+
#else
1320
const int MAX_BLOCK_SIZE = 512;
21+
#endif
1422

1523
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
1624
static int getNumThreads(int nElem) {
25+
#if defined(__HIP_PLATFORM_HCC__)
26+
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
27+
#else
1728
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
29+
#endif
1830
for (int i = 0; i != 5; ++i) {
1931
if (nElem <= threadSizes[i]) {
2032
return threadSizes[i];
@@ -116,7 +128,7 @@ __device__ T reduce(Op op, DeviceTensor3 tensor, int plane) {
116128
sum = warpSum(sum);
117129

118130
// 'transpose', and reduce within warp again
119-
__shared__ T shared[32];
131+
__shared__ T shared[WARP_SIZE];
120132
__syncthreads();
121133
if (threadIdx.x % WARP_SIZE == 0) {
122134
shared[threadIdx.x / WARP_SIZE] = sum;

aten/src/THCUNN/generic/BatchNormalization.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void THNN_(BatchNormalization_updateOutput)(
6464
dim3 blocks(input.getSize(1));
6565
dim3 threads(getNumThreads(input.getSize(2)));
6666
BatchNormalizationUpdateOutput_kernel<real, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
67-
input, output, weight, bias, eps, momentum, runningMean, runningVar,
67+
input, output, weight, bias, static_cast<accreal>(eps), static_cast<accreal>(momentum), runningMean, runningVar,
6868
saveMean, saveStd);
6969
}
7070
THCudaCheck(cudaGetLastError());

caffe2/CMakeLists.txt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,24 @@ endif()
341341
# ---[ Caffe2 HIP sources.
342342
if(USE_ROCM)
343343
# Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs.
344-
IF(BUILD_ATEN)
345-
HIP_INCLUDE_DIRECTORIES(${Caffe2_HIP_INCLUDES})
346-
ENDIF()
344+
if(BUILD_ATEN)
345+
# Get Compile Definitions from the directory (FindHIP.CMake bug)
346+
get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS)
347+
if(MY_DEFINITIONS)
348+
foreach(_item ${MY_DEFINITIONS})
349+
LIST(APPEND HIP_HCC_FLAGS "-D${_item}")
350+
endforeach()
351+
endif()
352+
353+
# Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs.
354+
hip_include_directories(${Caffe2_HIP_INCLUDES})
355+
endif()
347356
IF(BUILD_CAFFE2)
348357
set_source_files_properties(${Caffe2_HIP_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
349358
ENDIF()
350-
hip_add_library(caffe2_hip ${Caffe2_HIP_SRCS})
359+
360+
# FindHIP.CMake checks if the SHARED flag is set and adds extra logic accordingly.
361+
hip_add_library(caffe2_hip SHARED ${Caffe2_HIP_SRCS})
351362

352363
# Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added.
353364
set_target_properties(caffe2_hip PROPERTIES COMPILE_FLAGS ${HIP_HIPCC_FLAGS})

cmake/public/LoadHIP.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ IF(HIP_FOUND)
111111
set(CMAKE_HIP_ARCHIVE_CREATE ${CMAKE_CXX_ARCHIVE_CREATE})
112112
set(CMAKE_HIP_ARCHIVE_APPEND ${CMAKE_CXX_ARCHIVE_APPEND})
113113
set(CMAKE_HIP_ARCHIVE_FINISH ${CMAKE_CXX_ARCHIVE_FINISH})
114+
SET(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
115+
SET(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
114116
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
115117

116118
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)

test/common.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,23 @@ def _check_module_exists(name):
9898
import numpy
9999

100100

101+
def skipIfRocm(fn):
102+
@wraps(fn)
103+
def wrapper(*args, **kwargs):
104+
if TEST_WITH_ROCM:
105+
raise unittest.SkipTest("test doesn't currently work on the ROCm stack")
106+
else:
107+
fn(*args, **kwargs)
108+
return wrapper
109+
110+
101111
def skipIfNoLapack(fn):
102112
@wraps(fn)
103113
def wrapper(*args, **kwargs):
104114
try:
105115
fn(*args, **kwargs)
106116
except Exception as e:
107-
if 'Lapack library not found' in e.args[0]:
117+
if 'Lapack library not found' in repr(e):
108118
raise unittest.SkipTest('Compiled without Lapack')
109119
raise
110120
return wrapper

0 commit comments

Comments
 (0)