diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 916cc2efa389..a08c874407e3 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,10 +1,10 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import torch from tests.kernels.utils import override_backend_env_variable -from vllm.attention.selector import which_attn_to_use +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform from vllm.platforms.openvino import OpenVinoPlatform @@ -12,6 +12,13 @@ from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) @@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch): if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend.name == "TORCH_SDPA" + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.current_platform", RocmPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend.name == "ROCM_FLASH" + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.current_platform", - OpenVinoPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend.name == "OPENVINO" + OpenVinoPlatform()), patch.dict('sys.modules', + {'openvino': Mock()}): + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == "OPENVINO" else: - with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend.name == name + if name in ["XFORMERS", "FLASHINFER"]: + with patch("vllm.attention.selector.current_platform", + CudaPlatform()): + backend = get_attn_backend(16, torch.float16, torch.float16, + 16, False) + assert backend.get_name() == name def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" # TODO: When testing for v1, pipe in `use_v1` as an argument to - # which_attn_to_use + # get_attn_backend override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=(7, 5)): - backend = which_attn_to_use(16, torch.float16, None, 16, False) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported data type - backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = which_attn_to_use(16, torch.float16, "fp8", 16, False) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported block size - backend = which_attn_to_use(16, torch.float16, None, 8, False) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(16, torch.float16, None, 8, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(16, torch.float16, None, 16, False) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported head size - backend = which_attn_to_use(17, torch.float16, None, 16, False) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(17, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True) - assert backend.name != STR_FLASH_ATTN_VAL + backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + assert backend.get_name() != STR_FLASH_ATTN_VAL def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): - which_attn_to_use(16, torch.float16, None, 16, False) + get_attn_backend(16, torch.float16, None, 16, False) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d26383970569..0ff007c87b1c 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform -from vllm.utils import STR_BACKEND_ENV_VAR +from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) @@ -114,83 +114,19 @@ def _cached_get_attn_backend( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, - is_attention_free, use_v1) - if backend == _Backend.FLASH_ATTN: - logger.info("Using Flash Attention backend.") - from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend) - return FlashAttentionBackend - if backend == _Backend.FLASH_ATTN_VLLM_V1: - from vllm.v1.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend as FlashAttentionBackendV1) - return FlashAttentionBackendV1 - if backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - from vllm.attention.backends.xformers import ( # noqa: F401 - XFormersBackend) - return XFormersBackend - elif backend == _Backend.ROCM_FLASH: - logger.info("Using ROCmFlashAttention backend.") - from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 - ROCmFlashAttentionBackend) - return ROCmFlashAttentionBackend - elif backend == _Backend.TORCH_SDPA: - assert current_platform.is_cpu(), RuntimeError( - "Torch SDPA backend is only used for the CPU device.") - logger.info("Using Torch SDPA backend.") - from vllm.attention.backends.torch_sdpa import TorchSDPABackend - return TorchSDPABackend - elif backend == _Backend.OPENVINO: - logger.info("Using OpenVINO Attention backend.") - from vllm.attention.backends.openvino import OpenVINOAttentionBackend - return OpenVINOAttentionBackend - elif backend == _Backend.IPEX: - assert current_platform.is_xpu(), RuntimeError( - "IPEX attention backend is only used for the XPU device.") - logger.info("Using IPEX attention backend.") - from vllm.attention.backends.ipex_attn import IpexAttnBackend - return IpexAttnBackend - elif backend == _Backend.FLASHINFER: - logger.info("Using Flashinfer backend.") - from vllm.attention.backends.flashinfer import FlashInferBackend - return FlashInferBackend - elif backend == _Backend.HPU_ATTN: - logger.info("Using HPUAttention backend.") - from vllm.attention.backends.hpu_attn import HPUAttentionBackend - return HPUAttentionBackend - elif backend == _Backend.PALLAS: - logger.info("Using Pallas backend.") - from vllm.attention.backends.pallas import PallasAttentionBackend - return PallasAttentionBackend - elif backend == _Backend.NO_ATTENTION: - from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionBackend) - return PlaceholderAttentionBackend - else: - raise ValueError("Invalid attention backend.") - - -def which_attn_to_use(head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, - use_v1: bool = False) -> _Backend: - """Returns which flash attention backend to use.""" - # Default case. - selected_backend = _Backend.FLASH_ATTN - # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION if is_attention_free: - return _Backend.NO_ATTENTION + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend # Check whether a particular choice of backend was # previously forced. # # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. + selected_backend = None backend_by_global_setting: Optional[_Backend] = ( get_global_forced_attn_backend()) if backend_by_global_setting is not None: @@ -201,64 +137,13 @@ def which_attn_to_use(head_size: int, if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) - # get device-specific default attn_backend - default_backend = current_platform.get_default_attn_backend( - selected_backend) - if default_backend is not None: - return default_backend - - if use_v1: - return _Backend.FLASH_ATTN_VLLM_V1 - - # FlashAttn in NVIDIA GPUs. - if selected_backend == _Backend.FLASH_ATTN: - if not current_platform.has_device_capability(80): - # Volta and Turing NVIDIA GPUs. - logger.info( - "Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - selected_backend = _Backend.XFORMERS - elif dtype not in (torch.float16, torch.bfloat16): - logger.info( - "Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - selected_backend = _Backend.XFORMERS - elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info( - "Cannot use FlashAttention-2 backend for FP8 KV cache.") - logger.warning( - "Please use FlashInfer backend with FP8 KV Cache for " - "better performance by setting environment variable " - "VLLM_ATTENTION_BACKEND=FLASHINFER") - selected_backend = _Backend.XFORMERS - elif block_size % 16 != 0: - logger.info( - "Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - selected_backend = _Backend.XFORMERS - - # FlashAttn is valid for the model, checking if the package is installed. - if selected_backend == _Backend.FLASH_ATTN: - try: - import vllm.vllm_flash_attn # noqa: F401 - from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend) - - supported_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in supported_sizes: - logger.info( - "Cannot use FlashAttention-2 backend for head size %d.", - head_size) - selected_backend = _Backend.XFORMERS - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the " - "vllm.vllm_flash_attn package is not found. " - "Make sure that vllm_flash_attn was built and installed " - "(on by default).") - selected_backend = _Backend.XFORMERS - - return selected_backend + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}") + return resolve_obj_by_qualname(attention_cls) @contextmanager diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 7ba7f5150150..eb3e269cac28 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -28,10 +28,13 @@ def get_device_name(cls, device_id: int = 0) -> str: return "cpu" @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool) -> str: if selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) - return _Backend.TORCH_SDPA + logger.info("Using Torch SDPA backend.") + return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 3c5350b77834..23ceac83e49d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -16,7 +16,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from .interface import DeviceCapability, Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import VllmConfig @@ -141,6 +141,81 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + @classmethod + def get_attn_backend_cls(cls, selected_backend, head_size, dtype, + kv_cache_dtype, block_size, use_v1) -> str: + if use_v1: + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if selected_backend == _Backend.FLASHINFER: + logger.info("Using FlashInfer backend.") + return "vllm.attention.backends.flashinfer.FlashInferBackend" + elif selected_backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + return "vllm.attention.backends.xformers.XFormersBackend" + elif selected_backend == _Backend.FLASH_ATTN: + pass + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}") + + target_backend = _Backend.FLASH_ATTN + if not cls.has_device_capability(80): + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + target_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + target_backend = _Backend.XFORMERS + elif kv_cache_dtype is not None and \ + kv_cache_dtype.startswith("fp8"): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache.") + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by setting environment variable " + "VLLM_ATTENTION_BACKEND=FLASHINFER") + target_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + target_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is + # installed. + if target_backend == _Backend.FLASH_ATTN: + try: + import vllm.vllm_flash_attn # noqa: F401 + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + + supported_sizes = \ + FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + target_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm.vllm_flash_attn package is not found. " + "Make sure that vllm_flash_attn was built and installed " + "(on by default).") + target_backend = _Backend.XFORMERS + + if target_backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + return "vllm.attention.backends.xformers.XFormersBackend" + + logger.info("Using Flash Attention backend.") + return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 0a44f2b74163..8152d881fa8d 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -21,8 +21,11 @@ class HpuPlatform(Platform): dispatch_key: str = "HPU" @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: - return _Backend.HPU_ATTN + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool) -> str: + logger.info("Using HPUAttention backend.") + return "vllm.attention.backends.hpu_attn.HPUAttentionBackend" @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ddccaa2ce014..f440358f65fb 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -112,9 +112,11 @@ def is_cuda_alike(self) -> bool: return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend): - """Get the default attention backend of a device.""" - return None + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool) -> str: + """Get the attention backend class of a device.""" + return "" @classmethod def get_device_capability( diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 16eb8dc81efc..9390eda535c8 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -28,10 +28,13 @@ class OpenVinoPlatform(Platform): dispatch_key: str = "CPU" @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool) -> str: if selected_backend != _Backend.OPENVINO: logger.info("Cannot use %s backend on OpenVINO.", selected_backend) - return _Backend.OPENVINO + logger.info("Using OpenVINO Attention backend.") + return "vllm.attention.backends.openvino.OpenVINOAttentionBackend" @classmethod def get_device_name(cls, device_id: int = 0) -> str: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index aa779f265135..1c2f602efc85 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -70,7 +70,8 @@ class RocmPlatform(Platform): ] @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + def get_attn_backend_cls(cls, selected_backend, head_size, dtype, + kv_cache_dtype, block_size, use_v1) -> str: selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: @@ -79,7 +80,8 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: logger.info("flash_attn is not supported on NAVI GPUs.") else: logger.info("%s is not supported in AMD GPUs.", selected_backend) - return _Backend.ROCM_FLASH + logger.info("Using ROCmFlashAttention backend.") + return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 @classmethod @lru_cache(maxsize=8) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 77f5c8401424..10e07b782dc9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -22,10 +22,13 @@ class TpuPlatform(Platform): supported_quantization: list[str] = ["tpu_int8"] @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool) -> str: if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS + logger.info("Using Pallas backend.") + return "vllm.attention.backends.pallas.PallasAttentionBackend" @classmethod def get_device_name(cls, device_id: int = 0) -> str: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 78e17c2afec6..00692a5d2303 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -21,10 +21,13 @@ class XPUPlatform(Platform): dispatch_key: str = "XPU" @classmethod - def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + dtype: torch.dtype, kv_cache_dtype: Optional[str], + block_size: int, use_v1: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) - return _Backend.IPEX + logger.info("Using IPEX attention backend.") + return "vllm.attention.backends.ipex_attn.IpexAttnBackend" @staticmethod def get_device_capability(device_id: int = 0) -> DeviceCapability: