@@ -1851,13 +1851,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1851
1851
// use cublasGemmBatchedEx
1852
1852
const int ne23 = ne12*ne13;
1853
1853
1854
+ #ifdef GGML_USE_MUSA
1855
+ const void ** ptrs_src;
1856
+ void ** ptrs_dst;
1857
+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (void *)*2 *ne23));
1858
+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (void *)*1 *ne23));
1859
+ #else // GGML_USE_MUSA
1854
1860
ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
1855
1861
ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1862
+ #endif // GGML_USE_MUSA
1856
1863
1857
1864
dim3 block_dims (ne13, ne12);
1858
1865
k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1859
1866
src0_f16, src1_f16, dst_t ,
1867
+ #ifdef GGML_USE_MUSA
1868
+ ptrs_src, ptrs_dst,
1869
+ #else // GGML_USE_MUSA
1860
1870
ptrs_src.get (), ptrs_dst.get (),
1871
+ #endif // GGML_USE_MUSA
1861
1872
ne12, ne13,
1862
1873
ne23,
1863
1874
nb02, nb03,
@@ -1867,15 +1878,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1867
1878
r2, r3);
1868
1879
CUDA_CHECK (cudaGetLastError ());
1869
1880
1870
- CUBLAS_CHECK (
1881
+ #ifdef GGML_USE_MUSA
1882
+ cudaDeviceSynchronize ();
1883
+ const void **Aarray = (const void **) (ptrs_src + 0 * ne23);
1884
+ const void **Barray = (const void **) (ptrs_src + 1 * ne23);
1885
+ void **Carray = (void **) (ptrs_dst + 0 * ne23);
1886
+ #else // GGML_USE_MUSA
1887
+ const void **Aarray = (const void **) (ptrs_src.get () + 0 * ne23);
1888
+ const void **Barray = (const void **) (ptrs_src.get () + 1 * ne23);
1889
+ void **Carray = (void **) (ptrs_dst.get () + 0 * ne23);
1890
+ #endif // GGML_USE_MUSA
1891
+
1892
+ CUBLAS_CHECK (
1871
1893
cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1872
1894
ne01, ne11, ne10,
1873
- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1874
- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, nb11/nb10,
1875
- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne01,
1895
+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1896
+ Barray , CUDA_R_16F, nb11/nb10,
1897
+ beta, Carray , cu_data_type, ne01,
1876
1898
ne23,
1877
1899
cu_compute_type,
1878
1900
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1901
+
1902
+ #ifdef GGML_USE_MUSA
1903
+ CUDA_CHECK (cudaFree (ptrs_src));
1904
+ CUDA_CHECK (cudaFree (ptrs_dst));
1905
+ #endif // GGML_USE_MUSA
1879
1906
}
1880
1907
#endif
1881
1908
@@ -3011,12 +3038,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3011
3038
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
3012
3039
return false ;
3013
3040
}
3014
- #ifdef GGML_USE_MUSA
3015
- if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
3016
- !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
3017
- return false ;
3018
- }
3019
- #endif // GGML_USE_MUSA
3041
+ // #ifdef GGML_USE_MUSA
3042
+ // if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3043
+ // !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3044
+ // return false;
3045
+ // }
3046
+ // #endif // GGML_USE_MUSA
3020
3047
switch (a->type ) {
3021
3048
case GGML_TYPE_F32:
3022
3049
case GGML_TYPE_F16:
@@ -3041,11 +3068,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3041
3068
case GGML_TYPE_IQ4_NL:
3042
3069
case GGML_TYPE_IQ4_XS:
3043
3070
case GGML_TYPE_BF16:
3044
- #ifdef GGML_USE_MUSA
3045
- if (a->type == GGML_TYPE_Q3_K) {
3046
- return false ;
3047
- }
3048
- #endif // GGML_USE_MUSA
3071
+ // #ifdef GGML_USE_MUSA
3072
+ // if (a->type == GGML_TYPE_Q3_K) {
3073
+ // return false;
3074
+ // }
3075
+ // #endif // GGML_USE_MUSA
3049
3076
return true ;
3050
3077
default :
3051
3078
return false ;
0 commit comments