diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a8bd9e97641ca..fdcb62a7b373a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -665,7 +665,9 @@ struct vk_op_push_constants { }; struct vk_op_glu_push_constants { + uint32_t N; uint32_t ne00; + uint32_t ne20; uint32_t mode; // 0: default, 1: swapped, 2: split }; @@ -2761,8 +2763,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_UNARY #define CREATE_GLU(name) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -6867,7 +6869,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: case GGML_OP_ARGMAX: - case GGML_OP_GLU: { const uint32_t nr = ggml_nrows(src0); if (nr > 262144) { @@ -6952,6 +6953,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_CONV_2D_DW: { uint32_t ne = ggml_nelements(dst); @@ -7600,7 +7602,7 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t mode = split ? 2 : (swapped ? 1 : 0); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun); } static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp index 0d65baef38944..41a29889075f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -1,15 +1,15 @@ #extension GL_EXT_shader_16bit_storage : require -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - layout (push_constant) uniform parameter { + uint N; uint ne00; + uint ne20; uint mode; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp index 24814240365d2..85cf65a9ecac8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp @@ -1,31 +1,29 @@ void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.ne00 + i; - - data_d[row * offset + i] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); - } + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.ne00 + i; - - data_d[row * offset + i] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); - } + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - for (uint i = col; i < p.ne00; i += BLOCK_SIZE) { - const uint idx = row * p.ne00 + i; + const uint idx = row * p.ne00 + col; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); - } + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } }