Skip to content

Commit b992719

Browse files
Add ROCm support
Co-authored-by: [ ] <[email protected]>
1 parent 45de2b5 commit b992719

File tree

9 files changed

+70
-3
lines changed

9 files changed

+70
-3
lines changed

cuda_ext.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def find_msvc():
5353
os.path.join(library_dir, "exllama_ext/cuda_func/q4_mlp.cu"),
5454
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp")
5555
],
56+
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
5657
verbose = verbose,
57-
extra_ldflags = ["cublas.lib"] if windows else []
58+
extra_ldflags = ["cublas.lib"] if windows else [],
59+
extra_cuda_cflags = ["-U__HIP_NO_HALF_CONVERSIONS__"] if torch.version.hip else []
5860
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
5961
)
6062

exllama_ext/cuda_compat.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
4141

4242
//
4343

44-
#ifdef __CUDA_ARCH__
45-
#if __CUDA_ARCH__ < 700
44+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
45+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
4646

4747
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
4848
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }

exllama_ext/cuda_func/half_matmul.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include "../util.cuh"
33
#include "../matrix.cuh"
44
#include "../cuda_compat.cuh"
5+
#if defined(USE_ROCM)
6+
#include "../hip_compat.cuh"
7+
#endif
58

69
// Block size
710

exllama_ext/cuda_func/half_matmul.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
#include <cstdint>
77
#include <ATen/cuda/CUDAContext.h>
88

9+
// Workaround for hipify_python using rocblas instead of hipblas.
10+
#if defined(USE_ROCM)
11+
#include <hipblas/hipblas.h>
12+
#define rocblas_handle hipblasHandle_t
13+
#endif
14+
915
void half_matmul_cuda
1016
(
1117
const half* x,

exllama_ext/cuda_func/q4_matmul.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "../matrix.cuh"
55
#include "../cuda_compat.cuh"
66
#include "../cuda_buffers.cuh"
7+
#if defined(USE_ROCM)
8+
#include "../hip_compat.cuh"
9+
#endif
710

811
const int THREADS_X = 32; // Block size and thread count along columns in w and out
912
const int THREADS_Y = 1; // Block size and thread count along rows in x and out

exllama_ext/cuda_func/q4_matmul.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
#include "q4_matrix.cuh"
1111
#include "../tuning.h"
1212

13+
// Workaround for hipify_python using rocblas instead of hipblas.
14+
#if defined(USE_ROCM)
15+
#include <hipblas/hipblas.h>
16+
#define rocblas_handle hipblasHandle_t
17+
#endif
18+
1319
void q4_matmul_cuda
1420
(
1521
ExLlamaTuning* tuningParams,

exllama_ext/cuda_func/q4_mlp.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "../cuda_buffers.cuh"
55
#include "../util.cuh"
66
#include "../matrix.cuh"
7+
#if defined(USE_ROCM)
8+
#include "../hip_compat.cuh"
9+
#endif
710

811
const int THREADS_X = 32;
912
const int THREADS_Y = 4;

exllama_ext/hip_compat.cuh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef _hip_compat_cuh
2+
#define _hip_compat_cuh
3+
4+
// Workaround for a bug in hipamd, backported from upstream.
5+
__device__ __forceinline__ __half __compat_hrcp(__half x) {
6+
return __half_raw{
7+
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
8+
}
9+
10+
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
11+
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
12+
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
13+
}
14+
15+
#define hrcp __compat_hrcp
16+
#define h2rcp __compat_h2rcp
17+
18+
// Workaround for hipify_python using rocblas instead of hipblas.
19+
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
20+
hipblasOperation_t transA,
21+
hipblasOperation_t transB,
22+
int m,
23+
int n,
24+
int k,
25+
const half* alpha,
26+
const half* AP,
27+
int lda,
28+
const half* BP,
29+
int ldb,
30+
const half* beta,
31+
half* CP,
32+
int ldc) {
33+
return hipblasHgemm(handle, transA, transB, m, n, k, reinterpret_cast<const hipblasHalf*>(alpha), reinterpret_cast<const hipblasHalf*>(AP), lda, reinterpret_cast<const hipblasHalf*>(BP), ldb, reinterpret_cast<const hipblasHalf*>(beta), reinterpret_cast<hipblasHalf*>(CP), ldc);
34+
}
35+
36+
#define rocblas_handle hipblasHandle_t
37+
#define rocblas_operation_none HIPBLAS_OP_N
38+
#define rocblas_hgemm __compat_hipblasHgemm
39+
40+
#endif

exllama_ext/util.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
#include <cstdint>
77
#include <cstdio>
88

9+
#if defined(USE_ROCM)
10+
#define cudaUnspecified hipErrorUnknown
11+
#else
912
#define cudaUnspecified cudaErrorApiFailureBase
13+
#endif
1014

1115
// React to failure on return code != cudaSuccess
1216

0 commit comments

Comments
 (0)