Skip to content

Commit 0608d77

Browse files
committed
Review: fix formatting, remove useless type conversion, fix naming for bools
1 parent 87aeacf commit 0608d77

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18581858

18591859
const auto convert_func = traits::get_nc_converter(src1->type);
18601860
GGML_ASSERT(convert_func != nullptr);
1861-
convert_func((const void*)((const char*)src1->data), src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1861+
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
18621862
src1_ptr = src1_alloc.get();
18631863
s11 = ne10;
18641864
s12 = ne11*s11;
@@ -1919,7 +1919,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19191919
ne01, ne11, ne10,
19201920
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
19211921
src1_ptr, cu_data_type_b, s11, s12, // strideB
1922-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1922+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19231923
ne12*ne13,
19241924
cu_compute_type,
19251925
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1951,7 +1951,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19511951
ne01, ne11, ne10,
19521952
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
19531953
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1954-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1954+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
19551955
ne23,
19561956
cu_compute_type,
19571957
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -2030,10 +2030,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20302030
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20312031
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20322032

2033+
//TODO update for generic tensor parallelism
20332034
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2034-
bool can_use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2035-
bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2036-
bool can_use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2035+
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2036+
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2037+
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
20372038

20382039
if (!split && use_mul_mat_vec) {
20392040
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
@@ -2043,7 +2044,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20432044
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
20442045
} else if (!split && use_mul_mat_q) {
20452046
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
2046-
} else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32)
2047+
} else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
20472048
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
20482049
// general KQ + KQV multi-batch without FlashAttention
20492050
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);

0 commit comments

Comments
 (0)