Skip to content

Commit

Permalink
[V1][core] Implement pipeline parallel on Ray (#12996)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruisearch42 authored Feb 13, 2025
1 parent 0ccd876 commit 9605c12
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 45 deletions.
51 changes: 39 additions & 12 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
@dataclass
class PPTestSettings:
parallel_setups: List[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: List[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: List[str]
task: TaskOption
test_options: PPTestOptions

def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")

@staticmethod
def detailed(
*,
Expand Down Expand Up @@ -79,7 +92,9 @@ def detailed(
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
# only ray is supported for V1
distributed_backends=["mp", "ray", "ray"],
vllm_major_versions=["0", "0", "1"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
Expand Down Expand Up @@ -108,6 +123,7 @@ def fast(
chunked_prefill=False),
],
distributed_backends=["mp"],
vllm_major_versions=["0"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
Expand All @@ -120,8 +136,9 @@ def iter_params(self, model_name: str):
opts = self.test_options

for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_name, parallel_setup, backend, vllm_major_version,
self.task, opts)


Expand Down Expand Up @@ -244,6 +261,7 @@ def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available: int,
Expand Down Expand Up @@ -296,10 +314,13 @@ def _compare_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])

if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
# Test Ray ADAG for a subset of the tests
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
if distributed_backend == "ray" and (vllm_major_version == "1"
or specific_case):
# For V1, test Ray ADAG for all the tests
# For V0, test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_V1": vllm_major_version,
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
Expand Down Expand Up @@ -348,8 +369,8 @@ def _compare_tp(


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -361,22 +382,24 @@ def test_tp_language_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -388,22 +411,24 @@ def test_tp_language_embedding(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="encode")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_name)
Expand All @@ -415,13 +440,15 @@ def test_tp_multimodal_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
Expand Down
11 changes: 9 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -118,7 +118,14 @@ def execute_model(
) -> "ModelRunnerOutput":
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
else:
scheduler_output, intermediate_tensors = scheduler_output, None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
return output

def override_env_vars(self, vars: Dict[str, str]):
Expand Down
41 changes: 27 additions & 14 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:

def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
available_memory: int,
num_layers: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Expand All @@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns:
The generated KVCacheConfig
Expand All @@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1
page_size = page_sizes.pop()

num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)

if vllm_config.cache_config.num_gpu_blocks_override is not None:
Expand Down Expand Up @@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config


def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: List[KVCacheSpec],
available_memory: int) -> List[KVCacheConfig]:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_specs: The kv cache specs of the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for most models.
# Allocate the same amount of memory for each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
else:
raise NotImplementedError
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
kv_cache_configs = []
for kv_cache_spec in kv_cache_specs:
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append(
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory,
num_layers))
else:
raise NotImplementedError
return kv_cache_configs
19 changes: 12 additions & 7 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
Expand Down Expand Up @@ -73,20 +73,25 @@ def _initialize_kv_caches(self,
start = time.time()

# Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()
kv_cache_specs = self.model_executor.get_kv_cache_specs()

# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory()
available_gpu_memory = self.model_executor.determine_available_memory()

# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
num_gpu_blocks_set = set(config.num_blocks
for config in kv_cache_configs)
assert len(num_gpu_blocks_set) == 1, (
f"num_gpu_blocks need to be the same across workers, "
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop()
num_cpu_blocks = 0

# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)
self.model_executor.initialize(kv_cache_configs)

elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
Expand Down
12 changes: 5 additions & 7 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Type
from typing import List, Type

from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
Expand Down Expand Up @@ -48,12 +48,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
f"{distributed_executor_backend}")
return executor_class

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_available_memory(self) -> int: # in bytes
Expand All @@ -63,11 +63,9 @@ def determine_available_memory(self) -> int: # in bytes
# operators can be applied to all workers.
return min(output)

def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]
return output

def execute_model(
self,
Expand Down
Loading

0 comments on commit 9605c12

Please sign in to comment.