diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 36fcf5ee8..4a8221a50 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -106,8 +106,7 @@ def main_split_persistent( T.copy(acc_s, S_shared) T.copy(S_shared, acc_s_cast) for i in T.Parallel(block_H): - logsum[i] = logsum[i] * sco - es_scale[i] + scores_sum[i] + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_H, dim): acc_o[i, j] *= scores_scale[i] T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)