Skip to content

Commit eaea67b

Browse files
ggerganovarthw
authored andcommitted
metal : minor fixup in FA kernel (ggml-org#10143)
* metal : minor fixup in FA kernel ggml-ci * metal : use the unrolled loop variable * metal : remove unused var
1 parent 037d06f commit eaea67b

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

ggml/src/ggml-metal.metal

+8-9
Original file line numberDiff line numberDiff line change
@@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
27762776
const short iv3 = iq3 / rv3;
27772777

27782778
// load the queries from shared memory into local memory
2779-
float4 mq[D4];
2779+
float4 mq[D4/NW];
27802780

27812781
for (short ii = 0; ii < D4; ii += NW) {
27822782
short i = ii + tiisg;
2783-
mq[i] = (float4) sq4[i];
2783+
mq[ii/NW] = (float4) sq4[i];
27842784
}
27852785

27862786
// pointer to the mask
@@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28122812
mk[2] = (float4) pk4[i + 2*(nb11/8)];
28132813
mk[3] = (float4) pk4[i + 3*(nb11/8)];
28142814

2815-
mqk += (float4) (mq[i] * mk);
2815+
mqk += (float4) (mq[ii/NW] * mk);
28162816
}
28172817

28182818
// reduce the results from the threads in the simdgroup
@@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
28572857
// O = diag(ms)*O
28582858
#pragma unroll
28592859
for (short ii = 0; ii < D4; ii += NW) {
2860-
const short i = ii + tiisg;
2861-
lo[i/NW] *= ms;
2860+
lo[ii/NW] *= ms;
28622861
}
28632862
}
28642863

@@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
28722871
for (short ii = 0; ii < D4; ii += NW) {
28732872
const short i = ii + tiisg;
28742873

2875-
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
2876-
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
2877-
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
2878-
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
2874+
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
2875+
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
2876+
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
2877+
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
28792878
}
28802879
}
28812880
}

0 commit comments

Comments
 (0)