Skip to content

Commit

Permalink
Support torchrun and SPMD-style offline inference (vllm-project#12071)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored and abmfy committed Jan 24, 2025
1 parent 175efd9 commit 30e90b8
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 30 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ steps:
- vllm/worker/worker.py
- vllm/worker/model_runner.py
commands:
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
Expand Down
64 changes: 64 additions & 0 deletions examples/offline_inference/torchrun_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
experimental support for tensor-parallel inference with torchrun,
see https://github.com/vllm-project/vllm/issues/11400 for
the motivation and use case for this example.
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
the argument 2 should match the `tensor_parallel_size` below.
see `tests/distributed/test_torchrun_example.py` for the unit test.
"""

from vllm import LLM, SamplingParams

# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="external_launcher",
)

outputs = llm.generate(prompts, sampling_params)

# all ranks will have the same outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.
```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
# do something for rank 0, e.g. saving the results to disk.
```
2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```
3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""
56 changes: 56 additions & 0 deletions tests/distributed/test_torchrun_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# unit test for `examples/offline_inference/torchrun_example.py`

import random

import torch.distributed as dist

from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_world_group

# Create prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# set different `gpu_memory_utilization` and `swap_space` for different ranks,
# to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="external_launcher",
gpu_memory_utilization=random.uniform(0.7, 0.9),
swap_space=random.randint(1, 4))

outputs = llm.generate(prompts, sampling_params)

cpu_group = get_world_group().cpu_group

torch_rank = dist.get_rank(group=cpu_group)


def test_consistent_across_ranks(obj):
if torch_rank == 0:
dist.broadcast_object_list([obj], src=0, group=cpu_group)
else:
container = [None]
dist.broadcast_object_list(container, src=0, group=cpu_group)
assert container[0] == obj


test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)

# all ranks should have the same outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
test_consistent_across_ranks(prompt)
test_consistent_across_ranks(generated_text)
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
2 changes: 1 addition & 1 deletion tests/engine/test_multiproc_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
# simulate error case
raise worker_input

return self.rank, input
return self.rpc_rank, input


def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
Expand Down
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,14 +1338,15 @@ def _verify_args(self) -> None:
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", "uni", None) and not (isinstance(
"ray", "mp", "uni",
"external_launcher", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
" subclass.")
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
" custom ExecutorBase subclass.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
# Parallel arguments
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
choices=['ray', 'mp', 'uni', 'external_launcher'],
default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed model '
'workers, either "ray" or "mp" (multiprocessing). If the product '
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ def _get_executor_cls(cls,
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
Expand Down
6 changes: 3 additions & 3 deletions vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rank=rank)
rpc_rank=rank)
else:
worker = ray.remote(
num_cpus=0,
Expand All @@ -181,7 +181,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
rank=rank)
rpc_rank=rank)
worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank))
rank += 1
Expand All @@ -204,7 +204,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
vllm_config=self.vllm_config, rank=0)
vllm_config=self.vllm_config, rpc_rank=0)
worker_metadata.pop(i)
break

Expand Down
81 changes: 80 additions & 1 deletion vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
Expand All @@ -16,7 +21,7 @@ def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rank=0)
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
Expand Down Expand Up @@ -55,3 +60,77 @@ def check_health(self) -> None:


UniProcExecutorAsync = UniProcExecutor


class ExecutorWithExternalLauncher(UniProcExecutor):
"""An executor that uses external launchers to launch engines,
specially designed for torchrun-compatible launchers, for
offline inference with tensor parallelism.
see https://github.com/vllm-project/vllm/issues/11400 for
the motivation, and examples/offline_inference/torchrun_example.py
for the usage example.
The key idea: although it is tensor-parallel inference, we only
create one worker per executor, users will launch multiple
engines with torchrun-compatible launchers, and all these engines
work together to process the same prompts. When scheduling is
deterministic, all the engines will generate the same outputs,
and they don't need to synchronize the states with each other.
"""
uses_ray: bool = False

def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
("ExecutorWithExternalLauncher does not "
"support pipeline parallelism.")
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
("ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
"does not support delay_factor in scheduling")
assert not envs.VLLM_USE_V1, \
("V1 architecture cannot guarantee deterministic execution, "
"so it is not supported in ExecutorWithExternalLauncher.")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
# engines are launched in torchrun-compatible launchers
# so we can use the env:// method.
# required env vars:
# - RANK
# - MASTER_ADDR
# - MASTER_PORT
distributed_init_method = "env://"
rank = int(os.environ["RANK"])
local_rank = rank
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks.
Add an additional all_reduce to get the min across all ranks.
Note that even if we have the same `gpu_memory_utilization` and
`swap_space`, the available memory in every rank might still
differ because NCCL can take different amounts of memory in
different ranks. Therefore, it is necessary to test if all ranks
agree on the same KV cache configuration.
"""
a, b = super().determine_num_available_blocks()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return a_tensor.item(), b_tensor.item()
4 changes: 2 additions & 2 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,8 +940,8 @@ def soft_cap(self):
return self.base_layer.soft_cap

@property
def use_gather(self):
return self.base_layer.use_gather
def use_all_gather(self):
return self.base_layer.use_all_gather

@property
def org_vocab_size(self):
Expand Down
16 changes: 10 additions & 6 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn

import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -44,8 +45,10 @@ def __init__(self,
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.

self.use_gather = not current_platform.is_tpu(
) and not envs.VLLM_USE_V1
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or envs.VLLM_USE_V1 \
or parallel_config.distributed_executor_backend == "external_launcher" # noqa

def forward(
self,
Expand Down Expand Up @@ -88,16 +91,17 @@ def _get_logits(
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
else:

if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(
ready_path: str,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank)
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: List[Dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if is_driver_worker:
assert rank % self.parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
Expand Down
Loading

0 comments on commit 30e90b8

Please sign in to comment.