Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
d5c329d
adapt for dynamo
youkaichao Oct 3, 2024
12e29fe
fix tpu
youkaichao Oct 3, 2024
504bd6c
add backend
youkaichao Oct 3, 2024
6353613
add use_custom_dispatcher
youkaichao Oct 3, 2024
77ae8e7
update wrapper
youkaichao Oct 3, 2024
4d99a58
update envs
youkaichao Oct 3, 2024
2b79376
update custom op
youkaichao Oct 3, 2024
7dfddcd
support llama
youkaichao Oct 3, 2024
abd1a65
update plugins
youkaichao Oct 3, 2024
ce1907f
update model runner
youkaichao Oct 3, 2024
e1ea867
add support
youkaichao Oct 3, 2024
511e07b
add files
youkaichao Oct 3, 2024
3bb8950
fix not use_custom_dispatcher
youkaichao Oct 4, 2024
c4d7189
Merge branch 'main' into compile_integration
youkaichao Oct 5, 2024
ed573fa
do not test inductor
youkaichao Oct 5, 2024
93ef0b5
add compile context
youkaichao Oct 5, 2024
3cd40db
remove model reference
youkaichao Oct 5, 2024
4e28930
lint
youkaichao Oct 5, 2024
2ac7274
change levels
youkaichao Oct 7, 2024
34fe820
Merge branch 'main' into compile_integration
youkaichao Oct 8, 2024
a3c947e
add levels
youkaichao Oct 8, 2024
1a41c57
use const
youkaichao Oct 8, 2024
db61567
use const
youkaichao Oct 8, 2024
275ede9
use const
youkaichao Oct 8, 2024
d1f084d
use const
youkaichao Oct 8, 2024
326c5b4
use const
youkaichao Oct 8, 2024
9b7b0f3
use const
youkaichao Oct 8, 2024
9cfa70c
use const
youkaichao Oct 8, 2024
e819be7
use const
youkaichao Oct 8, 2024
d9cb162
use const
youkaichao Oct 8, 2024
825f384
use const
youkaichao Oct 8, 2024
c785fc8
use const
youkaichao Oct 8, 2024
28e9f6f
restore
youkaichao Oct 8, 2024
718c5e4
use const
youkaichao Oct 8, 2024
03081cd
use const
youkaichao Oct 8, 2024
fbac08d
error on inductor for tpu
youkaichao Oct 8, 2024
3c688ea
fix llava
youkaichao Oct 8, 2024
32676f8
restore tpu
youkaichao Oct 8, 2024
5ae34df
Merge branch 'main' into compile_integration
youkaichao Oct 8, 2024
3ed89da
adjust for tpu
youkaichao Oct 8, 2024
a3c3e21
fix env var
youkaichao Oct 8, 2024
30ff04f
fix calling
youkaichao Oct 8, 2024
13256c4
revert tpu
youkaichao Oct 8, 2024
bf0e935
revert utils
youkaichao Oct 8, 2024
39571c5
fix typo
youkaichao Oct 8, 2024
e3aea56
add typing
youkaichao Oct 8, 2024
6181795
move DYNAMO_AS_IS to model runner level
youkaichao Oct 8, 2024
1a80a7b
fix default context
youkaichao Oct 8, 2024
92d240b
use eager for DYNAMO_AS_IS by default
youkaichao Oct 8, 2024
f4b0f50
update tests
youkaichao Oct 8, 2024
896431a
update tests
youkaichao Oct 8, 2024
388d563
llava uses fullgraph=false
youkaichao Oct 8, 2024
3642b77
Merge branch 'main' into compile_integration
youkaichao Oct 9, 2024
3e3ea58
Merge branch 'main' into compile_integration
youkaichao Oct 10, 2024
ce7cd8e
disable tests first
youkaichao Oct 10, 2024
ab41d84
Merge branch 'main' into compile_integration
youkaichao Oct 10, 2024
d1f8ae8
add supports_dynamo in the decorator
youkaichao Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ steps:
- vllm/core/
- tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py

Expand Down Expand Up @@ -231,14 +233,16 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py
- pytest -v -s compile/test_basic_correctness.py

- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
# TODO: re-write in comparison tests, and fix symbolic shape
# for quantization ops.
# - label: "PyTorch Fullgraph Test" # 18min
# source_file_dependencies:
# - vllm/
# - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py

- label: Kernels Test %N # 1h each
mirror_hardwares: [amd]
Expand Down Expand Up @@ -394,7 +398,7 @@ steps:
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
Expand Down
48 changes: 48 additions & 0 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Dict, List, Optional

import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.utils import cuda_device_count_stateless

from ..utils import compare_all_settings


# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
if not fullgraph:
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
all_envs: List[Optional[Dict[str, str]]] = [{
"VLLM_TORCH_COMPILE_LEVEL":
str(level)
} for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]]
compare_all_settings(model, all_args, all_envs, method=method)
15 changes: 11 additions & 4 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import pytest

from vllm.compilation.backends import vllm_backend
from vllm.compilation.levels import CompilationLevel

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support


@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
@fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1)
22 changes: 0 additions & 22 deletions tests/compile/test_full_graph_multi_gpu.py

This file was deleted.

13 changes: 0 additions & 13 deletions tests/compile/test_full_graph_smoke.py

This file was deleted.

24 changes: 9 additions & 15 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip

TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
Expand Down Expand Up @@ -68,20 +61,21 @@
}))


def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
def check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"

# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
or quantization == "gptq_marlin_24"
) and optimization_level >= CompilationLevel.INDUCTOR:
return

set_torch_compile_backend(backend)

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand Down
4 changes: 3 additions & 1 deletion tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import depyf

from vllm.compilation.levels import CompilationLevel

# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
Expand Down
13 changes: 8 additions & 5 deletions tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from vllm.compilation.levels import CompilationLevel

from ..utils import compare_two_settings

# --enforce-eager on TPU causes graph compilation
Expand All @@ -9,8 +11,9 @@


def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
env2={})
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
115 changes: 114 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import copy
import operator
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.fx as fx

from vllm.logger import init_logger

from .compile_context import get_compile_context
from .levels import CompilationLevel

logger = init_logger(__name__)


def fix_functionalization(graph: fx.Graph):
"""
Expand Down Expand Up @@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f)


def vllm_backend(graph, example_inputs):
def wrap_inductor(graph, example_inputs, additional_inductor_config):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx

if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization")
Comment on lines +165 to +170
Copy link
Collaborator

@ProExpertProg ProExpertProg Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be common to add custom passes - I think we should handle this by wrapping the provided pass instead of overwriting:

Suggested change
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization")
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
def noop_pass(graph: fx.Graph):
pass
old_post_pass = current_config['post_grad_custom_post_pass'] or noop_pass
def post_grad_custom_post_pass(graph: fx.Graph):
old_pass(graph)
fix_functionalization(graph)
current_config['post_grad_custom_post_pass'] = post_grad_custom_post_pass

current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)


def vllm_backend(
graph,
example_inputs,
additional_inductor_config: Optional[Dict] = None) -> Callable:

context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context

# flags for all the seen shapes, whether we need to specialize
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}

# if we need to specialize, the compiled graph for that shape
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}

# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
logger.info("Compiling a graph for general shapes")
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
additional_inductor_config)

# TODO: Dynamo does not pass all dynamic shapes.
# Need to investigate why. It works now because all the dynamic
# shapes have the same value, and either of them can be used.
sym_shape_indices = [
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
]

first_run = True

# this is the function we return to Dynamo to run finally
def compiled_graph_wrapper(*args):

runtime_shapes: Tuple[int,
...] = tuple(args[i] for i in sym_shape_indices)

nonlocal first_run
nonlocal runtime_shapes_to_compile_flags
nonlocal runtime_shapes_to_compiled_graph

if first_run:
# the first compilation is for profiling, we directly run it
first_run = False
return graph_for_symbolic_shape(*args)

if runtime_shapes not in runtime_shapes_to_compile_flags:
# we haven't seen this shape before
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
0] in sizes_to_specialize

if not runtime_shapes_to_compile_flags[runtime_shapes]:
# we don't need to specialize for this shape
return graph_for_symbolic_shape(*args)

if runtime_shapes not in runtime_shapes_to_compiled_graph:
# we need to specialize for this shape, and we haven't compiled
# compile the graph for this shape
logger.info("Compiling a graph for shapes %s", runtime_shapes)
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
graph, args, additional_inductor_config)

return runtime_shapes_to_compiled_graph[runtime_shapes](*args)

return compiled_graph_wrapper


def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager"
return backend
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"

from vllm.compilation.backends import vllm_backend
from vllm.plugins import get_inductor_additional_configs
additional_configs = get_inductor_additional_configs()

if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
if "max_autotune" in additional_configs and not additional_configs[
"max_autotune"]:
logger.warning(
"max_autotune is disabled, but is overridden by level %s",
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
additional_configs['max_autotune'] = True

from functools import partial
backend = partial(vllm_backend,
additional_inductor_config=additional_configs)

return backend
23 changes: 23 additions & 0 deletions vllm/compilation/compile_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from contextlib import contextmanager
from typing import Any

_compile_context: Any = None


def get_compile_context() -> Any:
"""Get the current compile context."""
return _compile_context


@contextmanager
def set_compile_context(context: Any):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
Comment on lines +14 to +15
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit vague - could you improve the comment to add who uses this context and when?

"""
global _compile_context
prev_context = _compile_context
_compile_context = context
try:
yield
finally:
_compile_context = prev_context
Loading
Loading