From 2432b4dafd698d523f5c9fa2037f13942a0f4ebb Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 4 Sep 2025 09:29:48 -0600 Subject: [PATCH] feat(ggml-metal): Add template specialization for mul_mm_id w/ ne20 == 10 Branch: GGMLMetalNE20 Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 3 +++ ggml/src/ggml-metal/ggml-metal.metal | 1 + 2 files changed, 4 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9b4006d987c3b..c1a0a2bef171e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -407,6 +407,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, @@ -1439,6 +1440,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10, mul_mm_id_map0_f16_ne20_10, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); @@ -3979,6 +3981,7 @@ static int ggml_metal_encode_node( case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break; case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break; case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break; + case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break; case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break; default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9c5933d24a0e3..2d56c62674c8e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7618,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; template