-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[CI/Build] Add TP test for vision models #5892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
4cf6b2f
Add TP test
DarkLight1337 676fe84
Merge branch 'upstream' into vlm-tp
DarkLight1337 582272b
Increase buffer size limit
DarkLight1337 e84538b
Fir env not being unset
DarkLight1337 095400a
Merge branch 'upstream' into vlm-tp
DarkLight1337 7e4830c
Move tp tests to distributed tests and reduce cost
DarkLight1337 1b70c75
Remove unnecessary `Tensor.to`
DarkLight1337 703fbb8
Test phi3v as well
DarkLight1337 0ff27f2
Fix missing images
DarkLight1337 9fc3417
Remove unnecessary env override since it is already set by the pipeline
DarkLight1337 d940866
Add ray tests
DarkLight1337 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. | ||
The second test will hang if more than one test is run per command, so we need | ||
to run the tests one by one. The solution is to pass arguments (model name) by | ||
environment variables. | ||
|
||
Run: | ||
```sh | ||
TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \ | ||
test_multimodal_broadcast.py | ||
TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \ | ||
test_multimodal_broadcast.py | ||
``` | ||
""" | ||
import os | ||
|
||
import pytest | ||
|
||
from vllm.utils import cuda_device_count_stateless | ||
|
||
model = os.environ["TEST_DIST_MODEL"] | ||
|
||
if model.startswith("llava-hf/llava"): | ||
from ..models.test_llava import model_and_vl_config, run_test | ||
elif model.startswith("microsoft/Phi-3-vision"): | ||
from ..models.test_phi3v import model_and_vl_config, run_test | ||
else: | ||
raise NotImplementedError(f"Unsupported model: {model}") | ||
|
||
|
||
@pytest.mark.parametrize("tensor_parallel_size", [2]) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [128]) | ||
def test_models(hf_runner, vllm_runner, image_assets, | ||
tensor_parallel_size: int, dtype: str, | ||
max_tokens: int) -> None: | ||
if cuda_device_count_stateless() < tensor_parallel_size: | ||
pytest.skip( | ||
f"Need at least {tensor_parallel_size} GPUs to run the test.") | ||
|
||
distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND") | ||
|
||
run_test( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model_and_config=model_and_vl_config[0], | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.