Skip to content

Commit 8c82ee2

Browse files
Bruce-Lee-LYBruce-Lee-LY
andauthored
[fix] xqa precision for fp16/bf16 kv cache (#6573)
Signed-off-by: Bruce-Lee-LY <[email protected]> Co-authored-by: Bruce-Lee-LY <[email protected]>
1 parent a54972e commit 8c82ee2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

cpp/kernels/xqa/mha_sm90.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa
27342734
reinterpret_cast<Vec<InputElem, 4>&>(f16Core)
27352735
= convert<InputElem>(reinterpret_cast<Vec<float, 4> const&>(core));
27362736
auto const dst = idxMat < 2
2737-
? &swizzleBuf.template at<true>(idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat)
2737+
? &swizzleBuf.template at<true>(8 * n + idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat)
27382738
: nullptr;
27392739
stmatrix<true, 2>(dst, f16Core);
27402740
#elif CACHE_ELEM_ENUM == 2

0 commit comments

Comments
 (0)