Skip to content

Commit e4d652e

Browse files
authored
[torch.compile] integration with compilation control (#9058)
1 parent 78c0b41 commit e4d652e

22 files changed

+404
-98
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ steps:
121121
- vllm/core/
122122
- tests/distributed
123123
- tests/spec_decode/e2e/test_integration_dist_tp4
124+
- tests/compile
124125
commands:
126+
- pytest -v -s compile/test_basic_correctness.py
125127
- pytest -v -s distributed/test_pynccl.py
126128
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
127129

@@ -231,14 +233,16 @@ steps:
231233
- vllm/
232234
- tests/compile
233235
commands:
234-
- pytest -v -s compile/test_full_graph_smoke.py
236+
- pytest -v -s compile/test_basic_correctness.py
235237

236-
- label: "PyTorch Fullgraph Test" # 18min
237-
source_file_dependencies:
238-
- vllm/
239-
- tests/compile
240-
commands:
241-
- pytest -v -s compile/test_full_graph.py
238+
# TODO: re-write in comparison tests, and fix symbolic shape
239+
# for quantization ops.
240+
# - label: "PyTorch Fullgraph Test" # 18min
241+
# source_file_dependencies:
242+
# - vllm/
243+
# - tests/compile
244+
# commands:
245+
# - pytest -v -s compile/test_full_graph.py
242246

243247
- label: Kernels Test %N # 1h each
244248
mirror_hardwares: [amd]
@@ -394,7 +398,7 @@ steps:
394398
- tests/distributed/
395399
- vllm/compilation
396400
commands:
397-
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
401+
- pytest -v -s ./compile/test_basic_correctness.py
398402
- pytest -v -s ./compile/test_wrapper.py
399403
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
400404
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Dict, List, Optional
2+
3+
import pytest
4+
5+
from vllm.compilation.levels import CompilationLevel
6+
from vllm.utils import cuda_device_count_stateless
7+
8+
from ..utils import compare_all_settings
9+
10+
11+
# we cannot afford testing the full Catesian product
12+
# of all models and all levels
13+
@pytest.mark.parametrize(
14+
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
15+
[
16+
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
17+
True),
18+
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
19+
["--quantization", "compressed-tensors"
20+
], 1, 1, "FLASH_ATTN", "generate", True),
21+
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
22+
# TODO: add multi-modality test for llava
23+
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
24+
])
25+
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
26+
method, fullgraph):
27+
# this test is run under multiple suits, with different GPUs.
28+
# make sure we only run the test with correct CUDA devices.
29+
# don't use "<", as it will duplicate the tests.
30+
if cuda_device_count_stateless() != pp_size * tp_size:
31+
pytest.skip("Not correct CUDA devices for the test.")
32+
import os
33+
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
34+
if not fullgraph:
35+
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
36+
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
37+
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
38+
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
39+
# inductor will change the output, so we cannot compare them.
40+
all_envs: List[Optional[Dict[str, str]]] = [{
41+
"VLLM_TORCH_COMPILE_LEVEL":
42+
str(level)
43+
} for level in [
44+
CompilationLevel.NO_COMPILATION,
45+
CompilationLevel.DYNAMO_AS_IS,
46+
CompilationLevel.DYNAMO_ONCE,
47+
]]
48+
compare_all_settings(model, all_args, all_envs, method=method)

tests/compile/test_full_graph.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
import pytest
22

3-
from vllm.compilation.backends import vllm_backend
3+
from vllm.compilation.levels import CompilationLevel
44

5+
from ..utils import fork_new_process_for_each_test
56
from .utils import TEST_MODELS, check_full_graph_support
67

78

89
@pytest.mark.parametrize("model_info", TEST_MODELS)
9-
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
10-
def test_full_graph(model_info, backend):
10+
@pytest.mark.parametrize(
11+
"optimization_level",
12+
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
13+
@fork_new_process_for_each_test
14+
def test_full_graph(model_info, optimization_level):
1115
model = model_info[0]
1216
model_kwargs = model_info[1]
13-
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
17+
check_full_graph_support(model,
18+
model_kwargs,
19+
optimization_level,
20+
tp_size=1)

tests/compile/test_full_graph_multi_gpu.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

tests/compile/test_full_graph_smoke.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/compile/utils.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,9 @@
44

55
from tests.quantization.utils import is_quant_method_supported
66
from vllm import LLM, SamplingParams
7-
from vllm.plugins import set_torch_compile_backend
7+
from vllm.compilation.levels import CompilationLevel
88
from vllm.utils import is_hip
99

10-
TEST_MODELS_SMOKE = [
11-
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
12-
"quantization": "compressed-tensors"
13-
}),
14-
("meta-llama/Meta-Llama-3-8B", {}),
15-
]
16-
1710
TEST_MODELS = [
1811
("facebook/opt-125m", {}),
1912
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@@ -68,20 +61,21 @@
6861
}))
6962

7063

71-
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
64+
def check_full_graph_support(model,
65+
model_kwargs,
66+
optimization_level,
67+
tp_size=1):
7268
# make sure these models can be captured in full graph mode
73-
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
74-
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
75-
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
69+
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
70+
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
7671

7772
# Inductor doesn't support fp8/gptq_marlin_24 yet.
7873
quantization = model_kwargs.get("quantization")
7974
if (quantization == "fp8" or quantization == "gptq_marlin"
80-
or quantization == "gptq_marlin_24") and backend != "eager":
75+
or quantization == "gptq_marlin_24"
76+
) and optimization_level >= CompilationLevel.INDUCTOR:
8177
return
8278

83-
set_torch_compile_backend(backend)
84-
8579
prompts = [
8680
"Hello, my name is",
8781
"The president of the United States is",

tests/tpu/test_compilation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
import depyf
77

8+
from vllm.compilation.levels import CompilationLevel
9+
810
# disable custom dispatcher, let Dynamo takes over
911
# all the control
10-
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
12+
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
1113

1214
temp_dir = tempfile.mkdtemp()
1315
with depyf.prepare_debug(temp_dir):
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from vllm.compilation.levels import CompilationLevel
4+
35
from ..utils import compare_two_settings
46

57
# --enforce-eager on TPU causes graph compilation
@@ -9,8 +11,9 @@
911

1012

1113
def test_custom_dispatcher():
12-
compare_two_settings("google/gemma-2b",
13-
arg1=["--enforce-eager"],
14-
arg2=["--enforce-eager"],
15-
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
16-
env2={})
14+
compare_two_settings(
15+
"google/gemma-2b",
16+
arg1=["--enforce-eager"],
17+
arg2=["--enforce-eager"],
18+
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
19+
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})

vllm/compilation/backends.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
import copy
12
import operator
3+
from typing import Callable, Dict, List, Optional, Tuple, Union
24

35
import torch
46
import torch.fx as fx
57

8+
from vllm.logger import init_logger
9+
10+
from .compile_context import get_compile_context
11+
from .levels import CompilationLevel
12+
13+
logger = init_logger(__name__)
14+
615

716
def fix_functionalization(graph: fx.Graph):
817
"""
@@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
148157
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
149158

150159

151-
def vllm_backend(graph, example_inputs):
160+
def wrap_inductor(graph, example_inputs, additional_inductor_config):
152161
from torch._inductor import config
153162
current_config = config.shallow_copy_dict()
154163
from torch._inductor.compile_fx import compile_fx
164+
165+
if additional_inductor_config is not None:
166+
current_config.update(additional_inductor_config)
167+
if current_config['post_grad_custom_post_pass'] is not None:
168+
logger.warning(
169+
"post_grad_custom_post_pass is already set in the config. "
170+
"Overwriting it with the fix_functionalization")
155171
current_config['post_grad_custom_post_pass'] = fix_functionalization
156172
return compile_fx(graph, example_inputs, config_patches=current_config)
173+
174+
175+
def vllm_backend(
176+
graph,
177+
example_inputs,
178+
additional_inductor_config: Optional[Dict] = None) -> Callable:
179+
180+
context = get_compile_context()
181+
context = copy.deepcopy(context) if context is not None else []
182+
sizes_to_specialize: List[int] = context
183+
184+
# flags for all the seen shapes, whether we need to specialize
185+
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
186+
187+
# if we need to specialize, the compiled graph for that shape
188+
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
189+
190+
# this is the first compilation, we will compile a graph with
191+
# dynamic shape, as the caller will mark first dimension as dynamic
192+
logger.info("Compiling a graph for general shapes")
193+
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
194+
additional_inductor_config)
195+
196+
# TODO: Dynamo does not pass all dynamic shapes.
197+
# Need to investigate why. It works now because all the dynamic
198+
# shapes have the same value, and either of them can be used.
199+
sym_shape_indices = [
200+
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
201+
]
202+
203+
first_run = True
204+
205+
# this is the function we return to Dynamo to run finally
206+
def compiled_graph_wrapper(*args):
207+
208+
runtime_shapes: Tuple[int,
209+
...] = tuple(args[i] for i in sym_shape_indices)
210+
211+
nonlocal first_run
212+
nonlocal runtime_shapes_to_compile_flags
213+
nonlocal runtime_shapes_to_compiled_graph
214+
215+
if first_run:
216+
# the first compilation is for profiling, we directly run it
217+
first_run = False
218+
return graph_for_symbolic_shape(*args)
219+
220+
if runtime_shapes not in runtime_shapes_to_compile_flags:
221+
# we haven't seen this shape before
222+
# query if we need to specialize for this shape
223+
# we only specialize for the first dimension.
224+
# TODO: investigate if any model needs to specialize
225+
# beyond the first dimension
226+
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
227+
0] in sizes_to_specialize
228+
229+
if not runtime_shapes_to_compile_flags[runtime_shapes]:
230+
# we don't need to specialize for this shape
231+
return graph_for_symbolic_shape(*args)
232+
233+
if runtime_shapes not in runtime_shapes_to_compiled_graph:
234+
# we need to specialize for this shape, and we haven't compiled
235+
# compile the graph for this shape
236+
logger.info("Compiling a graph for shapes %s", runtime_shapes)
237+
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
238+
graph, args, additional_inductor_config)
239+
240+
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
241+
242+
return compiled_graph_wrapper
243+
244+
245+
def select_default_backend(level: int) -> Union[str, Callable]:
246+
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
247+
backend = "eager"
248+
return backend
249+
assert level in [
250+
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
251+
], f"Invalid level {level}"
252+
253+
from vllm.compilation.backends import vllm_backend
254+
from vllm.plugins import get_inductor_additional_configs
255+
additional_configs = get_inductor_additional_configs()
256+
257+
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
258+
if "max_autotune" in additional_configs and not additional_configs[
259+
"max_autotune"]:
260+
logger.warning(
261+
"max_autotune is disabled, but is overridden by level %s",
262+
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
263+
additional_configs['max_autotune'] = True
264+
265+
from functools import partial
266+
backend = partial(vllm_backend,
267+
additional_inductor_config=additional_configs)
268+
269+
return backend
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from contextlib import contextmanager
2+
from typing import Any
3+
4+
_compile_context: Any = None
5+
6+
7+
def get_compile_context() -> Any:
8+
"""Get the current compile context."""
9+
return _compile_context
10+
11+
12+
@contextmanager
13+
def set_compile_context(context: Any):
14+
"""A context manager that stores the current compile context,
15+
usually it is a list of sizes to specialize.
16+
"""
17+
global _compile_context
18+
prev_context = _compile_context
19+
_compile_context = context
20+
try:
21+
yield
22+
finally:
23+
_compile_context = prev_context

0 commit comments

Comments
 (0)