From dc09d66d7084f158de34dcd933a15a176435f377 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Mar 2025 12:22:34 +0000 Subject: [PATCH 01/18] add AITER paged attention kernel Signed-off-by: vllmellm --- vllm/attention/backends/rocm_flash_attn.py | 94 +++++++++++++++----- vllm/attention/ops/rocm_aiter_paged_attn.py | 98 +++++++++++++++++++++ vllm/envs.py | 15 ++++ 3 files changed, 187 insertions(+), 20 deletions(-) create mode 100644 vllm/attention/ops/rocm_aiter_paged_attn.py diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c47202099ac6..27030d97cc26 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -15,12 +15,15 @@ CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) +from vllm.attention.ops.rocm_aiter_paged_attn import AITERPagedAttention from vllm.logger import init_logger from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +USE_AITER_PAGED_ATTN = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 @@ -29,6 +32,29 @@ _ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) +class AttentionOps: + """ + Initializes the appropriate PagedAttention module from `attention/ops`, + which is a component of the attention mechanism used + by `ROCmFlashAttentionImpl`. + + The choice of attention module depends on whether + AITER paged attention is enabled: + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. + - Otherwise, it defaults to using the original `PagedAttention`. + """ + + def __init__(self): + if USE_AITER_PAGED_ATTN: + self._attn_module = AITERPagedAttention() + else: + self._attn_module = PagedAttention() + + @property + def attn_module(self) -> PagedAttention: + return self._attn_module + + class ROCmFlashAttentionBackend(AttentionBackend): @staticmethod @@ -540,6 +566,9 @@ def __init__( self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") + self.attn_module = AttentionOps().attn_module + self.aiter_kv_scales_initialized = False + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape @@ -616,6 +645,30 @@ def forward( else: assert value is None + attn_module = self.attn_module + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (USE_AITER_PAGED_ATTN and kv_cache.dtype.itemsize == 1 + and not self.aiter_kv_scales_initialized + and kv_cache.shape != torch.Size([0])): + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + # Only update KV cache for decoder self-attention # and encoder-decoder cross-attention if self.attn_type not in [ @@ -629,7 +682,7 @@ def forward( # cache. If kv_cache is not provided, the new key and value # tensors are not cached. This happens during the initial # memory profiling run. - PagedAttention.write_to_paged_cache( + attn_module.write_to_paged_cache( key, value, key_cache, @@ -765,23 +818,23 @@ def forward( # prefix-enabled attention - # not applicable for encoder-only models if self.attn_type != AttentionType.ENCODER_ONLY: - output[: - num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) + attn_module = attn_module + output[:num_prefill_tokens] = attn_module.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) # Skip decode phase for encoder-only models if (decode_meta := attn_metadata.decode_metadata) and ( self.attn_type != AttentionType.ENCODER_ONLY): @@ -841,7 +894,7 @@ def forward( layer._v_scale, ) else: - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_tokens:] = attn_module.forward_decode( decode_query, key_cache, value_cache, @@ -909,4 +962,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 + and not USE_AITER_PAGED_ATTN) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 000000000000..9b7a0ce3f1b6 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention + + +class AITERPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + else: + key_cache = key_cache.view(torch.int8) + value_cache = value_cache.view(torch.int8) + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index bf214f314c45..24085775461b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,8 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -522,6 +524,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), + # use aiter ops unless specifically disabled. + # Acts as a parent switch to enable the rest of the other operations. + "VLLM_ROCM_USE_AITER": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1")), + + # use aiter paged attention if aiter ops are enabled. + # this is disabled by default. + "VLLM_ROCM_USE_AITER_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", + "True").lower() in ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), From fe9ff98be4681463304c2e881efac7769450b421 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Mar 2025 12:24:33 +0000 Subject: [PATCH 02/18] include AITER enable for rocm platforms in model end to end tests Signed-off-by: vllmellm --- .buildkite/run-amd-test.sh | 4 + .../decoder_only/language/test_mistral.py | 80 +++++++++++-------- .../decoder_only/language/test_models.py | 10 +++ .../decoder_only/language/test_phimoe.py | 19 ++--- tests/quantization/test_fp8.py | 22 ++++- 5 files changed, 91 insertions(+), 44 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 0680bae13ddb..2e15533ffcf8 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -72,6 +72,10 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" +# environment variables +SKIP_ROCM_ATIER_MODEL_TEST_CASES="True" +echo $SKIP_ROCM_ATIER_MODEL_TEST_CASES + commands=$@ echo "Commands:$commands" #ignore certain kernels tests diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 4c2055361d44..2809b0c98012 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -12,6 +12,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) +from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -174,15 +175,16 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -206,14 +208,16 @@ def test_models( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner( model, dtype=dtype, @@ -244,11 +248,15 @@ def test_mistral_format( @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages( - vllm_runner, - model: str, - dtype: str, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -266,11 +274,15 @@ def test_mistral_symbolic_languages( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -def test_mistral_function_calling( - vllm_runner, - model: str, - dtype: str, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_function_calling(vllm_runner, model: str, dtype: str, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -301,11 +313,15 @@ def test_mistral_function_calling( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding( - vllm_runner, - model: str, - guided_backend: str, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index a49926ea220e..7a25d652195d 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -5,6 +5,8 @@ """ import pytest +from vllm.platforms import current_platform + from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not @@ -69,6 +71,8 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models( hf_runner, vllm_runner, @@ -77,11 +81,17 @@ def test_models( dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch, ) -> None: if model in REQUIRES_V0: monkeypatch.setenv("VLLM_USE_V1", "0") + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with hf_runner(model, dtype=dtype) as hf_model: if model.startswith("THUDM/chatglm3"): hf_model.model.get_output_embeddings = lambda: \ diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index f9757d6ac295..2badcaf104bd 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -79,15 +79,16 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 19cf29d3e659..5cadc8d5dd49 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -23,11 +23,16 @@ reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - monkeypatch) -> None: + use_rocm_aiter: bool, monkeypatch) -> None: if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy @@ -47,7 +52,13 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, + use_rocm_aiter: bool, monkeypatch): + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: @@ -86,8 +97,13 @@ def check_model(model): reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - monkeypatch) -> None: + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") From d7c5dfba01ff35045ffb74fc40018b08c21583a5 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Mar 2025 12:26:20 +0000 Subject: [PATCH 03/18] add AITER into rocm docker base file Signed-off-by: vllmellm --- Dockerfile.rocm_base | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index e33e73b30309..50d23cfc9ad5 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="b7d29fb" ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +ARG AITER_BRANCH="e1ec015" +ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base @@ -129,6 +131,15 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl +ARG AITER_REPO +ARG AITER_BRANCH +RUN git clone --recursive ${AITER_REPO} +RUN cd aiter \ + && git checkout ${AITER_BRANCH} \ + && git submodule update --init --recursive \ + && pip install -r requirements.txt \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter + ARG BASE_IMAGE ARG HIPBLASLT_BRANCH ARG LEGACY_HIPBLASLT_OPTION @@ -156,3 +167,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt + && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file From 1732f9a412c4997418836b391c5cf2546c378306 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Mar 2025 05:32:17 +0000 Subject: [PATCH 04/18] use clearer name for paged attention module used in ROCmFlashAttentionImmp Signed-off-by: vllmellm --- vllm/attention/backends/rocm_flash_attn.py | 33 +++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 27030d97cc26..058ad432cbfc 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -22,8 +22,6 @@ if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -USE_AITER_PAGED_ATTN = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 @@ -32,7 +30,7 @@ _ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) -class AttentionOps: +class PagedAttentionOps: """ Initializes the appropriate PagedAttention module from `attention/ops`, which is a component of the attention mechanism used @@ -45,14 +43,14 @@ class AttentionOps: """ def __init__(self): - if USE_AITER_PAGED_ATTN: - self._attn_module = AITERPagedAttention() + if envs.VLLM_ROCM_USE_AITER_PAGED_ATTN: + self._paged_attn_module = AITERPagedAttention() else: - self._attn_module = PagedAttention() + self._paged_attn_module = PagedAttention() @property - def attn_module(self) -> PagedAttention: - return self._attn_module + def paged_attn_module(self) -> PagedAttention: + return self._paged_attn_module class ROCmFlashAttentionBackend(AttentionBackend): @@ -566,8 +564,8 @@ def __init__( self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") - self.attn_module = AttentionOps().attn_module - self.aiter_kv_scales_initialized = False + self.paged_attn_module = PagedAttentionOps().paged_attn_module + self.aiter_kv_scales_initialized = False def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -645,13 +643,15 @@ def forward( else: assert value is None - attn_module = self.attn_module + paged_attn = self.paged_attn_module + # Reshaping kv tensors is required for AITER paged attention kernel # because it works on a different tensor shape, # when the size of one element is one byte (int8/fp8 dtypes). # This reshaping is only required on the first forward call # and the kv cache must not be empty. - if (USE_AITER_PAGED_ATTN and kv_cache.dtype.itemsize == 1 + if (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and kv_cache.dtype.itemsize == 1 and not self.aiter_kv_scales_initialized and kv_cache.shape != torch.Size([0])): num_blocks = kv_cache.shape[1] @@ -682,7 +682,7 @@ def forward( # cache. If kv_cache is not provided, the new key and value # tensors are not cached. This happens during the initial # memory profiling run. - attn_module.write_to_paged_cache( + paged_attn.write_to_paged_cache( key, value, key_cache, @@ -818,8 +818,7 @@ def forward( # prefix-enabled attention - # not applicable for encoder-only models if self.attn_type != AttentionType.ENCODER_ONLY: - attn_module = attn_module - output[:num_prefill_tokens] = attn_module.forward_prefix( + output[:num_prefill_tokens] = paged_attn.forward_prefix( query, key, value, @@ -894,7 +893,7 @@ def forward( layer._v_scale, ) else: - output[num_prefill_tokens:] = attn_module.forward_decode( + output[num_prefill_tokens:] = paged_attn.forward_decode( decode_query, key_cache, value_cache, @@ -963,4 +962,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and not USE_AITER_PAGED_ATTN) + and not envs.VLLM_ROCM_USE_AITER_PAGED_ATTN) From 85296f78f83e843604a52b3dbab2525d43bef92b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Mar 2025 07:42:32 +0000 Subject: [PATCH 05/18] fix get envs variables in unit tests Signed-off-by: vllmellm --- tests/models/decoder_only/language/test_mistral.py | 11 ++++++----- tests/models/decoder_only/language/test_models.py | 4 +++- tests/models/decoder_only/language/test_phimoe.py | 4 +++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 2809b0c98012..8c6353d5b3eb 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -5,6 +5,7 @@ """ import copy import json +import os import jsonschema import jsonschema.exceptions @@ -181,7 +182,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -214,7 +215,7 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -253,7 +254,7 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -279,7 +280,7 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, def test_mistral_function_calling(vllm_runner, model: str, dtype: str, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -318,7 +319,7 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str, def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 7a25d652195d..593fc7af2fb4 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -3,6 +3,8 @@ Run `pytest tests/models/test_models.py`. """ +import os + import pytest from vllm.platforms import current_platform @@ -88,7 +90,7 @@ def test_models( monkeypatch.setenv("VLLM_USE_V1", "0") if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 2badcaf104bd..d9cfac1d3b38 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -3,6 +3,8 @@ Run `pytest tests/models/test_phimoe.py`. """ +import os + import pytest import torch @@ -85,7 +87,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") From 07ac4d443683f95e5375636c10b5956c1b90fc11 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Mar 2025 11:42:04 +0000 Subject: [PATCH 06/18] Remove AttentionOps class instead use a simple funtion to return appropriate paged attention module Signed-off-by: vllmellm --- vllm/attention/backends/rocm_flash_attn.py | 36 ++++++++++------------ 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 058ad432cbfc..f317f12e5ca9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -30,27 +30,20 @@ _ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) -class PagedAttentionOps: +def _get_paged_attn_module() -> PagedAttention: """ Initializes the appropriate PagedAttention module from `attention/ops`, which is a component of the attention mechanism used - by `ROCmFlashAttentionImpl`. + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. The choice of attention module depends on whether AITER paged attention is enabled: - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - Otherwise, it defaults to using the original `PagedAttention`. """ - - def __init__(self): - if envs.VLLM_ROCM_USE_AITER_PAGED_ATTN: - self._paged_attn_module = AITERPagedAttention() - else: - self._paged_attn_module = PagedAttention() - - @property - def paged_attn_module(self) -> PagedAttention: - return self._paged_attn_module + if envs.VLLM_ROCM_USE_AITER_PAGED_ATTN: + return AITERPagedAttention() + return PagedAttention() class ROCmFlashAttentionBackend(AttentionBackend): @@ -82,8 +75,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + paged_attn = _get_paged_attn_module() + return paged_attn.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -91,14 +85,16 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + paged_attn = _get_paged_attn_module() + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + paged_attn = _get_paged_attn_module() + paged_attn.copy_blocks(kv_caches, src_to_dists) @dataclass @@ -514,7 +510,10 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - supported_head_sizes = PagedAttention.get_supported_head_sizes() + self.paged_attn_module = _get_paged_attn_module() + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( + ) + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " @@ -564,7 +563,6 @@ def __init__( self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") - self.paged_attn_module = PagedAttentionOps().paged_attn_module self.aiter_kv_scales_initialized = False def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -674,7 +672,7 @@ def forward( if self.attn_type not in [ AttentionType.ENCODER, AttentionType.ENCODER_ONLY ] and kv_cache.numel() > 0: - key_cache, value_cache = PagedAttention.split_kv_cache( + key_cache, value_cache = paged_attn.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if key is not None and value is not None: From 1592e7e225f94a76ad835bc19c0b6741db52390a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 19 Mar 2025 09:31:50 +0000 Subject: [PATCH 07/18] remove cascading logic from vllm.envs Signed-off-by: vllmellm --- vllm/attention/backends/rocm_flash_attn.py | 14 +++++++++----- vllm/envs.py | 14 +++++++------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f317f12e5ca9..912f2045a9a3 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -30,10 +30,15 @@ _ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) +def is_rocm_aiter_paged_attn_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ + and envs.VLLM_ROCM_USE_AITER \ + + def _get_paged_attn_module() -> PagedAttention: """ Initializes the appropriate PagedAttention module from `attention/ops`, - which is a component of the attention mechanism used + which is used as helper function by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. The choice of attention module depends on whether @@ -41,7 +46,7 @@ def _get_paged_attn_module() -> PagedAttention: - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - Otherwise, it defaults to using the original `PagedAttention`. """ - if envs.VLLM_ROCM_USE_AITER_PAGED_ATTN: + if is_rocm_aiter_paged_attn_enabled(): return AITERPagedAttention() return PagedAttention() @@ -648,8 +653,7 @@ def forward( # when the size of one element is one byte (int8/fp8 dtypes). # This reshaping is only required on the first forward call # and the kv cache must not be empty. - if (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and kv_cache.dtype.itemsize == 1 + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 and not self.aiter_kv_scales_initialized and kv_cache.shape != torch.Size([0])): num_blocks = kv_cache.shape[1] @@ -960,4 +964,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and not envs.VLLM_ROCM_USE_AITER_PAGED_ATTN) + and not is_rocm_aiter_paged_attn_enabled()) diff --git a/vllm/envs.py b/vllm/envs.py index 24085775461b..6f319a21d9e1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -524,18 +524,18 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), - # use aiter ops unless specifically disabled. - # Acts as a parent switch to enable the rest of the other operations. + # Whether to enable all aiter ops. + # Acts as a parent switch to enable/disable the aiter ops. + # By default is enabled. "VLLM_ROCM_USE_AITER": lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), - # use aiter paged attention if aiter ops are enabled. - # this is disabled by default. + # Whether to use aiter paged attention. + # By default is enabled. "VLLM_ROCM_USE_AITER_PAGED_ATTN": - lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in - ("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", - "True").lower() in ("true", "1")), + lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "True").lower() in + ("true", "1")), # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": From 07bf5c6bd606f68b3b6fb51d49bc1b37ff37311c Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 19 Mar 2025 08:39:09 +0000 Subject: [PATCH 08/18] refactor aiter unit test flags into decorator Signed-off-by: tjtanaa Signed-off-by: vllmellm --- .../decoder_only/language/test_models.py | 28 +++------------ tests/utils.py | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 593fc7af2fb4..99b3d646f147 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -3,11 +3,10 @@ Run `pytest tests/models/test_models.py`. """ -import os import pytest -from vllm.platforms import current_platform +from tests.utils import maybe_test_rocm_aiter from ...utils import check_logprobs_close @@ -73,26 +72,9 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, - use_rocm_aiter: bool, - monkeypatch, -) -> None: - if model in REQUIRES_V0: - monkeypatch.setenv("VLLM_USE_V1", "0") - - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") +@maybe_test_rocm_aiter +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int) -> None: with hf_runner(model, dtype=dtype) as hf_model: if model.startswith("THUDM/chatglm3"): @@ -111,4 +93,4 @@ def test_models( outputs_1_lst=vllm_outputs, name_0="hf", name_1="vllm", - ) + ) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 627cf567afcc..aba7bd04e8f5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -914,3 +914,37 @@ def get_client_text_logprob_generations( return [(text_generations, text, (None if x.logprobs is None else x.logprobs.top_logprobs)) for completion in completions for x in completion.choices] + + +def maybe_test_rocm_aiter(func): + if not current_platform.is_rocm(): + return func + + def test_case(use_rocm_aiter, *args, **kwargs): + if use_rocm_aiter and (os.getenv("ROCM_SKIP_AITER_TEST_CASES", + "False").lower() in ("true", "1")): + print("AITER tests are skipped/disabled.") + return None + + with pytest.MonkeyPatch.context() as ctx: + ctx.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0") + return func(*args, **kwargs) + + error_messages = [] + + @functools.wraps(func) + def wrapper(*args, **kwargs): + for use_rocm_aiter in (True, False): + try: + test_case(use_rocm_aiter, *args, **kwargs) + except Exception: + import traceback + error_type = (f"With ROCM_AITER=" + f"{'ON' if use_rocm_aiter else 'OFF'}") + error_trace = traceback.format_exc() + error_messages.append(f"{error_type}:\n{error_trace}") + + if error_messages: + raise Exception('\n\n' + '=' * 50 + '\n\n'.join(error_messages)) + + return wrapper From 1fdd695223093bb1cf5a72db73dfac58c488fad5 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 19 Mar 2025 11:29:55 +0000 Subject: [PATCH 09/18] modify the rocm AITER check tests based on new decorator and include granite model test as well Signed-off-by: vllmellm --- .../decoder_only/language/test_granite.py | 3 + .../decoder_only/language/test_mistral.py | 60 ++++--------------- .../decoder_only/language/test_phimoe.py | 15 +---- tests/quantization/test_fp8.py | 26 +++----- 4 files changed, 26 insertions(+), 78 deletions(-) diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index 119b79d64c96..b3bab255c4d4 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -5,6 +5,8 @@ """ import pytest +from tests.utils import maybe_test_rocm_aiter + from ...utils import check_logprobs_close MODELS = [ @@ -18,6 +20,7 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) +@maybe_test_rocm_aiter def test_models( hf_runner, vllm_runner, diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 8c6353d5b3eb..70d7cfacb462 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -5,15 +5,14 @@ """ import copy import json -import os import jsonschema import jsonschema.exceptions import pytest +from tests.utils import maybe_test_rocm_aiter from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) -from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -176,16 +175,9 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@maybe_test_rocm_aiter def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + dtype: str, max_tokens: int, num_logprobs: int) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -209,16 +201,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@maybe_test_rocm_aiter def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + max_tokens: int, num_logprobs: int) -> None: with vllm_runner( model, dtype=dtype, @@ -249,15 +234,9 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +@maybe_test_rocm_aiter +def test_mistral_symbolic_languages(vllm_runner, model: str, + dtype: str) -> None: with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -275,15 +254,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_mistral_function_calling(vllm_runner, model: str, dtype: str, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +@maybe_test_rocm_aiter +def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -314,15 +286,9 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +@maybe_test_rocm_aiter +def test_mistral_guided_decoding(vllm_runner, model: str, + guided_backend: str) -> None: with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index d9cfac1d3b38..5eea9bf7e23f 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -3,14 +3,12 @@ Run `pytest tests/models/test_phimoe.py`. """ -import os - import pytest import torch from vllm.platforms import current_platform -from ....utils import large_gpu_test +from ....utils import large_gpu_test, maybe_test_rocm_aiter from ...utils import check_logprobs_close MODELS = [ @@ -81,16 +79,9 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@maybe_test_rocm_aiter def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + dtype: str, max_tokens: int, num_logprobs: int) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 5cadc8d5dd49..07a8ce42728d 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -7,6 +7,7 @@ import torch from tests.quantization.utils import is_quant_method_supported +from tests.utils import maybe_test_rocm_aiter from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, Fp8LinearMethod) @@ -23,16 +24,12 @@ reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@maybe_test_rocm_aiter def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: + monkeypatch) -> None: if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy @@ -52,13 +49,8 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, - use_rocm_aiter: bool, monkeypatch): - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +@maybe_test_rocm_aiter +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: @@ -97,13 +89,9 @@ def check_model(model): reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@maybe_test_rocm_aiter def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + monkeypatch) -> None: # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") From bb3687d7244551bab782745e2e81667ca012a5a5 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 26 Mar 2025 05:01:50 +0000 Subject: [PATCH 10/18] remove the decorator for enability of rocm AITER ops in tests Signed-off-by: vllmellm --- .../decoder_only/language/test_granite.py | 3 -- .../decoder_only/language/test_mistral.py | 6 ---- .../decoder_only/language/test_models.py | 3 -- .../decoder_only/language/test_phimoe.py | 3 +- tests/quantization/test_fp8.py | 4 --- tests/utils.py | 34 ------------------- 6 files changed, 1 insertion(+), 52 deletions(-) diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index b3bab255c4d4..119b79d64c96 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -5,8 +5,6 @@ """ import pytest -from tests.utils import maybe_test_rocm_aiter - from ...utils import check_logprobs_close MODELS = [ @@ -20,7 +18,6 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@maybe_test_rocm_aiter def test_models( hf_runner, vllm_runner, diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 70d7cfacb462..ec885386dd94 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -10,7 +10,6 @@ import jsonschema.exceptions import pytest -from tests.utils import maybe_test_rocm_aiter from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) from vllm.sampling_params import GuidedDecodingParams, SamplingParams @@ -175,7 +174,6 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@maybe_test_rocm_aiter def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int) -> None: # TODO(sang): Sliding window should be tested separately. @@ -201,7 +199,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@maybe_test_rocm_aiter def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int) -> None: with vllm_runner( @@ -234,7 +231,6 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@maybe_test_rocm_aiter def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, @@ -254,7 +250,6 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -@maybe_test_rocm_aiter def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, @@ -286,7 +281,6 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -@maybe_test_rocm_aiter def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str) -> None: with vllm_runner(model, dtype='bfloat16', diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 99b3d646f147..7f4386913670 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -6,8 +6,6 @@ import pytest -from tests.utils import maybe_test_rocm_aiter - from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not @@ -72,7 +70,6 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) -@maybe_test_rocm_aiter def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int) -> None: diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 5eea9bf7e23f..5e43f20bd2b1 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -8,7 +8,7 @@ from vllm.platforms import current_platform -from ....utils import large_gpu_test, maybe_test_rocm_aiter +from ....utils import large_gpu_test from ...utils import check_logprobs_close MODELS = [ @@ -79,7 +79,6 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@maybe_test_rocm_aiter def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int) -> None: with hf_runner(model, dtype=dtype) as hf_model: diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 07a8ce42728d..19cf29d3e659 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -7,7 +7,6 @@ import torch from tests.quantization.utils import is_quant_method_supported -from tests.utils import maybe_test_rocm_aiter from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, Fp8LinearMethod) @@ -24,7 +23,6 @@ reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) -@maybe_test_rocm_aiter def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, monkeypatch) -> None: if force_marlin: @@ -49,7 +47,6 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) -@maybe_test_rocm_aiter def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") @@ -89,7 +86,6 @@ def check_model(model): reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) -@maybe_test_rocm_aiter def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, monkeypatch) -> None: # vllm_runner.apply_model() relies on V0 internals. diff --git a/tests/utils.py b/tests/utils.py index aba7bd04e8f5..627cf567afcc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -914,37 +914,3 @@ def get_client_text_logprob_generations( return [(text_generations, text, (None if x.logprobs is None else x.logprobs.top_logprobs)) for completion in completions for x in completion.choices] - - -def maybe_test_rocm_aiter(func): - if not current_platform.is_rocm(): - return func - - def test_case(use_rocm_aiter, *args, **kwargs): - if use_rocm_aiter and (os.getenv("ROCM_SKIP_AITER_TEST_CASES", - "False").lower() in ("true", "1")): - print("AITER tests are skipped/disabled.") - return None - - with pytest.MonkeyPatch.context() as ctx: - ctx.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0") - return func(*args, **kwargs) - - error_messages = [] - - @functools.wraps(func) - def wrapper(*args, **kwargs): - for use_rocm_aiter in (True, False): - try: - test_case(use_rocm_aiter, *args, **kwargs) - except Exception: - import traceback - error_type = (f"With ROCM_AITER=" - f"{'ON' if use_rocm_aiter else 'OFF'}") - error_trace = traceback.format_exc() - error_messages.append(f"{error_type}:\n{error_trace}") - - if error_messages: - raise Exception('\n\n' + '=' * 50 + '\n\n'.join(error_messages)) - - return wrapper From 9087f44eed6242bf85bd69e5632e7372b246ea52 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 26 Mar 2025 05:15:53 +0000 Subject: [PATCH 11/18] match the tests files and run-amd-test script to the main branch Signed-off-by: vllmellm --- .buildkite/run-amd-test.sh | 4 -- .../decoder_only/language/test_mistral.py | 41 +++++++++++++++---- .../decoder_only/language/test_phimoe.py | 11 ++++- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 2e15533ffcf8..0680bae13ddb 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -72,10 +72,6 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" -# environment variables -SKIP_ROCM_ATIER_MODEL_TEST_CASES="True" -echo $SKIP_ROCM_ATIER_MODEL_TEST_CASES - commands=$@ echo "Commands:$commands" #ignore certain kernels tests diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index ec885386dd94..4c2055361d44 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -174,8 +174,15 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -199,8 +206,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with vllm_runner( model, dtype=dtype, @@ -231,8 +244,11 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages(vllm_runner, model: str, - dtype: str) -> None: +def test_mistral_symbolic_languages( + vllm_runner, + model: str, + dtype: str, +) -> None: with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -250,7 +266,11 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: +def test_mistral_function_calling( + vllm_runner, + model: str, + dtype: str, +) -> None: with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -281,8 +301,11 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding(vllm_runner, model: str, - guided_backend: str) -> None: +def test_mistral_guided_decoding( + vllm_runner, + model: str, + guided_backend: str, +) -> None: with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 5e43f20bd2b1..f9757d6ac295 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -79,8 +79,15 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) From 052d9e045eb0ab7558e7328e841c5b47084f3ec0 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 21 Apr 2025 06:50:19 +0000 Subject: [PATCH 12/18] import AITERPagedAttention only if flag is set Signed-off-by: vllmellm --- vllm/attention/backends/rocm_flash_attn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index ed08c0f3c2dc..f9a7d332f384 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -15,7 +15,6 @@ CommonMetadataBuilder) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) -from vllm.attention.ops.rocm_aiter_paged_attn import AITERPagedAttention from vllm.logger import init_logger from vllm.platforms import current_platform @@ -47,6 +46,9 @@ def _get_paged_attn_module() -> PagedAttention: - Otherwise, it defaults to using the original `PagedAttention`. """ if is_rocm_aiter_paged_attn_enabled(): + # Import AITERPagedAttention only when the flag is enabled + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AITERPagedAttention) return AITERPagedAttention() return PagedAttention() From 15862f1a1370667e0bb446c5a4ee41eba62d7ba9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 21 Apr 2025 06:53:30 +0000 Subject: [PATCH 13/18] prefer current_platform.fp8_dtype over the harcoded dtype Signed-off-by: vllmellm --- vllm/attention/ops/rocm_aiter_paged_attn.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index 9b7a0ce3f1b6..cf9c2cd1846b 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -5,6 +5,9 @@ import torch from vllm.attention.ops.paged_attn import PagedAttention +from vllm.platforms import current_platform + +FP8_DTYPE = current_platform.fp8_dtype() class AITERPagedAttention(PagedAttention): @@ -26,12 +29,11 @@ def write_to_paged_cache( kv_cache_dtype, k_scale, v_scale) else: - if "fp8" in kv_cache_dtype: - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) - else: - key_cache = key_cache.view(torch.int8) - value_cache = value_cache.view(torch.int8) + kv_cache_torch_dtype = (FP8_DTYPE + if "fp8" in kv_cache_dtype else torch.int8) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + rocm_aiter.reshape_and_cache_with_pertoken_quant( key, value, key_cache, value_cache, k_scale, v_scale, slot_mapping.flatten(), True) @@ -95,4 +97,4 @@ def forward_decode( rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, seq_lens, max_num_blocks_per_seq, k_scale, v_scale, output) - return output \ No newline at end of file + return output From 15406cb710bf5a55093bcc827ace507ae4320ac6 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 21 Apr 2025 14:59:08 +0000 Subject: [PATCH 14/18] cache aiter pa import Signed-off-by: vllmellm --- vllm/attention/backends/rocm_flash_attn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7b1ad9fba193..37b6cadcb98a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -2,6 +2,7 @@ """Attention layer ROCm GPUs.""" import itertools from dataclasses import dataclass +from functools import cache from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -26,11 +27,13 @@ _PARTITION_SIZE_ROCM = 256 +@cache def is_rocm_aiter_paged_attn_enabled() -> bool: return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ and envs.VLLM_ROCM_USE_AITER \ +@cache def _get_paged_attn_module() -> PagedAttention: """ Initializes the appropriate PagedAttention module from `attention/ops`, From a9ef9f977f4960e434e9bc07d19b86af34e46ba1 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 21 Apr 2025 16:32:08 +0000 Subject: [PATCH 15/18] update aiter commit Signed-off-by: vllmellm --- docker/Dockerfile.rocm_base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index b8523fbc2a01..1776b26d445c 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="8970b25b" +ARG AITER_BRANCH="7e1ed08" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base From e203aed95ffeb1491610323e571375dc94490238 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 21 Apr 2025 16:45:41 +0000 Subject: [PATCH 16/18] correct comment Signed-off-by: vllmellm --- vllm/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 18a9d584d0e0..79abe0ce2b99 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -535,7 +535,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: ("true", "1")), # Whether to use aiter paged attention. - # By default is enabled. + # By default is diabled. "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1")), From 976da61452c570f7ba90c2497d515facd04ccc47 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 22 Apr 2025 03:10:28 +0000 Subject: [PATCH 17/18] fix spelling mistake Signed-off-by: vllmellm --- vllm/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 79abe0ce2b99..7ef4ab429645 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -535,7 +535,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: ("true", "1")), # Whether to use aiter paged attention. - # By default is diabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1")), From 0f5f2d0c403570d41503656cd35b821039ed3219 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 22 Apr 2025 04:39:08 +0000 Subject: [PATCH 18/18] prefer utils cdiv Signed-off-by: vllmellm --- vllm/attention/ops/rocm_aiter_paged_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index cf9c2cd1846b..0f3cf1842c80 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -6,6 +6,7 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.platforms import current_platform +from vllm.utils import cdiv FP8_DTYPE = current_platform.fp8_dtype() @@ -92,7 +93,7 @@ def forward_decode( output = torch.empty_like(query) block_size = value_cache.shape[3] - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + max_num_blocks_per_seq = cdiv(max_seq_len, block_size) rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, seq_lens, max_num_blocks_per_seq, k_scale,