Skip to content

Commit fad5576

Browse files
authored
[TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856)
1 parent f954d07 commit fad5576

File tree

6 files changed

+24
-7
lines changed

6 files changed

+24
-7
lines changed

Dockerfile.tpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG NIGHTLY_DATE="20240713"
1+
ARG NIGHTLY_DATE="20240726"
22
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
33

44
FROM $BASE_IMAGE

docs/source/getting_started/tpu-installation.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ First, install the dependencies:
5656
$ pip uninstall torch torch-xla -y
5757
5858
$ # Install PyTorch and PyTorch XLA.
59-
$ export DATE="+20240713"
59+
$ export DATE="+20240726"
6060
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
6161
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
6262
@@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
7575
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
7676
7777
78+
.. note::
79+
80+
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
81+
The compilation time may take 20~30 minutes in the first run.
82+
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).
83+
84+
7885
.. tip::
7986

8087
If you encounter the following error:

vllm/attention/backends/pallas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
import torch_xla.experimental.custom_kernel # Required to register custom ops.
6-
import torch_xla.experimental.dynamo_set_buffer_donor
76

87
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
98
AttentionMetadata, AttentionType)

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
if current_platform.is_tpu():
88
import torch_xla.core.xla_model as xm
9+
import torch_xla.runtime as xr
910
from torch_xla._internal import pjrt
1011

1112

@@ -20,7 +21,7 @@ def __init__(self, group: ProcessGroup):
2021
local_rank = dist.get_rank(group)
2122
world_size = dist.get_world_size(group)
2223
pjrt.initialize_multiprocess(local_rank, world_size)
23-
xm._init_world_size_ordinal()
24+
xr._init_world_size_ordinal()
2425

2526
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
2627
return xm.all_reduce(xm.REDUCE_SUM, x)

vllm/worker/tpu_model_runner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn as nn
99
import torch_xla.core.xla_model as xm
10+
import torch_xla.runtime as xr
1011

1112
from vllm.attention import AttentionMetadata, get_attn_backend
1213
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
@@ -127,7 +128,7 @@ def load_model(self) -> None:
127128
# determine the order of concatenating the output tensors.
128129
# As a workaround, we use the xm's rank assignment only when loading
129130
# the embedding weights.
130-
xm_tp_rank = xm.get_ordinal()
131+
xm_tp_rank = xr.global_ordinal()
131132
with patch(
132133
"vllm.model_executor.layers.vocab_parallel_embedding."
133134
"get_tensor_model_parallel_rank",
@@ -146,7 +147,17 @@ def load_model(self) -> None:
146147
xm.wait_device_ops()
147148

148149
model = ModelWrapper(model)
149-
self.model = torch.compile(model, backend="openxla", fullgraph=True)
150+
# NOTE(woosuk): There are two stages of compilation: torch.compile and
151+
# XLA compilation. Setting dynamic=True can reduce the torch.compile
152+
# overhead by reusing the FX graph for different shapes.
153+
# However, the XLA graph will still require static shapes and needs to
154+
# be re-compiled for every different shapes. This overhead is inevitable
155+
# in the first run, but can be skipped afterwards as we cache the XLA
156+
# graphs in the disk (VLLM_XLA_CACHE_PATH).
157+
self.model = torch.compile(model,
158+
backend="openxla",
159+
fullgraph=True,
160+
dynamic=True)
150161

151162
def _dummy_run(
152163
self,

vllm/worker/tpu_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
import torch_xla.core.xla_model as xm
6-
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
76
import torch_xla.runtime as xr
87

98
import vllm.envs as envs

0 commit comments

Comments
 (0)