Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def main():
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4)
max_num_seqs=4,
max_model_len=128)
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output, answer in zip(outputs, answers):
Expand Down
10 changes: 5 additions & 5 deletions requirements/tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev20250408
torchvision==0.22.0.dev20250408
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.8.0.dev20250430
torchvision==0.22.0.dev20250430
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

10 changes: 6 additions & 4 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel

cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

compilation_config = vllm_config.compilation_config

# TPU only supports DYNAMO_ONCE compilation level
Expand All @@ -98,16 +98,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config)
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > vllm_config.cache_config.block_size:
if min_page_size > cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
vllm_config.cache_config.block_size,
cache_config.block_size,
min_page_size,
)
vllm_config.cache_config.block_size = min_page_size
cache_config.block_size = min_page_size

parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
Expand Down
7 changes: 7 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,13 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b)


def next_power_of_2(n) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()


def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

Expand Down
16 changes: 15 additions & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.utils import cdiv, next_power_of_2

logger = init_logger(__name__)

Expand Down Expand Up @@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size

# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
page_size = next_power_of_2(
vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size


@dataclass
class PallasMetadata:
Expand Down