diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index fc6526d6d5dc6..82c1ac1da0b13 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -320,40 +320,31 @@ typedef struct { } ggml_metal_kargs_mul_mv_ext; typedef struct { + int32_t ne02; int32_t ne10; int32_t ne11; // n_expert_used (bcast) uint64_t nb11; uint64_t nb12; - int32_t neh11; // n_tokens - uint64_t nbh11; + int32_t ne21; // n_tokens int32_t ne20; // n_expert_used uint64_t nb21; } ggml_metal_kargs_mul_mm_id_map0; -typedef struct { - int32_t ne20; // n_expert_used - int32_t neh0; - int32_t neh1; - uint64_t nbh1; - uint64_t nbh2; - int32_t ne0; - uint64_t nb1; - uint64_t nb2; -} ggml_metal_kargs_mul_mm_id_map1; - typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; - int32_t neh12; - uint64_t nbh10; - uint64_t nbh11; - uint64_t nbh12; - uint64_t nbh13; - int32_t neh0; - int32_t neh1; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne0; + int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm_id; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index dcd816a430f60..7a05a98278f94 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -398,8 +398,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, @@ -1428,8 +1432,12 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); @@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node( default: break; } - const int64_t neh10 = ne10; // n_embd - const int64_t neh11 = ne21; // n_tokens - const int64_t neh12 = ne02; // n_expert - - const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); - const uint64_t nbh11 = nbh10*neh10; - const uint64_t nbh12 = nbh11*neh11; - const uint64_t nbh13 = nbh12*neh12; - - const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; - id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); - if (!h_src1) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); - return 0; - } - - const int64_t neh0 = ne0; - const int64_t neh1 = ne21; - const int64_t neh2 = ne02; - - const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); - const uint64_t nbh1 = nbh0*neh0; - const uint64_t nbh2 = nbh1*neh1; - //const uint64_t nbh3 = nbh2*neh2; - - const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; - id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); - if (!h_dst) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); - return 0; - } - // tokens per expert const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); @@ -3949,8 +3925,8 @@ static int ggml_metal_encode_node( } // id map - // [n_expert_used, n_tokens] - const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; + // [n_tokens, n_expert] + const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02; id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); if (!h_ids) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); @@ -3958,32 +3934,45 @@ static int ggml_metal_encode_node( } { - const int nth = MIN(1024, ne10/4); - ggml_metal_kargs_mul_mm_id_map0 args = { + ne02, ne10, - ne11, // n_expert_used (bcast) + ne11, // n_expert_used (bcast) nb11, nb12, - neh11, // n_tokens - nbh11, - ne20, // n_expert_used + ne21, // n_tokens + ne20, // n_expert_used nb21, }; id pipeline = nil; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; + pipeline = nil; + + switch (ne20) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break; + case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break; + case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break; + case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break; + default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20); + } + + GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup); + + const size_t smem = ne02*ne20*sizeof(uint16_t); + + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer: h_src1 offset:0 atIndex:3]; - [encoder setBuffer: h_tpe offset:0 atIndex:4]; - [encoder setBuffer: h_ids offset:0 atIndex:5]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:1]; + [encoder setBuffer: h_tpe offset:0 atIndex:2]; + [encoder setBuffer: h_ids offset:0 atIndex:3]; + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)]; } { @@ -4022,13 +4011,15 @@ static int ggml_metal_encode_node( /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, - /*.neh12 =*/ neh12, - /*.nbh10 =*/ nbh10, - /*.nbh11 =*/ nbh11, - /*.nbh12 =*/ nbh12, - /*.nbh13 =*/ nbh13, - /*.neh0 =*/ neh0, - /*.neh1 =*/ neh1, + /*.ne11 =*/ ne11, // n_expert_used (bcast) + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, // n_expert_used + /*.ne21 =*/ ne21, // n_tokens + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, /*.r2 =*/ r2, /*.r3 =*/ r3, }; @@ -4036,42 +4027,14 @@ static int ggml_metal_encode_node( [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer: h_src1 offset:0 atIndex:2]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; [encoder setBuffer: h_tpe offset:0 atIndex:3]; - [encoder setBuffer: h_dst offset:0 atIndex:4]; + [encoder setBuffer: h_ids offset:0 atIndex:4]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } - - { - GGML_ASSERT(ne0 % 4 == 0); - - const int nth = MIN(1024, ne0/4); - - ggml_metal_kargs_mul_mm_id_map1 args = { - ne20, // n_expert_used - neh0, - neh1, - nbh1, - nbh2, - ne0, - nb1, - nb2, - }; - - id pipeline = nil; - - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer: h_dst offset:0 atIndex:1]; - [encoder setBuffer: h_ids offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } } else { id pipeline = nil; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3dd55fd325ce2..7037c1aa0d30d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -974,9 +974,16 @@ kernel void kernel_mul( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } } } @@ -1000,9 +1007,16 @@ kernel void kernel_div( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = 1.0f / *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } } } @@ -7491,97 +7505,81 @@ kernel void kernel_mul_mm( } } -template +template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, - device const char * src1, device const char * src2, - device char * hsrc1, device char * htpe, device char * hids, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int ide = tgpig[0]; // expert id - - int n_all = 0; + threadgroup char * shmem [[threadgroup(0)]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const short ide = tpitg; // expert id - device int32_t * ids_i32 = (device int32_t *) (hids); + uint32_t n_all = 0; - for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens - device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21; - for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used - if (src2_i32[i20] != ide) { - continue; - } + for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens + if (i21 + tpitg < args.ne21) { + device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21); - device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); - device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20; - for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { - hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sids[i20] = src2_i32[i20]; } - - if (tpitg.x == 0) { - ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; - } - - ++n_all; } - } - if (tpitg.x == 0) { - device int32_t * tpe_i32 = (device int32_t *) (htpe); - tpe_i32[ide] = n_all; - } -} - -typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; - -template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; + threadgroup_barrier(mem_flags::mem_threadgroup); -template -kernel void kernel_mul_mm_id_map1( - constant ggml_metal_kargs_mul_mm_id_map1 & args, - device const char * hdst, - device const char * hids, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i20 = tgpig[0]; // used expert - const int i21 = tgpig[1]; // token + for (short t = 0; t < ntg; t++) { + if (i21 + t >= args.ne21) { + break; + } - device const int32_t * ids_i32 = (device const int32_t *) (hids); - device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20; - const int id = ids_i32[i21*args.ne20 + i20]; + short sel = 0; + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sel += (sids[i20] == ide)*(i20 + 1); + } - const int ide = id / args.neh1; - const int idt = id % args.neh1; + ids_i32[n_all] = (i21 + t)*ne20 + sel - 1; - device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + n_all += sel > 0; + } - for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { - dst_f32x4[i0] = hdst_f32x4[i0]; + threadgroup_barrier(mem_flags::mem_threadgroup); } + + device uint32_t * tpe_u32 = (device uint32_t *) (htpe); + tpe_u32[ide] = n_all; } -typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; +typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; -template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; +template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; template kernel void kernel_mul_mm_id( constant ggml_metal_kargs_mul_mm_id & args, device const char * src0, device const char * src1, - device const char * tpe, + device const char * htpe, + device const char * hids, device char * dst, threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup T * sa = (threadgroup T *)(shmem); @@ -7589,19 +7587,20 @@ kernel void kernel_mul_mm_id( const int r0 = tgpig.y; const int r1 = tgpig.x; - const int im = tgpig.z; + const int im = tgpig.z; // expert - device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); - const int neh1 = tpe_i32[im]; + const int32_t neh1 = tpe_u32[im]; if (r1*BLOCK_SIZE_N >= neh1) { return; } // if this block is of 64x32 shape or smaller - const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; @@ -7617,20 +7616,23 @@ kernel void kernel_mul_mm_id( short il = (tiitg % THREAD_PER_ROW); - const int i12 = im%args.neh12; - const int i13 = im/args.neh12; + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col]; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; const short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const half * y = (device const half *)(src1 - + args.nbh13*i13 - + args.nbh12*i12 - + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) - + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const float * y = (device const float *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory @@ -7646,7 +7648,7 @@ kernel void kernel_mul_mm_id( + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; } - *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y)); il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -7682,43 +7684,38 @@ kernel void kernel_mul_mm_id( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { - device float * C = (device float *) dst + - (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ - (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); - } + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(8) + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; - device float4 * D4 = (device float4 *) D; + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); - threadgroup float4 * C4 = (threadgroup float4 *) C; + for (short j = sgitg; j < n_cols; j += 4) { + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j]; - int i = 0; - for (; i < n_rows/4; i++) { - *(D4 + i) = *(C4 + i); - } + const short ide = id % args.ne20; + const short idt = id / args.ne20; - i *= 4; - for (; i < n_rows; i++) { - *(D + i) = *(C + i); - } - } + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < n_rows/4; i += 32) { + *(D4 + i) = *(C4 + i); + } + + i = (4*(n_rows/4)) + tiisg; + for (; i < n_rows; i += 32) { + *(D + i) = *(C + i); } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 765521ffebb4e..4b4299d49df61 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6018,6 +6018,7 @@ static std::vector> make_test_cases_eval() { for (bool b : {false, true}) { test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64)); } test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));