@@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
2776
2776
const short iv3 = iq3 / rv3;
2777
2777
2778
2778
// load the queries from shared memory into local memory
2779
- float4 mq[D4];
2779
+ float4 mq[D4/NW ];
2780
2780
2781
2781
for (short ii = 0 ; ii < D4; ii += NW) {
2782
2782
short i = ii + tiisg;
2783
- mq[i ] = (float4) sq4[i];
2783
+ mq[ii/NW ] = (float4) sq4[i];
2784
2784
}
2785
2785
2786
2786
// pointer to the mask
@@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2812
2812
mk[2 ] = (float4) pk4[i + 2 *(nb11/8 )];
2813
2813
mk[3 ] = (float4) pk4[i + 3 *(nb11/8 )];
2814
2814
2815
- mqk += (float4) (mq[i ] * mk);
2815
+ mqk += (float4) (mq[ii/NW ] * mk);
2816
2816
}
2817
2817
2818
2818
// reduce the results from the threads in the simdgroup
@@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2857
2857
// O = diag(ms)*O
2858
2858
#pragma unroll
2859
2859
for (short ii = 0 ; ii < D4; ii += NW) {
2860
- const short i = ii + tiisg;
2861
- lo[i/NW] *= ms;
2860
+ lo[ii/NW] *= ms;
2862
2861
}
2863
2862
}
2864
2863
@@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
2872
2871
for (short ii = 0 ; ii < D4; ii += NW) {
2873
2872
const short i = ii + tiisg;
2874
2873
2875
- lo[i /NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2876
- lo[i /NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2877
- lo[i /NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2878
- lo[i /NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
2874
+ lo[ii /NW] += pv4[i + 0 *(nb21/8 )] * ss[4 *cc + 0 ];
2875
+ lo[ii /NW] += pv4[i + 1 *(nb21/8 )] * ss[4 *cc + 1 ];
2876
+ lo[ii /NW] += pv4[i + 2 *(nb21/8 )] * ss[4 *cc + 2 ];
2877
+ lo[ii /NW] += pv4[i + 3 *(nb21/8 )] * ss[4 *cc + 3 ];
2879
2878
}
2880
2879
}
2881
2880
}
0 commit comments