diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e9b4dff74f42..df3fb2aeefc4 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import functools from abc import abstractmethod from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple @@ -183,6 +184,15 @@ def __init__( self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_UV_O): @@ -487,7 +497,7 @@ def _forward_prefill_flash( v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) - attn_output = flash_attn_varlen_func( + attn_output = self.flash_attn_varlen_func( q=q, k=k, v=v_padded, @@ -497,7 +507,6 @@ def _forward_prefill_flash( max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, - fa_version=self.vllm_flash_attn_version, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\