-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Support torchrun and SPMD-style offline inference #12071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
8aaf7a0
support torchrun
youkaichao 12c7c9d
support
youkaichao efb3167
commit
youkaichao 5291b3e
fix format
youkaichao 342fbbe
comments
youkaichao f9b6167
add comments
youkaichao a87635e
fix
youkaichao 80e48da
fix
youkaichao 08162f3
fix tests
youkaichao 255cdd3
fix
youkaichao 7e4dd7e
fix
youkaichao bdbbdcf
fix
youkaichao 917667b
test differently
youkaichao 012800e
fix
youkaichao d95291b
add consistency test
youkaichao 0a2fe6a
unify arg names
youkaichao 7f9879a
add launch command
youkaichao 05d6595
tutorial style
youkaichao d28e8d2
add tests
youkaichao ef92e63
add tips
youkaichao 8750ea7
add tips
youkaichao 8c62a32
add code
youkaichao f1dfd2b
add config check
youkaichao 33cc388
fix use_all_gather
youkaichao 690631b
fix linter
youkaichao 5022850
fix linter
youkaichao 3886dc2
fix arg name
youkaichao 5151add
fix tests
youkaichao 78c0b29
disable v1 for torchrun
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
experimental support for tensor-parallel inference with torchrun, | ||
see https://github.com/vllm-project/vllm/issues/11400 for | ||
the motivation and use case for this example. | ||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`, | ||
the argument 2 should match the `tensor_parallel_size` below. | ||
see `tests/distributed/test_torchrun_example.py` for the unit test. | ||
""" | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# Create prompts, the same across all ranks | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
# Create sampling parameters, the same across all ranks | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Use `distributed_executor_backend="external_launcher"` so that | ||
# this llm engine/instance only creates one worker. | ||
llm = LLM( | ||
model="facebook/opt-125m", | ||
tensor_parallel_size=2, | ||
distributed_executor_backend="external_launcher", | ||
) | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
# all ranks will have the same outputs | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, " | ||
f"Generated text: {generated_text!r}") | ||
""" | ||
Further tips: | ||
|
||
1. to communicate control messages across all ranks, use the cpu group, | ||
a PyTorch ProcessGroup with GLOO backend. | ||
|
||
```python | ||
from vllm.distributed.parallel_state import get_world_group | ||
cpu_group = get_world_group().cpu_group | ||
torch_rank = dist.get_rank(group=cpu_group) | ||
if torch_rank == 0: | ||
# do something for rank 0, e.g. saving the results to disk. | ||
``` | ||
|
||
2. to communicate data across all ranks, use the model's device group, | ||
a PyTorch ProcessGroup with NCCL backend. | ||
```python | ||
from vllm.distributed.parallel_state import get_world_group | ||
device_group = get_world_group().device_group | ||
``` | ||
|
||
3. to access the model directly in every rank, use the following code: | ||
```python | ||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model | ||
``` | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# unit test for `examples/offline_inference/torchrun_example.py` | ||
|
||
import random | ||
|
||
import torch.distributed as dist | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.distributed.parallel_state import get_world_group | ||
|
||
# Create prompts | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# set different `gpu_memory_utilization` and `swap_space` for different ranks, | ||
# to test if all ranks agree on the same kv cache configuration. | ||
llm = LLM(model="facebook/opt-125m", | ||
tensor_parallel_size=2, | ||
distributed_executor_backend="external_launcher", | ||
gpu_memory_utilization=random.uniform(0.7, 0.9), | ||
swap_space=random.randint(1, 4)) | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
cpu_group = get_world_group().cpu_group | ||
|
||
torch_rank = dist.get_rank(group=cpu_group) | ||
|
||
|
||
def test_consistent_across_ranks(obj): | ||
if torch_rank == 0: | ||
dist.broadcast_object_list([obj], src=0, group=cpu_group) | ||
else: | ||
container = [None] | ||
dist.broadcast_object_list(container, src=0, group=cpu_group) | ||
assert container[0] == obj | ||
|
||
|
||
test_consistent_across_ranks( | ||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) | ||
test_consistent_across_ranks( | ||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) | ||
|
||
# all ranks should have the same outputs | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
test_consistent_across_ranks(prompt) | ||
test_consistent_across_ranks(generated_text) | ||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, " | ||
f"Generated text: {generated_text!r}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.