Skip to content

Commit bb38cdd

Browse files
authored
metal : fix F32 accumulation in FA vec kernel (#10232)
1 parent f018acb commit bb38cdd

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml/src/ggml-metal.metal

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec(
34503450
{
34513451
// each simdgroup processes 1 query and 4 keys
34523452
for (short cc = 0; cc < C/4; ++cc) {
3453-
qk_t mqk = 0.0;
3453+
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
34543454

34553455
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563456

@@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec(
34613461
k4x4_t mk;
34623462
deq_k(pk + i/nl_k, i%nl_k, mk);
34633463

3464-
mqk +=
3465-
dot(mq[ii/NL][0], mk[0]) +
3466-
dot(mq[ii/NL][1], mk[1]) +
3467-
dot(mq[ii/NL][2], mk[2]) +
3468-
dot(mq[ii/NL][3], mk[3]);
3464+
mqka[0] += dot(mq[ii/NL][0], mk[0]);
3465+
mqka[1] += dot(mq[ii/NL][1], mk[1]);
3466+
mqka[2] += dot(mq[ii/NL][2], mk[2]);
3467+
mqka[3] += dot(mq[ii/NL][3], mk[3]);
34693468
}
34703469

3470+
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
3471+
34713472
// simdgroup reduce
34723473
// [ 0 .. 7] -> [ 0]
34733474
// [ 8 .. 15] -> [ 8]

0 commit comments

Comments
 (0)