Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][core] Implement pipeline parallel on Ray #12996

Merged
merged 9 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
@dataclass
class PPTestSettings:
parallel_setups: List[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: List[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: List[str]
task: TaskOption
test_options: PPTestOptions

def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")

@staticmethod
def detailed(
*,
Expand Down Expand Up @@ -79,7 +92,9 @@ def detailed(
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
# only ray is supported for V1
distributed_backends=["mp", "ray", "ray"],
vllm_major_versions=["0", "0", "1"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
Expand Down Expand Up @@ -108,6 +123,7 @@ def fast(
chunked_prefill=False),
],
distributed_backends=["mp"],
vllm_major_versions=["0"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
Expand All @@ -120,8 +136,9 @@ def iter_params(self, model_name: str):
opts = self.test_options

for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_name, parallel_setup, backend, vllm_major_version,
self.task, opts)


Expand Down Expand Up @@ -244,6 +261,7 @@ def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available: int,
Expand Down Expand Up @@ -296,9 +314,11 @@ def _compare_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])

if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
# Test Ray ADAG for a subset of the tests
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
if distributed_backend == "ray" and (vllm_major_version == "1"
or specific_case):
# For V1, test Ray ADAG for all the tests
# For V0, test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
Expand Down Expand Up @@ -348,8 +368,8 @@ def _compare_tp(


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -358,25 +378,29 @@ def _compare_tp(
)
@fork_new_process_for_each_test
def test_tp_language_generation(
monkeypatch,
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
monkeypatch.setenv('VLLM_USE_V1', vllm_major_version)
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -385,25 +409,29 @@ def test_tp_language_generation(
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
monkeypatch,
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
monkeypatch.setenv('VLLM_USE_V1', vllm_major_version)
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="encode")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -412,16 +440,20 @@ def test_tp_language_embedding(
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
monkeypatch,
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
monkeypatch.setenv('VLLM_USE_V1', vllm_major_version)
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
Expand Down
11 changes: 9 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -118,7 +118,14 @@ def execute_model(
) -> "ModelRunnerOutput":
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
else:
scheduler_output, intermediate_tensors = scheduler_output, None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
return output

def override_env_vars(self, vars: Dict[str, str]):
Expand Down
22 changes: 16 additions & 6 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,30 @@ def _initialize_kv_caches(self,
start = time.time()

# Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()
kv_cache_specs = self.model_executor.get_kv_cache_specs()

# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory()
available_gpu_memory = self.model_executor.determine_available_memory()

# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks
kv_cache_configs = []
num_gpu_blocks = None
for kv_cache_spec in kv_cache_specs:
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_gpu_memory)
kv_cache_configs.append(kv_cache_config)
if num_gpu_blocks is None:
num_gpu_blocks = kv_cache_config.num_blocks
elif num_gpu_blocks != kv_cache_config.num_blocks:
raise NotImplementedError(
"num_gpu_blocks need to be the same across workers: "
f"{num_gpu_blocks} != {kv_cache_config.num_blocks}")
assert num_gpu_blocks is not None
num_cpu_blocks = 0

# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)
self.model_executor.initialize(kv_cache_configs)

elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
Expand Down
12 changes: 5 additions & 7 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import List, Type

from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
Expand Down Expand Up @@ -48,12 +48,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
f"{distributed_executor_backend}")
return executor_class

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_available_memory(self) -> int: # in bytes
Expand All @@ -63,11 +63,9 @@ def determine_available_memory(self) -> int: # in bytes
# operators can be applied to all workers.
return min(output)

def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]
return output

def execute_model(
self,
Expand Down
16 changes: 15 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
Expand All @@ -21,6 +21,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
Expand Down Expand Up @@ -773,6 +774,7 @@ def get_model(self) -> nn.Module:
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> ModelRunnerOutput:
batch_changed = self._update_states(scheduler_output)

Expand Down Expand Up @@ -831,8 +833,11 @@ def execute_model(
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
Expand Down Expand Up @@ -1007,12 +1012,19 @@ def _dummy_run(
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
Expand Down Expand Up @@ -1142,6 +1154,8 @@ def profile_run(self) -> None:
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

import torch
import torch.distributed
Expand Down Expand Up @@ -195,8 +195,9 @@ def determine_available_memory(self) -> int:
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()

def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
Expand Down