Skip to content

Commit 1428e18

Browse files
author
ZhouYu
committed
musa: enable MMA
1 parent f0dd6a1 commit 1428e18

File tree

3 files changed

+53
-16
lines changed

3 files changed

+53
-16
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ typedef float2 dfloat2;
215215
#define FP16_MMA_AVAILABLE
216216
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
217217

218+
#if defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
219+
#define FP16_MMA_AVAILABLE
220+
#endif // defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
221+
218222
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
219223
#define NEW_MMA_AVAILABLE
220224
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -237,7 +241,7 @@ static bool fast_fp16_available(const int cc) {
237241

238242
// To be used for feature selection of external libraries, e.g. cuBLAS.
239243
static bool fast_fp16_hardware_available(const int cc) {
240-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
244+
return cc >= GGML_CUDA_CC_PASCAL && cc != 610 && cc != GGML_CUDA_CC_QY1;
241245
}
242246

243247
// Any FP16 tensor core instructions are available for ggml code.
@@ -246,13 +250,15 @@ static bool fp16_mma_available(const int cc) {
246250
return false;
247251
#else
248252
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
253+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
249254
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
250255
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
251256
}
252257

253258
// To be used for feature selection of external libraries, e.g. cuBLAS.
254259
static bool fp16_mma_hardware_available(const int cc) {
255260
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
261+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
256262
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
257263
}
258264

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,13 +1851,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18511851
// use cublasGemmBatchedEx
18521852
const int ne23 = ne12*ne13;
18531853

1854+
#ifdef GGML_USE_MUSA
1855+
const void ** ptrs_src;
1856+
void ** ptrs_dst;
1857+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
1858+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
1859+
#else // GGML_USE_MUSA
18541860
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
18551861
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1862+
#endif // GGML_USE_MUSA
18561863

18571864
dim3 block_dims(ne13, ne12);
18581865
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
18591866
src0_f16, src1_f16, dst_t,
1867+
#ifdef GGML_USE_MUSA
1868+
ptrs_src, ptrs_dst,
1869+
#else // GGML_USE_MUSA
18601870
ptrs_src.get(), ptrs_dst.get(),
1871+
#endif // GGML_USE_MUSA
18611872
ne12, ne13,
18621873
ne23,
18631874
nb02, nb03,
@@ -1867,15 +1878,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18671878
r2, r3);
18681879
CUDA_CHECK(cudaGetLastError());
18691880

1870-
CUBLAS_CHECK(
1881+
#ifdef GGML_USE_MUSA
1882+
cudaDeviceSynchronize();
1883+
const void **Aarray = (const void **) (ptrs_src + 0 * ne23);
1884+
const void **Barray = (const void **) (ptrs_src + 1 * ne23);
1885+
void **Carray = (void **) (ptrs_dst + 0 * ne23);
1886+
#else // GGML_USE_MUSA
1887+
const void **Aarray = (const void **) (ptrs_src.get() + 0 * ne23);
1888+
const void **Barray = (const void **) (ptrs_src.get() + 1 * ne23);
1889+
void **Carray = (void **) (ptrs_dst.get() + 0 * ne23);
1890+
#endif // GGML_USE_MUSA
1891+
1892+
CUBLAS_CHECK(
18711893
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18721894
ne01, ne11, ne10,
1873-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1874-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
1875-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
1895+
alpha, Aarray, CUDA_R_16F, nb01/nb00,
1896+
Barray, CUDA_R_16F, nb11/nb10,
1897+
beta, Carray, cu_data_type, ne01,
18761898
ne23,
18771899
cu_compute_type,
18781900
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1901+
1902+
#ifdef GGML_USE_MUSA
1903+
CUDA_CHECK(cudaFree(ptrs_src));
1904+
CUDA_CHECK(cudaFree(ptrs_dst));
1905+
#endif // GGML_USE_MUSA
18791906
}
18801907
#endif
18811908

@@ -3011,12 +3038,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30113038
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
30123039
return false;
30133040
}
3014-
#ifdef GGML_USE_MUSA
3015-
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3016-
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3017-
return false;
3018-
}
3019-
#endif // GGML_USE_MUSA
3041+
// #ifdef GGML_USE_MUSA
3042+
// if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3043+
// !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3044+
// return false;
3045+
// }
3046+
// #endif // GGML_USE_MUSA
30203047
switch (a->type) {
30213048
case GGML_TYPE_F32:
30223049
case GGML_TYPE_F16:
@@ -3041,11 +3068,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30413068
case GGML_TYPE_IQ4_NL:
30423069
case GGML_TYPE_IQ4_XS:
30433070
case GGML_TYPE_BF16:
3044-
#ifdef GGML_USE_MUSA
3045-
if (a->type == GGML_TYPE_Q3_K) {
3046-
return false;
3047-
}
3048-
#endif // GGML_USE_MUSA
3071+
// #ifdef GGML_USE_MUSA
3072+
// if (a->type == GGML_TYPE_Q3_K) {
3073+
// return false;
3074+
// }
3075+
// #endif // GGML_USE_MUSA
30493076
return true;
30503077
default:
30513078
return false;

0 commit comments

Comments
 (0)