Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) #10980

Merged
merged 80 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
57f3329
wip
afeldman-nm Feb 9, 2025
50584f6
wip
afeldman-nm Feb 9, 2025
bf3cfd0
Merge branch 'main' into sample
afeldman-nm Feb 9, 2025
98726ed
stream=false works
afeldman-nm Feb 10, 2025
cd649df
streaming nearly works
afeldman-nm Feb 10, 2025
fdc3296
Merge branch 'sample' into sample_merge
afeldman-nm Feb 10, 2025
a5415ef
refactor
afeldman-nm Feb 11, 2025
eb9042a
Merge branch 'main' into sample_merge
afeldman-nm Feb 11, 2025
a6637a9
good async implementation
afeldman-nm Feb 11, 2025
8808c7c
Merge branch 'main' into sample_merge
afeldman-nm Feb 11, 2025
af11e41
seed
afeldman-nm Feb 11, 2025
07f0c17
feedback
afeldman-nm Feb 11, 2025
522d34c
Merge branch 'main' into sample_merge
afeldman-nm Feb 11, 2025
2e828a8
linting
afeldman-nm Feb 11, 2025
374f1c7
Update vllm/v1/engine/async_llm.py
afeldman-nm Feb 13, 2025
35036ea
async def -> def
afeldman-nm Feb 13, 2025
b45c413
Update vllm/v1/engine/parallel_sampling.py
afeldman-nm Feb 13, 2025
00bb1f2
Update vllm/v1/engine/parallel_sampling.py
afeldman-nm Feb 13, 2025
b16ba2b
index
afeldman-nm Feb 13, 2025
fbcd213
sort by index
afeldman-nm Feb 13, 2025
a4ded40
no warmup
afeldman-nm Feb 13, 2025
9752657
Merge branch 'main' into sample_merge
afeldman-nm Feb 14, 2025
119a77c
refactor transform_output()
afeldman-nm Feb 14, 2025
a64e3b3
wip
afeldman-nm Feb 14, 2025
36cd555
refactor
afeldman-nm Feb 14, 2025
1f32836
Merge branch 'main' into sample_merge
afeldman-nm Feb 14, 2025
39d3d0b
sampling params caching
afeldman-nm Feb 14, 2025
5462e83
wip
afeldman-nm Feb 14, 2025
c0f8fb1
wip
afeldman-nm Feb 16, 2025
252e2ae
Merge branch 'main' into sample_merge
afeldman-nm Feb 16, 2025
97feb43
Merge branch 'main' into sample_merge
afeldman-nm Feb 17, 2025
103ceb6
parallel sampling core
afeldman-nm Feb 17, 2025
e6a1134
wip
afeldman-nm Feb 17, 2025
625e161
wip
afeldman-nm Feb 17, 2025
e28388e
add parallel sampling requests
afeldman-nm Feb 17, 2025
1c18dc2
wip
afeldman-nm Feb 17, 2025
c9c3dbb
working parallel sampling in LLMEngine
afeldman-nm Feb 17, 2025
d927a4a
Merge branch 'main' into sync_sample_merge
afeldman-nm Feb 18, 2025
0f0075c
unit test
afeldman-nm Feb 18, 2025
e4a0e6c
refactoring; bugfix
afeldman-nm Feb 18, 2025
bcdec42
Merge branch 'main' into sync_sample_merge
afeldman-nm Feb 18, 2025
196fc68
seed
afeldman-nm Feb 18, 2025
f86708a
refactor
afeldman-nm Feb 18, 2025
7e653f9
Merge branch 'main' into sample_merge
afeldman-nm Feb 18, 2025
6b1be36
Parallel sampling unit tests
afeldman-nm Feb 18, 2025
081a695
Merge branch 'main' into sync_sample_merge
afeldman-nm Feb 18, 2025
a3c5b22
Merge branch 'sync_sample_merge' into sync_sample
afeldman-nm Feb 18, 2025
13f3424
Merge branch 'main' into sample_merge
afeldman-nm Feb 18, 2025
5f57964
Merge branch 'sync_sample' into sample_merge
afeldman-nm Feb 18, 2025
933a90e
pre-commit hook fixes
afeldman-nm Feb 18, 2025
4876c85
Merge branch 'main' into sample_merge
afeldman-nm Feb 19, 2025
fe9f88c
Update vllm/v1/engine/async_llm.py
afeldman-nm Feb 19, 2025
d40089e
Update vllm/v1/engine/parallel_sampling.py
afeldman-nm Feb 19, 2025
cb0a2b4
Update vllm/v1/engine/parallel_sampling.py
afeldman-nm Feb 19, 2025
ef49ba7
rename
afeldman-nm Feb 19, 2025
d96617b
Merge branch 'afeldman-nm/sample' of https://github.com/neuralmagic/v…
afeldman-nm Feb 19, 2025
94261e2
Update vllm/v1/engine/llm_engine.py
afeldman-nm Feb 19, 2025
cfe46a7
Merge branch 'afeldman-nm/sample' of https://github.com/neuralmagic/v…
afeldman-nm Feb 19, 2025
9cc19de
Update vllm/v1/engine/parallel_sampling.py
afeldman-nm Feb 19, 2025
150fc93
fix
afeldman-nm Feb 19, 2025
9e8d755
refactor generate_parallel_sampling_async() into parallel_sampling.py
afeldman-nm Feb 19, 2025
79ebacf
Merge branch 'main' into sample_merge
afeldman-nm Feb 19, 2025
50b4154
Merge branch 'afeldman-nm/sample' of https://github.com/neuralmagic/v…
afeldman-nm Feb 19, 2025
73ccfb3
refactor
afeldman-nm Feb 19, 2025
b334d60
refactor
afeldman-nm Feb 19, 2025
38ea057
refactor
afeldman-nm Feb 19, 2025
f1bd068
Merge branch 'main' into sample_merge
afeldman-nm Feb 19, 2025
ecb39ae
refactor
afeldman-nm Feb 19, 2025
05db54d
Merge branch 'main' into sample_merge
afeldman-nm Feb 20, 2025
dfb8513
reorg
afeldman-nm Feb 20, 2025
53e35df
refactor
afeldman-nm Feb 20, 2025
382edd6
stream mode finished flag
afeldman-nm Feb 20, 2025
3204c36
Merge branch 'main' into sample_merge
afeldman-nm Feb 20, 2025
2001bef
refactor
afeldman-nm Feb 20, 2025
892429c
refactor
afeldman-nm Feb 20, 2025
267c1b8
rename
afeldman-nm Feb 20, 2025
bef174e
protocol-based types
afeldman-nm Feb 20, 2025
ba5a0dd
merge
afeldman-nm Feb 22, 2025
7e845cb
llm_engine test fixtures
afeldman-nm Feb 24, 2025
3d2fff4
Merge branch 'main' into sample_merge
afeldman-nm Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 98 additions & 5 deletions tests/v1/engine/test_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,114 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import Dict, List, Optional, Tuple

import pytest

from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams

MODEL = "facebook/opt-125m"
DTYPE = "half"

def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""

def _vllm_model(apc: bool, vllm_runner, monkeypatch):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
return vllm_runner(
MODEL,
dtype=DTYPE,
max_model_len=128,
enforce_eager=True,
enable_prefix_caching=apc,
gpu_memory_utilization=0.5,
)


@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture parameterized by APC True/False."""
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
yield vllm_model


@pytest.fixture(scope="function")
def vllm_model_apc(vllm_runner, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
yield vllm_model


def _get_test_sampling_params(
prompt_list: List[str],
seed: Optional[int] = 42,
) -> Tuple[List[SamplingParams], List[int]]:
"""Generate random sampling params for a batch."""

def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28)
if x < 10:
return 1
else:
return x - 8

n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions
return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
for n in n_list
], n_list


def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions.

Args:
vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing.
"""
sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
model: LLM = vllm_model.model
outputs = model.generate(example_prompts, sampling_params_list)

# Validate each request response
for out, n in zip(outputs, n_list):
completion_counts: Dict[str, int] = {}
# Assert correct number of completions
assert len(out.outputs) == n, (
f"{len(out.outputs)} completions; {n} expected.")
for idx in range(n):
comp = out.outputs[idx]
# Assert correct completion indices
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
text = comp.text
completion_counts[text] = completion_counts.get(text, 0) + 1
# Assert unique completions
if len(completion_counts) != n:
repeats = {
txt: num
for (txt, num) in completion_counts.items() if num > 1
}
raise AssertionError(
f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}")


def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model: LLM = vllm_model_apc.model
with pytest.raises(ValueError) as excinfo:
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
model.generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))

Expand Down
102 changes: 102 additions & 0 deletions tests/v1/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
model_name: str):
"""Parallel sampling without streaming.
A single request output contains a list of completions.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

# High temperature to maximize chance of unique completions.
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
stream=False,
seed=42)

# Assert `n` completions
num_completions = len(completion.choices)
assert num_completions == n, (
f"Num completions {num_completions} but expected {n}.")
completion_repeats: Dict[str, int] = {}
for idx, choice in enumerate(completion.choices):
# Assert correct completion index & some finish reason.
assert choice.index == idx, (
f"Index {choice.index} but expected {idx}.")
assert choice.finish_reason is not None, (
"None finish_reason is invalid.")
text = choice.text
completion_repeats[text] = completion_repeats.get(text, 0) + 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
repeats = {
txt: num
for (txt, num) in completion_repeats.items() if num > 1
}
raise AssertionError(
f"Expected {n} unique completions, got {num_unique};"
f" repeats: {repeats}.")


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""

prompt = "What is an LLM?"
n = 3
max_tokens = 5

stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
stream=True,
seed=42)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# Assert `n` completions with correct finish reasons
assert finish_reason_count == n, (
f"Expected {n} completions with valid indices and finish_reason.")
completion_repeats: Dict[str, int] = {}
for chunk in chunks:
chunk_len = len(chunk)
# Assert correct number of completion tokens
assert chunk_len == max_tokens, (
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
text = "".join(chunk)
completion_repeats[text] = completion_repeats.get(text, 0) + 1
print(text)
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
repeats = {
txt: num
for (txt, num) in completion_repeats.items() if num > 1
}
raise AssertionError(f"{num_unique} unique completions, expected {n};"
f" repeats: {repeats}")


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
27 changes: 26 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
Expand Down Expand Up @@ -170,7 +171,7 @@ async def add_request(
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
async def _generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
Expand Down Expand Up @@ -241,6 +242,30 @@ async def generate(
await self.abort(request_id)
raise

def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
kwargs = dict(prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
if sampling_params.n is None or sampling_params.n == 1:
return self._generate(**kwargs)
else:
# Special handling for parallel sampling requests
return generate_parallel_sampling_async(generate=self._generate,
**kwargs)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""

Expand Down
43 changes: 40 additions & 3 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor

Expand Down Expand Up @@ -48,6 +49,9 @@ def __init__(
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config

# Bookkeeping for parallel sampling requests
self.parallel_manager = SyncParallelSamplingManager()

# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
Expand Down Expand Up @@ -115,7 +119,8 @@ def from_engine_args(
multiprocess_mode=enable_multiprocessing)

def get_num_unfinished_requests(self) -> int:
return self.output_processor.get_num_unfinished_requests()
return self.parallel_manager.get_num_unfinished_requests(
self.output_processor.get_num_unfinished_requests())

def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
Expand Down Expand Up @@ -151,7 +156,36 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:

"""Add request."""
kwargs = dict(request_id=request_id,
prompt=prompt,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority)
# Handle parallel sampling requests differently.
if params is None or isinstance(params,
PoolingParams) or params.n == 1:
self._add_request(**kwargs)
else:
# Special handling for parallel sampling requests
self.parallel_manager.add_request_parallel_sampling(
add_request=self._add_request, **kwargs)

def _add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add request, `n=1`"""
# 1) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
Expand Down Expand Up @@ -182,7 +216,10 @@ def step(self) -> List[RequestOutput]:
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

return processed_outputs.request_outputs
request_outputs = processed_outputs.request_outputs

# 4) Process unfinished parallel sampling requests
return self.parallel_manager.step(request_outputs)

def get_model_config(self):
return self.model_config
Expand Down
Loading