forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support torchrun and SPMD-style offline inference (vllm-project#12071)
Signed-off-by: youkaichao <youkaichao@gmail.com>
- Loading branch information
1 parent
2ff12d1
commit 0d0aba1
Showing
14 changed files
with
248 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
experimental support for tensor-parallel inference with torchrun, | ||
see https://github.com/vllm-project/vllm/issues/11400 for | ||
the motivation and use case for this example. | ||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`, | ||
the argument 2 should match the `tensor_parallel_size` below. | ||
see `tests/distributed/test_torchrun_example.py` for the unit test. | ||
""" | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# Create prompts, the same across all ranks | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
# Create sampling parameters, the same across all ranks | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Use `distributed_executor_backend="external_launcher"` so that | ||
# this llm engine/instance only creates one worker. | ||
llm = LLM( | ||
model="facebook/opt-125m", | ||
tensor_parallel_size=2, | ||
distributed_executor_backend="external_launcher", | ||
) | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
# all ranks will have the same outputs | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, " | ||
f"Generated text: {generated_text!r}") | ||
""" | ||
Further tips: | ||
1. to communicate control messages across all ranks, use the cpu group, | ||
a PyTorch ProcessGroup with GLOO backend. | ||
```python | ||
from vllm.distributed.parallel_state import get_world_group | ||
cpu_group = get_world_group().cpu_group | ||
torch_rank = dist.get_rank(group=cpu_group) | ||
if torch_rank == 0: | ||
# do something for rank 0, e.g. saving the results to disk. | ||
``` | ||
2. to communicate data across all ranks, use the model's device group, | ||
a PyTorch ProcessGroup with NCCL backend. | ||
```python | ||
from vllm.distributed.parallel_state import get_world_group | ||
device_group = get_world_group().device_group | ||
``` | ||
3. to access the model directly in every rank, use the following code: | ||
```python | ||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model | ||
``` | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# unit test for `examples/offline_inference/torchrun_example.py` | ||
|
||
import random | ||
|
||
import torch.distributed as dist | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.distributed.parallel_state import get_world_group | ||
|
||
# Create prompts | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# set different `gpu_memory_utilization` and `swap_space` for different ranks, | ||
# to test if all ranks agree on the same kv cache configuration. | ||
llm = LLM(model="facebook/opt-125m", | ||
tensor_parallel_size=2, | ||
distributed_executor_backend="external_launcher", | ||
gpu_memory_utilization=random.uniform(0.7, 0.9), | ||
swap_space=random.randint(1, 4)) | ||
|
||
outputs = llm.generate(prompts, sampling_params) | ||
|
||
cpu_group = get_world_group().cpu_group | ||
|
||
torch_rank = dist.get_rank(group=cpu_group) | ||
|
||
|
||
def test_consistent_across_ranks(obj): | ||
if torch_rank == 0: | ||
dist.broadcast_object_list([obj], src=0, group=cpu_group) | ||
else: | ||
container = [None] | ||
dist.broadcast_object_list(container, src=0, group=cpu_group) | ||
assert container[0] == obj | ||
|
||
|
||
test_consistent_across_ranks( | ||
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) | ||
test_consistent_across_ranks( | ||
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) | ||
|
||
# all ranks should have the same outputs | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
test_consistent_across_ranks(prompt) | ||
test_consistent_across_ranks(generated_text) | ||
print(f"Rank {torch_rank}, Prompt: {prompt!r}, " | ||
f"Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.