Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,40 +320,31 @@ typedef struct {
} ggml_metal_kargs_mul_mv_ext;

typedef struct {
int32_t ne02;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
int32_t neh11; // n_tokens
uint64_t nbh11;
int32_t ne21; // n_tokens
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;

typedef struct {
int32_t ne20; // n_expert_used
int32_t neh0;
int32_t neh1;
uint64_t nbh1;
uint64_t nbh2;
int32_t ne0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_mul_mm_id_map1;

typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t neh12;
uint64_t nbh10;
uint64_t nbh11;
uint64_t nbh12;
uint64_t nbh13;
int32_t neh0;
int32_t neh1;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;
Expand Down
141 changes: 52 additions & 89 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
Expand Down Expand Up @@ -1428,8 +1432,12 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
Expand Down Expand Up @@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node(
default: break;
}

const int64_t neh10 = ne10; // n_embd
const int64_t neh11 = ne21; // n_tokens
const int64_t neh12 = ne02; // n_expert

const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
const uint64_t nbh11 = nbh10*neh10;
const uint64_t nbh12 = nbh11*neh11;
const uint64_t nbh13 = nbh12*neh12;

const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
if (!h_src1) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
return 0;
}

const int64_t neh0 = ne0;
const int64_t neh1 = ne21;
const int64_t neh2 = ne02;

const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
const uint64_t nbh1 = nbh0*neh0;
const uint64_t nbh2 = nbh1*neh1;
//const uint64_t nbh3 = nbh2*neh2;

const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
if (!h_dst) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
return 0;
}

// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
Expand All @@ -3949,41 +3925,54 @@ static int ggml_metal_encode_node(
}

// id map
// [n_expert_used, n_tokens]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
// [n_tokens, n_expert]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
return 0;
}

{
const int nth = MIN(1024, ne10/4);

ggml_metal_kargs_mul_mm_id_map0 args = {
ne02,
ne10,
ne11, // n_expert_used (bcast)
ne11, // n_expert_used (bcast)
nb11,
nb12,
neh11, // n_tokens
nbh11,
ne20, // n_expert_used
ne21, // n_tokens
ne20, // n_expert_used
nb21,
};

id<MTLComputePipelineState> pipeline = nil;

pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
pipeline = nil;

switch (ne20) {
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
}

GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);

const size_t smem = ne02*ne20*sizeof(uint16_t);

GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer: h_src1 offset:0 atIndex:3];
[encoder setBuffer: h_tpe offset:0 atIndex:4];
[encoder setBuffer: h_ids offset:0 atIndex:5];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
[encoder setBuffer: h_tpe offset:0 atIndex:2];
[encoder setBuffer: h_ids offset:0 atIndex:3];
[encoder setThreadgroupMemoryLength:smem atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}

{
Expand Down Expand Up @@ -4022,56 +4011,30 @@ static int ggml_metal_encode_node(
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.neh12 =*/ neh12,
/*.nbh10 =*/ nbh10,
/*.nbh11 =*/ nbh11,
/*.nbh12 =*/ nbh12,
/*.nbh13 =*/ nbh13,
/*.neh0 =*/ neh0,
/*.neh1 =*/ neh1,
/*.ne11 =*/ ne11, // n_expert_used (bcast)
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne20 =*/ ne20, // n_expert_used
/*.ne21 =*/ ne21, // n_tokens
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};

[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer: h_src1 offset:0 atIndex:2];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
[encoder setBuffer: h_dst offset:0 atIndex:4];
[encoder setBuffer: h_ids offset:0 atIndex:4];
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];

[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}

{
GGML_ASSERT(ne0 % 4 == 0);

const int nth = MIN(1024, ne0/4);

ggml_metal_kargs_mul_mm_id_map1 args = {
ne20, // n_expert_used
neh0,
neh1,
nbh1,
nbh2,
ne0,
nb1,
nb2,
};

id<MTLComputePipelineState> pipeline = nil;

pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer: h_dst offset:0 atIndex:1];
[encoder setBuffer: h_ids offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} else {
id<MTLComputePipelineState> pipeline = nil;

Expand Down
Loading
Loading