File tree 1 file changed +7
-6
lines changed 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec(
3450
3450
{
3451
3451
// each simdgroup processes 1 query and 4 keys
3452
3452
for (short cc = 0 ; cc < C/4 ; ++cc) {
3453
- qk_t mqk = 0.0 ;
3453
+ qk_t mqka[ 4 ] = { 0.0 , 0.0 , 0.0 , 0.0 } ;
3454
3454
3455
3455
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
3456
@@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec(
3461
3461
k4x4_t mk;
3462
3462
deq_k (pk + i/nl_k, i%nl_k, mk);
3463
3463
3464
- mqk +=
3465
- dot (mq[ii/NL][0 ], mk[0 ]) +
3466
- dot (mq[ii/NL][1 ], mk[1 ]) +
3467
- dot (mq[ii/NL][2 ], mk[2 ]) +
3468
- dot (mq[ii/NL][3 ], mk[3 ]);
3464
+ mqka[0 ] += dot (mq[ii/NL][0 ], mk[0 ]);
3465
+ mqka[1 ] += dot (mq[ii/NL][1 ], mk[1 ]);
3466
+ mqka[2 ] += dot (mq[ii/NL][2 ], mk[2 ]);
3467
+ mqka[3 ] += dot (mq[ii/NL][3 ], mk[3 ]);
3469
3468
}
3470
3469
3470
+ qk_t mqk = mqka[0 ] + mqka[1 ] + mqka[2 ] + mqka[3 ];
3471
+
3471
3472
// simdgroup reduce
3472
3473
// [ 0 .. 7] -> [ 0]
3473
3474
// [ 8 .. 15] -> [ 8]
You can’t perform that action at this time.
0 commit comments