-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[torch.compile] integration with compilation control #9058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
youkaichao
merged 57 commits into
vllm-project:main
from
youkaichao:compile_integration
Oct 10, 2024
Merged
Changes from all commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
d5c329d
adapt for dynamo
youkaichao 12e29fe
fix tpu
youkaichao 504bd6c
add backend
youkaichao 6353613
add use_custom_dispatcher
youkaichao 77ae8e7
update wrapper
youkaichao 4d99a58
update envs
youkaichao 2b79376
update custom op
youkaichao 7dfddcd
support llama
youkaichao abd1a65
update plugins
youkaichao ce1907f
update model runner
youkaichao e1ea867
add support
youkaichao 511e07b
add files
youkaichao 3bb8950
fix not use_custom_dispatcher
youkaichao c4d7189
Merge branch 'main' into compile_integration
youkaichao ed573fa
do not test inductor
youkaichao 93ef0b5
add compile context
youkaichao 3cd40db
remove model reference
youkaichao 4e28930
lint
youkaichao 2ac7274
change levels
youkaichao 34fe820
Merge branch 'main' into compile_integration
youkaichao a3c947e
add levels
youkaichao 1a41c57
use const
youkaichao db61567
use const
youkaichao 275ede9
use const
youkaichao d1f084d
use const
youkaichao 326c5b4
use const
youkaichao 9b7b0f3
use const
youkaichao 9cfa70c
use const
youkaichao e819be7
use const
youkaichao d9cb162
use const
youkaichao 825f384
use const
youkaichao c785fc8
use const
youkaichao 28e9f6f
restore
youkaichao 718c5e4
use const
youkaichao 03081cd
use const
youkaichao fbac08d
error on inductor for tpu
youkaichao 3c688ea
fix llava
youkaichao 32676f8
restore tpu
youkaichao 5ae34df
Merge branch 'main' into compile_integration
youkaichao 3ed89da
adjust for tpu
youkaichao a3c3e21
fix env var
youkaichao 30ff04f
fix calling
youkaichao 13256c4
revert tpu
youkaichao bf0e935
revert utils
youkaichao 39571c5
fix typo
youkaichao e3aea56
add typing
youkaichao 6181795
move DYNAMO_AS_IS to model runner level
youkaichao 1a80a7b
fix default context
youkaichao 92d240b
use eager for DYNAMO_AS_IS by default
youkaichao f4b0f50
update tests
youkaichao 896431a
update tests
youkaichao 388d563
llava uses fullgraph=false
youkaichao 3642b77
Merge branch 'main' into compile_integration
youkaichao 3e3ea58
Merge branch 'main' into compile_integration
youkaichao ce7cd8e
disable tests first
youkaichao ab41d84
Merge branch 'main' into compile_integration
youkaichao d1f8ae8
add supports_dynamo in the decorator
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import pytest | ||
|
||
from vllm.compilation.levels import CompilationLevel | ||
from vllm.utils import cuda_device_count_stateless | ||
|
||
from ..utils import compare_all_settings | ||
|
||
|
||
# we cannot afford testing the full Catesian product | ||
# of all models and all levels | ||
@pytest.mark.parametrize( | ||
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", | ||
[ | ||
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate", | ||
True), | ||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", | ||
["--quantization", "compressed-tensors" | ||
], 1, 1, "FLASH_ATTN", "generate", True), | ||
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True), | ||
# TODO: add multi-modality test for llava | ||
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False) | ||
]) | ||
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, | ||
method, fullgraph): | ||
# this test is run under multiple suits, with different GPUs. | ||
# make sure we only run the test with correct CUDA devices. | ||
# don't use "<", as it will duplicate the tests. | ||
if cuda_device_count_stateless() != pp_size * tp_size: | ||
pytest.skip("Not correct CUDA devices for the test.") | ||
import os | ||
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend | ||
if not fullgraph: | ||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" | ||
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] | ||
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 | ||
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case | ||
# inductor will change the output, so we cannot compare them. | ||
all_envs: List[Optional[Dict[str, str]]] = [{ | ||
"VLLM_TORCH_COMPILE_LEVEL": | ||
str(level) | ||
} for level in [ | ||
CompilationLevel.NO_COMPILATION, | ||
CompilationLevel.DYNAMO_AS_IS, | ||
CompilationLevel.DYNAMO_ONCE, | ||
]] | ||
compare_all_settings(model, all_args, all_envs, method=method) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
import pytest | ||
|
||
from vllm.compilation.backends import vllm_backend | ||
from vllm.compilation.levels import CompilationLevel | ||
|
||
from ..utils import fork_new_process_for_each_test | ||
from .utils import TEST_MODELS, check_full_graph_support | ||
|
||
|
||
@pytest.mark.parametrize("model_info", TEST_MODELS) | ||
@pytest.mark.parametrize("backend", ["eager", vllm_backend]) | ||
def test_full_graph(model_info, backend): | ||
@pytest.mark.parametrize( | ||
"optimization_level", | ||
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR]) | ||
@fork_new_process_for_each_test | ||
def test_full_graph(model_info, optimization_level): | ||
model = model_info[0] | ||
model_kwargs = model_info[1] | ||
check_full_graph_support(model, model_kwargs, backend, tp_size=1) | ||
check_full_graph_support(model, | ||
model_kwargs, | ||
optimization_level, | ||
tp_size=1) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from contextlib import contextmanager | ||
from typing import Any | ||
|
||
_compile_context: Any = None | ||
|
||
|
||
def get_compile_context() -> Any: | ||
"""Get the current compile context.""" | ||
return _compile_context | ||
|
||
|
||
@contextmanager | ||
def set_compile_context(context: Any): | ||
"""A context manager that stores the current compile context, | ||
usually it is a list of sizes to specialize. | ||
Comment on lines
+14
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit vague - could you improve the comment to add who uses this context and when? |
||
""" | ||
global _compile_context | ||
prev_context = _compile_context | ||
_compile_context = context | ||
try: | ||
yield | ||
finally: | ||
_compile_context = prev_context |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will be common to add custom passes - I think we should handle this by wrapping the provided pass instead of overwriting: