Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
Expand Down Expand Up @@ -183,6 +184,15 @@ def __init__(
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()

# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
Expand Down Expand Up @@ -487,7 +497,7 @@ def _forward_prefill_flash(
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

attn_output = flash_attn_varlen_func(
attn_output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
Expand All @@ -497,7 +507,6 @@ def _forward_prefill_flash(
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=self.vllm_flash_attn_version,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
Expand Down