Skip to content

Commit 38e72f8

Browse files
committed
Add ROCm support
1 parent 4585e2c commit 38e72f8

22 files changed

+105
-21
lines changed

cuda_ext.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,30 @@
1111
library_dir = os.path.dirname(os.path.abspath(__file__))
1212
extension_name = "exllama_ext"
1313

14+
if torch.version.hip:
15+
# FIXME: To build, I had to comment "flags += ['-fno-gpu-rdc']" in torch/utils/cpp_extension.py.
16+
# I am not sure if it's possible to find a way to build without editing that file.
17+
# If building without gpu-rdc, build will error with "lld: error: undefined hidden symbol: __llvm_amdgcn_rcp_f16".
18+
extra_cuda_cflags= ["-U__HIP_NO_HALF_CONVERSIONS__", "-fgpu-rdc"]
19+
else:
20+
extra_cuda_cflags = []
21+
1422
exllama_ext = load(
1523
name = extension_name,
1624
sources = [
1725
os.path.join(library_dir, "exllama_ext/cuda_buffers.cu"),
1826
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp"),
19-
os.path.join(library_dir, "exllama_ext/cuda_func/column_remap.cu"),
20-
os.path.join(library_dir, "exllama_ext/cuda_func/half_matmul.cu"),
21-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_matmul.cu"),
22-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_mlp.cu"),
23-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_recons.cu"),
24-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_sequential.cu"),
25-
os.path.join(library_dir, "exllama_ext/cuda_func/rms_norm.cu"),
26-
os.path.join(library_dir, "exllama_ext/cuda_func/rope.cu"),
27+
os.path.join(library_dir, "exllama_ext/cu_func/column_remap.cu"),
28+
os.path.join(library_dir, "exllama_ext/cu_func/half_matmul.cu"),
29+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_matmul.cu"),
30+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_mlp.cu"),
31+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_recons.cu"),
32+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_sequential.cu"),
33+
os.path.join(library_dir, "exllama_ext/cu_func/rms_norm.cu"),
34+
os.path.join(library_dir, "exllama_ext/cu_func/rope.cu"),
2735
os.path.join(library_dir, "exllama_ext/exllama_ext.cpp")
2836
],
37+
extra_cuda_cflags = extra_cuda_cflags
2938
# verbose = True,
3039
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
3140
)

exllama_ext/cuda_func/column_remap.cuh renamed to exllama_ext/cu_func/column_remap.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _column_remap_cuh
22
#define _column_remap_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t column_remap_cuda

exllama_ext/cuda_func/half_matmul.cu renamed to exllama_ext/cu_func/half_matmul.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ cudaError_t half_matmul_cublas_cuda
102102
const half alpha = __float2half(1.0f);
103103
const half beta = __float2half(0.0f);
104104

105-
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, w, width, x, dim, &beta, out, width);
105+
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, reinterpret_cast<const rocblas_half*>(&alpha), reinterpret_cast<const rocblas_half*>(w), width, reinterpret_cast<const rocblas_half*>(x), dim, reinterpret_cast<const rocblas_half*>(&beta), reinterpret_cast<rocblas_half*>(out), width);
106106

107107
// cudaDeviceSynchronize();
108108
// _cuda_check(cudaGetLastError());

exllama_ext/cuda_func/half_matmul.cuh renamed to exllama_ext/cu_func/half_matmul.cuh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
#ifndef _half_matmul_cuh
22
#define _half_matmul_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#include <rocblas/rocblas.h>
8+
#include <ATen/hip/HIPContext.h>
9+
#define cudaError_t hipError_t
10+
#define cublasHandle_t rocblas_handle
11+
#else
412
#include <cuda_runtime.h>
513
#include <cuda_fp16.h>
6-
#include <cstdint>
714
#include <ATen/cuda/CUDAContext.h>
15+
#endif
16+
#include <cstdint>
817

918
cudaError_t half_matmul_cuda
1019
(

exllama_ext/cuda_func/q4v2_matmul.cuh renamed to exllama_ext/cu_func/q4v2_matmul.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_matmul_cuh
22
#define _q4v2_matmul_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713
#include <cstdio>
814

File renamed without changes.

exllama_ext/cuda_func/q4v2_mlp.cuh renamed to exllama_ext/cu_func/q4v2_mlp.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_mlp_cuh
22
#define _q4v2_mlp_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t q4v2_mlp_cuda

exllama_ext/cuda_func/q4v2_recons.cuh renamed to exllama_ext/cu_func/q4v2_recons.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_recons_cuh
22
#define _q4v2_recons_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t q4v2_recons_cuda

exllama_ext/cuda_func/q4v2_sequential.cuh renamed to exllama_ext/cu_func/q4v2_sequential.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_sequential_cuh
22
#define _q4v2_sequential_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713
#include <cstdio>
814

File renamed without changes.

exllama_ext/cuda_func/rms_norm.cuh renamed to exllama_ext/cu_func/rms_norm.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _rms_norm_cuh
22
#define _rms_norm_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t rms_norm_cuda
File renamed without changes.

exllama_ext/cuda_func/rope.cuh renamed to exllama_ext/cu_func/rope.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _rope_cuh
22
#define _rope_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t rope_cuda

exllama_ext/cuda_buffers.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _cuda_buffers_cuh
22
#define _cuda_buffers_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713
#include <cstdio>
814

exllama_ext/cuda_compat.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
4141

4242
//
4343

44-
#ifdef __CUDA_ARCH__
45-
#if __CUDA_ARCH__ < 700
44+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
45+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
4646

4747
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
4848
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }

exllama_ext/exllama_ext.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#include "cpu_func/rep_penalty.h"
1010

1111
#include "cuda_buffers.cuh"
12-
#include "cuda_func/column_remap.cuh"
13-
#include "cuda_func/half_matmul.cuh"
14-
#include "cuda_func/q4v2_matmul.cuh"
15-
#include "cuda_func/q4v2_mlp.cuh"
16-
#include "cuda_func/q4v2_recons.cuh"
17-
#include "cuda_func/q4v2_sequential.cuh"
18-
#include "cuda_func/rms_norm.cuh"
19-
#include "cuda_func/rope.cuh"
12+
#include "cu_func/column_remap.cuh"
13+
#include "cu_func/half_matmul.cuh"
14+
#include "cu_func/q4v2_matmul.cuh"
15+
#include "cu_func/q4v2_mlp.cuh"
16+
#include "cu_func/q4v2_recons.cuh"
17+
#include "cu_func/q4v2_sequential.cuh"
18+
#include "cu_func/rms_norm.cuh"
19+
#include "cu_func/rope.cuh"
2020
#include "util.cuh"
2121

2222
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a

exllama_ext/matrix.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
#ifndef _matrix_cuh
22
#define _matrix_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#else
48
#include <cuda_runtime.h>
59
#include <cuda_fp16.h>
10+
#endif
611

712
class MatrixView_half
813
{

exllama_ext/util.cuh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
#ifndef _util_cuh
22
#define _util_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaDeviceSynchronize hipDeviceSynchronize
8+
#define cudaError_t hipError_t
9+
#define cudaMalloc hipMalloc
10+
#define cudaMemcpy hipMemcpy
11+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
12+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
13+
#define cudaSuccess hipSuccess
14+
#define cudaUnspecified hipErrorUnknown
15+
#else
416
#include <cuda_runtime.h>
517
#include <cuda_fp16.h>
18+
#define cudaUnspecified cudaErrorApiFailureBase
19+
#endif
620
#include <cstdint>
721
#include <cstdio>
822

9-
#define cudaUnspecified cudaErrorApiFailureBase
1023

1124
// React to failure on return code != cudaSuccess
1225

0 commit comments

Comments
 (0)