From b6e7d0b6533fdd37d88e4acf8695e40174723b83 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 2 May 2025 05:46:15 +0000 Subject: [PATCH] Set pagesize based on max-model-len Signed-off-by: Jevin Jiang --- examples/offline_inference/tpu.py | 3 ++- requirements/tpu.txt | 10 +++++----- vllm/platforms/tpu.py | 10 ++++++---- vllm/utils.py | 7 +++++++ vllm/v1/attention/backends/pallas.py | 16 +++++++++++++++- 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index dea717c36082..71cd88f2788a 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -22,7 +22,8 @@ def main(): # In real workloads, `enforace_eager` should be `False`. llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_num_batched_tokens=64, - max_num_seqs=4) + max_num_seqs=4, + max_model_len=128) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers): diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 17d57058bfa8..11501bc5d92f 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250408 -torchvision==0.22.0.dev20250408 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250430 +torchvision==0.22.0.dev20250430 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d5923557a211..27284da97a8d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -73,9 +73,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel cache_config = vllm_config.cache_config + # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_ONCE compilation level @@ -98,16 +98,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_USE_V1: from vllm.v1.attention.backends.pallas import ( PallasAttentionBackend) + cache_config.block_size = PallasAttentionBackend.get_page_size( + vllm_config) min_page_size = PallasAttentionBackend.get_min_page_size( vllm_config) - if min_page_size > vllm_config.cache_config.block_size: + if min_page_size > cache_config.block_size: logger.warning( "Increase the page size from %s to %s to make sure there's" "no SMEM OOM", - vllm_config.cache_config.block_size, + cache_config.block_size, min_page_size, ) - vllm_config.cache_config.block_size = min_page_size + cache_config.block_size = min_page_size parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config diff --git a/vllm/utils.py b/vllm/utils.py index f85bbe3a5990..011fd73f942d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -704,6 +704,13 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) +def next_power_of_2(n) -> int: + """The next power of 2 (inclusive)""" + if n < 1: + return 1 + return 1 << (n - 1).bit_length() + + def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 05b97172bc6c..79ec67b89e97 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -12,7 +12,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv +from vllm.utils import cdiv, next_power_of_2 logger = init_logger(__name__) @@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int: min_page_size = 1 << (min_page_size - 1).bit_length() return min_page_size + # TPU has limited SREGs (scalar registers), if page_size is too small, we + # can spill SREGs easily which leads to bad performance. The strategy we + # apply here is trying to split max-model-len to 16 pages which make the + # spill less likely. Meanwhile we make sure the page size is in [16, 256]. + @staticmethod + def get_page_size(vllm_config: VllmConfig) -> int: + page_size = next_power_of_2( + vllm_config.model_config.max_model_len) // 16 + if page_size <= 16: + return 16 + if page_size >= 256: + return 256 + return page_size + @dataclass class PallasMetadata: