Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][core] Implement pipeline parallel on Ray #12996

Merged
merged 9 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
@dataclass
class PPTestSettings:
parallel_setups: List[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: List[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: List[str]
task: TaskOption
test_options: PPTestOptions

def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")

@staticmethod
def detailed(
*,
Expand Down Expand Up @@ -79,7 +92,9 @@ def detailed(
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
# only ray is supported for V1
distributed_backends=["mp", "ray", "ray"],
vllm_major_versions=["0", "0", "1"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
Expand Down Expand Up @@ -108,6 +123,7 @@ def fast(
chunked_prefill=False),
],
distributed_backends=["mp"],
vllm_major_versions=["0"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
Expand All @@ -120,8 +136,9 @@ def iter_params(self, model_name: str):
opts = self.test_options

for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_name, parallel_setup, backend, vllm_major_version,
self.task, opts)


Expand Down Expand Up @@ -244,6 +261,7 @@ def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available: int,
Expand Down Expand Up @@ -296,10 +314,13 @@ def _compare_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])

if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
# Test Ray ADAG for a subset of the tests
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
if distributed_backend == "ray" and (vllm_major_version == "1"
or specific_case):
# For V1, test Ray ADAG for all the tests
# For V0, test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_V1": vllm_major_version,
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
Expand Down Expand Up @@ -348,8 +369,8 @@ def _compare_tp(


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -361,22 +382,24 @@ def test_tp_language_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -388,22 +411,24 @@ def test_tp_language_embedding(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="encode")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -415,13 +440,15 @@ def test_tp_multimodal_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
Expand Down
11 changes: 9 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -118,7 +118,14 @@ def execute_model(
) -> "ModelRunnerOutput":
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
else:
scheduler_output, intermediate_tensors = scheduler_output, None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
return output

def override_env_vars(self, vars: Dict[str, str]):
Expand Down
41 changes: 27 additions & 14 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:

def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
available_memory: int,
num_layers: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Expand All @@ -433,6 +434,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.

Returns:
The generated KVCacheConfig
Expand All @@ -442,7 +444,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1
page_size = page_sizes.pop()

num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)

if vllm_config.cache_config.num_gpu_blocks_override is not None:
Expand Down Expand Up @@ -472,25 +474,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config


def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: List[KVCacheSpec],
available_memory: int) -> List[KVCacheConfig]:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.

Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_specs: The kv cache specs of the model
available_memory: Memory available for KV cache in bytes.

Returns:
The generated KVCacheConfig
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for most models.
# Allocate the same amount of memory for each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
else:
raise NotImplementedError
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
kv_cache_configs = []
for kv_cache_spec in kv_cache_specs:
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append(
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory,
num_layers))
else:
raise NotImplementedError
return kv_cache_configs
19 changes: 12 additions & 7 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
Expand Down Expand Up @@ -71,20 +71,25 @@ def _initialize_kv_caches(self,
start = time.time()

# Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()
kv_cache_specs = self.model_executor.get_kv_cache_specs()

# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory()
available_gpu_memory = self.model_executor.determine_available_memory()

# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
num_gpu_blocks_set = set(config.num_blocks
for config in kv_cache_configs)
assert len(num_gpu_blocks_set) == 1, (
f"num_gpu_blocks need to be the same across workers, "
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop()
num_cpu_blocks = 0

# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)
self.model_executor.initialize(kv_cache_configs)

elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
Expand Down
12 changes: 5 additions & 7 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import List, Type

from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
Expand Down Expand Up @@ -48,12 +48,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
f"{distributed_executor_backend}")
return executor_class

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_available_memory(self) -> int: # in bytes
Expand All @@ -63,11 +63,9 @@ def determine_available_memory(self) -> int: # in bytes
# operators can be applied to all workers.
return min(output)

def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]
return output

def execute_model(
self,
Expand Down
Loading