Skip to content

Replace hipblas API with rocBLAS #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@ ENDIF()

IF(USE_ROCM)
### Link in the ROCm libraries BLAS / RNG.
FIND_LIBRARY(HIPBLAS_LIBRARY hipblas HINTS ${HIPBLAS_PATH}/lib)
FIND_LIBRARY(ROCBLAS_LIBRARY rocblas HINTS ${ROCBLAS_PATH}/lib)
FIND_LIBRARY(HIPRAND_LIBRARY hiprand HINTS ${HIPRAND_PATH}/lib)

list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${HIPBLAS_LIBRARY} ${HIPRAND_LIBRARY})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${ROCBLAS_LIBRARY} ${HIPRAND_LIBRARY})
ENDIF()

# Include CPU paths for CUDA as well
Expand Down
52 changes: 48 additions & 4 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,17 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int6
#else
cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state);
if (prop->major >= 5){
#ifndef __HIP_PLATFORM_HCC__
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasGemmEx(handle, opa, opb,
#endif
THCublasCheck(cublasGemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc,
CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#ifndef __HIP_PLATFORM_HCC__
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
}else{
THCublasCheck(cublasSgemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
Expand Down Expand Up @@ -374,6 +378,25 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
}
#endif

#ifdef __HIP_PLATFORM_HCC__
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
float beta, float *c[], int64_t ldc, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;

THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);

}
#else
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
float beta, float *c[], int64_t ldc, int64_t batchCount)
Expand All @@ -395,8 +418,9 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
}
#endif

#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB,
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount)
Expand All @@ -421,6 +445,25 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i
}
#endif

#ifdef __HIP_PLATFORM_HCC__
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
double beta, double *c[], int64_t ldc, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;

THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);

}
#else
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
double beta, double *c[], int64_t ldc, int64_t batchCount)
Expand All @@ -442,8 +485,9 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
}
#endif

#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb,
THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
double beta, double *c[], int64_t ldc, int64_t batchCount);
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB,
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount);
Expand Down
2 changes: 2 additions & 0 deletions aten/src/THC/THCGeneral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,15 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
errmsg = "an absent device architectural feature is required";
break;

#ifndef __HIP_PLATFORM_HCC__
case CUBLAS_STATUS_MAPPING_ERROR:
errmsg = "an access to GPU memory space failed";
break;

case CUBLAS_STATUS_EXECUTION_FAILED:
errmsg = "the GPU program failed to execute";
break;
#endif

case CUBLAS_STATUS_INTERNAL_ERROR:
errmsg = "an internal operation failed";
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
// Compute pointers to matrices in each batch.
#if CUDA_VERSION < 8000
#if CUDA_VERSION < 8000 && !defined __HIP_PLATFORM_HCC__
size_t matrices_size = num_batches * sizeof(real*);

// Copy pointers to device.
Expand Down
4 changes: 2 additions & 2 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ if(BUILD_CAFFE2 OR BUILD_ATEN)
hip_include_directories(${Caffe2_HIP_INCLUDES})

set(Caffe2_HIP_DEPENDENCY_LIBS
${rocrand_LIBRARIES} ${hiprand_LIBRARIES} ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipblas_LIBRARIES})
${rocrand_LIBRARIES} ${hiprand_LIBRARIES} ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a ${rocblas_LIBRARIES} here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rest looks good afaict

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't, it adds rocBLAS a few lines below anyways (hipBLAS always relied on rocBLAS in the backend).

# Additional libraries required by PyTorch AMD that aren't used by Caffe2 (not in Caffe2's docker image)
if(BUILD_ATEN)
set(Caffe2_HIP_DEPENDENCY_LIBS ${Caffe2_HIP_DEPENDENCY_LIBS} ${hipsparse_LIBRARIES})
Expand All @@ -553,7 +553,7 @@ endif()
# ---[ ROCm
if(USE_ROCM AND NOT BUILD_CAFFE2)
include_directories(SYSTEM ${HIP_PATH}/include)
include_directories(SYSTEM ${HIPBLAS_PATH}/include)
include_directories(SYSTEM ${ROCBLAS_PATH}/include)
include_directories(SYSTEM ${HIPSPARSE_PATH}/include)
include_directories(SYSTEM ${HIPRAND_PATH}/include)
include_directories(SYSTEM ${ROCRAND_PATH}/include)
Expand Down
11 changes: 1 addition & 10 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ ELSE()
SET(HSA_PATH $ENV{HSA_PATH})
ENDIF()

# HIPBLAS_PATH
IF(NOT DEFINED ENV{HIPBLAS_PATH})
SET(HIPBLAS_PATH ${ROCM_PATH}/hipblas)
ELSE()
SET(HIPBLAS_PATH $ENV{HIPBLAS_PATH})
ENDIF()

# ROCBLAS_PATH
IF(NOT DEFINED ENV{ROCBLAS_PATH})
SET(ROCBLAS_PATH ${ROCM_PATH}/rocblas)
Expand Down Expand Up @@ -112,14 +105,13 @@ IF(HIP_FOUND)
set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
set(hipblas_DIR ${HIPBLAS_PATH}/lib/cmake/hipblas)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)

find_package(rocrand REQUIRED)
find_package(hiprand REQUIRED)
find_package(rocblas REQUIRED)
find_package(miopen REQUIRED)
#find_package(hipblas REQUIRED) There's a bug with the CMake file in the Hipblas package.
#find_package(hipsparse REQUIRED)

# TODO: hip_hcc has an interface include flag "-hc" which is only
Expand All @@ -131,7 +123,6 @@ IF(HIP_FOUND)
# however currently it's just the lib name
FIND_LIBRARY(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib)
FIND_LIBRARY(hiprand_LIBRARIES hiprand HINTS ${HIPRAND_PATH}/lib)
FIND_LIBRARY(hipblas_LIBRARIES hipblas HINTS ${HIPBLAS_PATH}/lib)
FIND_LIBRARY(hipsparse_LIBRARIES hipsparse HINTS ${HIPSPARSE_PATH}/lib)


Expand Down
1 change: 0 additions & 1 deletion docker/caffe2/jenkins/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ install_ubuntu() {
miopen-hip \
miopengemm \
rocblas \
hipblas \
rocm-profiler \
cxlactivitylogger

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,15 +882,15 @@ def run(self):
if USE_ROCM:
rocm_include_path = '/opt/rocm/include'
hcc_include_path = '/opt/rocm/hcc/include'
hipblas_include_path = '/opt/rocm/hipblas/include'
rocblas_include_path = '/opt/rocm/rocblas/include'
hipsparse_include_path = '/opt/rocm/hcsparse/include'
hiprand_include_path = '/opt/rocm/hiprand/include'
rocrand_include_path = '/opt/rocm/rocrand/include'
hip_lib_path = '/opt/rocm/hip/lib'
hcc_lib_path = '/opt/rocm/hcc/lib'
include_dirs.append(rocm_include_path)
include_dirs.append(hcc_include_path)
include_dirs.append(hipblas_include_path)
include_dirs.append(rocblas_include_path)
include_dirs.append(hipsparse_include_path)
include_dirs.append(hiprand_include_path)
include_dirs.append(rocrand_include_path)
Expand Down
17 changes: 7 additions & 10 deletions tools/amd_build/disabled_features.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
{
"path": "aten/src/THC/THCBlas.cu",
"functions": {
"cublasSgemmEx": "HIPBLAS_STATUS_INTERNAL_ERROR",
"cublasSgetrfBatched": "HIPBLAS_STATUS_INTERNAL_ERROR",
"cublasDgetrfBatched": "HIPBLAS_STATUS_INTERNAL_ERROR",
"cublasSgetrsBatched": "HIPBLAS_STATUS_INTERNAL_ERROR",
"cublasDgetrsBatched": "HIPBLAS_STATUS_INTERNAL_ERROR",
"cublasSgetriBatched": "HIPBLAS_STATUS_INTERNAL_ERROR",
"cublasDgetriBatched": "HIPBLAS_STATUS_INTERNAL_ERROR"
"cublasSgemmEx": "rocblas_status_internal_error",
"cublasSgetrfBatched": "rocblas_status_internal_error",
"cublasDgetrfBatched": "rocblas_status_internal_error",
"cublasSgetrsBatched": "rocblas_status_internal_error",
"cublasDgetrsBatched": "rocblas_status_internal_error",
"cublasSgetriBatched": "rocblas_status_internal_error",
"cublasDgetriBatched": "rocblas_status_internal_error"
},
"constants": {
"HIPBLAS_DATA_HALF": "0"
}
},
{
"path": "aten/src/THC/THCStream.cpp",
Expand Down
Loading