From 57f3329c0bf4b68bd8980a08adb4dbbd37804cff Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 9 Feb 2025 04:59:46 +0000 Subject: [PATCH 01/54] wip Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3c4e35e4aa274..da02eb924f9d2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -167,7 +167,7 @@ async def add_request( # requests we don't need to send multiple messages to core proc, # and so we don't need multiple streams which then get # re-multiplexed in the API server anyhow. - async def generate( + async def _generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -238,6 +238,21 @@ async def generate( await self.abort(request_id) raise + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + async for output in self._generate(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority): + yield output + async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" From 50584f62762eca1145c16836e65ce4f11d32f05a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 9 Feb 2025 05:40:36 +0000 Subject: [PATCH 02/54] wip Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 28 ++++++++++++++++-- vllm/v1/engine/parallel_sampling.py | 44 +++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 vllm/v1/engine/parallel_sampling.py diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index da02eb924f9d2..39fd9328ac8f8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -238,6 +238,18 @@ async def _generate( await self.abort(request_id) raise + async def _parallel_sampling_batch( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + pass + async def generate( self, prompt: PromptType, @@ -248,9 +260,19 @@ async def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: - async for output in self._generate(prompt, sampling_params, request_id, - lora_request, trace_headers, - prompt_adapter_request, priority): + n = sampling_params.n + if n is None or sampling_params.n == 1: + generator = self._generate(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority) + else: + generator = self._parallel_sampling_batch(prompt, sampling_params, + request_id, lora_request, + trace_headers, + prompt_adapter_request, + priority) + + async for output in generator: yield output async def _run_output_handler(self): diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py new file mode 100644 index 0000000000000..1aebcfb79e422 --- /dev/null +++ b/vllm/v1/engine/parallel_sampling.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 + +from copy import copy +from typing import Any, Dict, Optional + +from vllm.outputs import RequestOutput +from vllm.sampling_params import RequestOutputKind, SamplingParams + + +class ParentRequestState: + sampling_params: SamplingParams + request_output: Optional[RequestOutput] = None + + def get_child_sampling_params( + self, + kwargs: Dict[str, Any] = {}, + ) -> SamplingParams: + sampling_params = copy(self.sampling_params) + for kw in kwargs: + setattr(sampling_params, kw, kwargs[kw]) + return sampling_params + + def add_output( + self, + child_req_output: RequestOutput, + ) -> None: + if self.output_kind != RequestOutputKind.DELTA: + pass + + @property + def n(self) -> int: + return self.sampling_params.n + + @property + def logprobs(self) -> Optional[int]: + return self.sampling_params.logprobs + + @property + def prompt_logprobs(self) -> Optional[int]: + return self.sampling_params.prompt_logprobs + + @property + def output_kind(self) -> RequestOutputKind: + return self.sampling_params.output_kind From 98726edc2775a0ab84d8a286a81dc88ba632d506 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Feb 2025 15:18:53 +0000 Subject: [PATCH 03/54] stream=false works Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 68 ++++++++++++++++++++++++----- vllm/v1/engine/parallel_sampling.py | 59 ++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 14 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 39fd9328ac8f8..be389f6410ba3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,6 +24,8 @@ from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import (ParallelSamplingOutputProcessor, + ParentRequestState) from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -50,6 +52,7 @@ def __init__( assert start_engine_loop self.model_config = vllm_config.model_config + self.enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching self.log_requests = log_requests self.log_stats = log_stats @@ -248,7 +251,50 @@ async def _parallel_sampling_batch( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: - pass + parent_state = ParentRequestState(request_id, sampling_params) + output_processor = ParallelSamplingOutputProcessor(parent_state) + n = parent_state.n + + if self.enable_prefix_caching: + # If engine uses APC, generate a “warmup request” with + # max_tokens=1 which populates the APC + w_sampling_params = parent_state.get_child_sampling_params({ + "max_tokens": + 1, + "n": + 1 + }) + async for _ in self._generate( + prompt, + w_sampling_params, + parent_state.get_warmup_request_id(), + lora_request, + trace_headers, + prompt_adapter_request, + priority, + ): + pass + + seed = 42 + for idx in range(n): + c_sampling_params = parent_state.get_child_sampling_params({ + "n": + 1, + "seed": + seed + }) + seed += 1 + async for out in self._generate( + prompt, + c_sampling_params, + parent_state.get_child_request_id(idx), + lora_request, + trace_headers, + prompt_adapter_request, + priority, + ): + if req_out := output_processor.process_output(out): + yield req_out async def generate( self, @@ -262,18 +308,16 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: n = sampling_params.n if n is None or sampling_params.n == 1: - generator = self._generate(prompt, sampling_params, request_id, - lora_request, trace_headers, - prompt_adapter_request, priority) + async for out in self._generate(prompt, sampling_params, + request_id, lora_request, + trace_headers, + prompt_adapter_request, priority): + yield out else: - generator = self._parallel_sampling_batch(prompt, sampling_params, - request_id, lora_request, - trace_headers, - prompt_adapter_request, - priority) - - async for output in generator: - yield output + async for out in self._parallel_sampling_batch( + prompt, sampling_params, request_id, lora_request, + trace_headers, prompt_adapter_request, priority): + yield out async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 1aebcfb79e422..626b66de2484e 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -8,9 +8,15 @@ class ParentRequestState: + request_id: str sampling_params: SamplingParams request_output: Optional[RequestOutput] = None + def __init__(self, request_id: str, + sampling_params: SamplingParams) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + def get_child_sampling_params( self, kwargs: Dict[str, Any] = {}, @@ -24,8 +30,32 @@ def add_output( self, child_req_output: RequestOutput, ) -> None: - if self.output_kind != RequestOutputKind.DELTA: - pass + if self.request_output is None: + # Save the first request output; reinstate + # original request ID; metrics are not + # supported for parallel sampling + child_req_output.request_id = self.request_id + child_req_output.metrics = None + self.request_output = child_req_output + else: + # Add completion to the request output + new_completion = child_req_output.outputs[0] + new_completion.index = self.num_completions + self.request_output.outputs.append(new_completion) + + def get_warmup_request_id(self) -> str: + return "w_" + self.request_id + + def get_child_request_id( + self, + index: int, + ) -> str: + return str(index) + "_" + self.request_id + + @property + def num_completions(self) -> int: + assert self.request_output is not None + return len(self.request_output.outputs) @property def n(self) -> int: @@ -42,3 +72,28 @@ def prompt_logprobs(self) -> Optional[int]: @property def output_kind(self) -> RequestOutputKind: return self.sampling_params.output_kind + + +class ParallelSamplingOutputProcessor: + + def __init__( + self, + parent_state: ParentRequestState, + ) -> None: + self.parent_state = parent_state + + def process_output( + self, + child_req_output: RequestOutput, + ) -> Optional[RequestOutput]: + if self.parent_state.output_kind == RequestOutputKind.FINAL_ONLY: + # stream=false: accumulate child completions + self.parent_state.add_output(child_req_output) + if self.parent_state.num_completions == self.parent_state.n: + # Return accumulated request output after obtaining + # all completions + return self.parent_state.request_output + else: + # stream=true: return child completions immediately + pass + return None From cd649dfb4e06aae9f8076e4e1f4a0c64c28e9593 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Feb 2025 19:28:51 +0000 Subject: [PATCH 04/54] streaming nearly works Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 10 +++++++--- vllm/v1/engine/parallel_sampling.py | 13 +++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index be389f6410ba3..80b7f0ed9dce8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -52,7 +52,8 @@ def __init__( assert start_engine_loop self.model_config = vllm_config.model_config - self.enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching + self.enable_prefix_caching = ( + vllm_config.cache_config.enable_prefix_caching) self.log_requests = log_requests self.log_stats = log_stats @@ -262,7 +263,9 @@ async def _parallel_sampling_batch( "max_tokens": 1, "n": - 1 + 1, + "output_kind": + RequestOutputKind.FINAL_ONLY }) async for _ in self._generate( prompt, @@ -273,6 +276,7 @@ async def _parallel_sampling_batch( prompt_adapter_request, priority, ): + # Exhaust the generator pass seed = 42 @@ -293,7 +297,7 @@ async def _parallel_sampling_batch( prompt_adapter_request, priority, ): - if req_out := output_processor.process_output(out): + if req_out := output_processor.process_output(out, idx): yield req_out async def generate( diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 626b66de2484e..79dfd002a719f 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -19,11 +19,12 @@ def __init__(self, request_id: str, def get_child_sampling_params( self, - kwargs: Dict[str, Any] = {}, + kwargs: Optional[Dict[str, Any]] = None, ) -> SamplingParams: sampling_params = copy(self.sampling_params) - for kw in kwargs: - setattr(sampling_params, kw, kwargs[kw]) + if kwargs is not None: + for kw in kwargs: + setattr(sampling_params, kw, kwargs[kw]) return sampling_params def add_output( @@ -85,6 +86,7 @@ def __init__( def process_output( self, child_req_output: RequestOutput, + index: int, ) -> Optional[RequestOutput]: if self.parent_state.output_kind == RequestOutputKind.FINAL_ONLY: # stream=false: accumulate child completions @@ -95,5 +97,8 @@ def process_output( return self.parent_state.request_output else: # stream=true: return child completions immediately - pass + child_req_output.request_id = self.parent_state.request_id + child_req_output.outputs[0].index = index + return child_req_output + return None From a5415efea3ee507e2de61b0a55afdb9515dbc7ab Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 11 Feb 2025 13:10:21 +0000 Subject: [PATCH 05/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 80b7f0ed9dce8..fc67a63258f43 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -279,6 +279,7 @@ async def _parallel_sampling_batch( # Exhaust the generator pass + # n child requests seed = 42 for idx in range(n): c_sampling_params = parent_state.get_child_sampling_params({ From a6637a9d99d2da4d1c4e5e2487c91361d10ba451 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 11 Feb 2025 14:53:36 +0000 Subject: [PATCH 06/54] good async implementation Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 56 +++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index fc67a63258f43..52c38dd843306 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -242,6 +242,16 @@ async def _generate( await self.abort(request_id) raise + async def _parallel_sampling_task( + self, + gen: AsyncGenerator[RequestOutput, None], + output_processor: ParallelSamplingOutputProcessor, + index: int, + ) -> AsyncGenerator[RequestOutput, None]: + async for out in gen: + if req_out := output_processor.process_output(out, index): + yield req_out + async def _parallel_sampling_batch( self, prompt: PromptType, @@ -279,7 +289,9 @@ async def _parallel_sampling_batch( # Exhaust the generator pass - # n child requests + # Aggregate generators for n child requests + gens = [] + active = {} seed = 42 for idx in range(n): c_sampling_params = parent_state.get_child_sampling_params({ @@ -289,17 +301,37 @@ async def _parallel_sampling_batch( seed }) seed += 1 - async for out in self._generate( - prompt, - c_sampling_params, - parent_state.get_child_request_id(idx), - lora_request, - trace_headers, - prompt_adapter_request, - priority, - ): - if req_out := output_processor.process_output(out, idx): - yield req_out + child_gen = self._generate( + prompt, + c_sampling_params, + parent_state.get_child_request_id(idx), + lora_request, + trace_headers, + prompt_adapter_request, + priority, + ) + gen = self._parallel_sampling_task(child_gen, output_processor, + idx) + gens.append(gen) + active[asyncio.create_task(gen.__anext__())] = idx + + try: + while active: + done, _ = await asyncio.wait( + active.keys(), return_when=asyncio.FIRST_COMPLETED) + for task in done: + idx = active.pop(task) + try: + result = task.result() + yield result + # Schedule the next result + active[asyncio.create_task( + gens[idx].__anext__())] = idx + except StopAsyncIteration: + continue + finally: + for task in active: + task.cancel() async def generate( self, From af11e412b63c5542ef2e220e234fdb176ee7b608 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 11 Feb 2025 15:24:48 +0000 Subject: [PATCH 07/54] seed Signed-off-by: Andrew Feldman --- .../v1/entrypoints/openai/test_completion.py | 59 +++++++++++++++++++ vllm/v1/engine/async_llm.py | 5 +- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index ef46a16ef3447..dc72089572f8d 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -250,6 +250,65 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_parallel_no_streaming(client: openai.AsyncOpenAI, + model_name: str): + """Parallel sampling without streaming. + A single request output contains a list of completions. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + stream=False) + + for choice in completion.choices: + assert choice.finish_reason is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + stream=True) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == n + for chunk in chunks: + assert len(chunk) == max_tokens + print("".join(chunk)) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 52c38dd843306..6f09c20d06932 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -292,7 +292,7 @@ async def _parallel_sampling_batch( # Aggregate generators for n child requests gens = [] active = {} - seed = 42 + seed = sampling_params.seed for idx in range(n): c_sampling_params = parent_state.get_child_sampling_params({ "n": @@ -300,7 +300,8 @@ async def _parallel_sampling_batch( "seed": seed }) - seed += 1 + if seed is not None: + seed += 1 child_gen = self._generate( prompt, c_sampling_params, From 07f0c17bbacef1c552fc7150c7d6b16fc1cdb794 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 11 Feb 2025 19:30:34 +0000 Subject: [PATCH 08/54] feedback Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 6f09c20d06932..a5f9969d53ff9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -252,7 +252,7 @@ async def _parallel_sampling_task( if req_out := output_processor.process_output(out, index): yield req_out - async def _parallel_sampling_batch( + async def _generate_parallel_sampling( self, prompt: PromptType, sampling_params: SamplingParams, @@ -266,6 +266,12 @@ async def _parallel_sampling_batch( output_processor = ParallelSamplingOutputProcessor(parent_state) n = parent_state.n + # Adapted from sglang: + # https://github.com/sgl-project/sglang/blob/ + # 4fe92bfca5517f3cf5ca967fc5fcfdb7cf335f30/ + # python/sglang/srt/managers/ + # tokenizer_manager.py#L456-L532 + if self.enable_prefix_caching: # If engine uses APC, generate a “warmup request” with # max_tokens=1 which populates the APC @@ -352,7 +358,7 @@ async def generate( prompt_adapter_request, priority): yield out else: - async for out in self._parallel_sampling_batch( + async for out in self._generate_parallel_sampling( prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request, priority): yield out From 2e828a81aaacd3a9ac1a75c1ddb3e3b99d640343 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 11 Feb 2025 20:57:31 +0000 Subject: [PATCH 09/54] linting Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 59 +++++++++------- vllm/v1/engine/parallel_sampling.py | 101 ++++++++++++++++++++++++---- 2 files changed, 124 insertions(+), 36 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 085604e88fec1..0d3f6c15e0b15 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,7 @@ import asyncio import os -from typing import AsyncGenerator, List, Mapping, Optional, Type, Union +from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union import numpy as np @@ -245,13 +245,34 @@ async def _generate( await self.abort(request_id) raise - async def _parallel_sampling_task( + async def _parallel_sampling_child_gen( self, - gen: AsyncGenerator[RequestOutput, None], + child_gen: AsyncGenerator[RequestOutput, None], output_processor: ParallelSamplingOutputProcessor, index: int, ) -> AsyncGenerator[RequestOutput, None]: - async for out in gen: + """A single parallel sampling child request + output generator. + + Each parallel sampling request triggers at + least two child requests. This generator + yields zero or more request outputs to + return to the caller, as they become + available. + + Args: + child_gen: generator for child request + outputs. + output_processor: transform child request + outputs into parent + request outputs + index: index within the `n` child requests + + Returns: + Yields zero or more request outputs to return + to the caller. + """ + async for out in child_gen: if req_out := output_processor.process_output(out, index): yield req_out @@ -265,6 +286,8 @@ async def _generate_parallel_sampling( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: + """Generation completes for parallel sampling requests.""" + parent_state = ParentRequestState(request_id, sampling_params) output_processor = ParallelSamplingOutputProcessor(parent_state) n = parent_state.n @@ -278,14 +301,7 @@ async def _generate_parallel_sampling( if self.enable_prefix_caching: # If engine uses APC, generate a “warmup request” with # max_tokens=1 which populates the APC - w_sampling_params = parent_state.get_child_sampling_params({ - "max_tokens": - 1, - "n": - 1, - "output_kind": - RequestOutputKind.FINAL_ONLY - }) + w_sampling_params = parent_state.get_warmup_sampling_params() async for _ in self._generate( prompt, w_sampling_params, @@ -299,16 +315,11 @@ async def _generate_parallel_sampling( pass # Aggregate generators for n child requests - gens = [] - active = {} + gens: List[AsyncGenerator[RequestOutput, None]] = [] + active: Dict[asyncio.Task, int] = {} seed = sampling_params.seed for idx in range(n): - c_sampling_params = parent_state.get_child_sampling_params({ - "n": - 1, - "seed": - seed - }) + c_sampling_params = parent_state.get_child_sampling_params(seed) if seed is not None: seed += 1 child_gen = self._generate( @@ -320,10 +331,10 @@ async def _generate_parallel_sampling( prompt_adapter_request, priority, ) - gen = self._parallel_sampling_task(child_gen, output_processor, - idx) + gen = self._parallel_sampling_child_gen(child_gen, + output_processor, idx) gens.append(gen) - active[asyncio.create_task(gen.__anext__())] = idx + active[asyncio.create_task(gen.__anext__())] = idx # type: ignore try: while active: @@ -336,7 +347,7 @@ async def _generate_parallel_sampling( yield result # Schedule the next result active[asyncio.create_task( - gens[idx].__anext__())] = idx + gens[idx].__anext__())] = idx # type: ignore except StopAsyncIteration: continue finally: diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 79dfd002a719f..e37689bdb0902 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,13 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import Any, Dict, Optional +from typing import Optional from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams class ParentRequestState: + """Info and state for parallel sampling request. + + Store parent request ID and sampling params. + Facilitate generating child request sampling params. + When stream mode is disabled, then `self.request_output` + aggregates completions. + """ + request_id: str sampling_params: SamplingParams request_output: Optional[RequestOutput] = None @@ -17,20 +25,38 @@ def __init__(self, request_id: str, self.request_id = request_id self.sampling_params = sampling_params + def get_warmup_sampling_params(self, ) -> SamplingParams: + sampling_params = copy(self.sampling_params) + sampling_params.max_tokens = 1 + sampling_params.n = 1 + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + return sampling_params + def get_child_sampling_params( self, - kwargs: Optional[Dict[str, Any]] = None, + seed: Optional[int], ) -> SamplingParams: sampling_params = copy(self.sampling_params) - if kwargs is not None: - for kw in kwargs: - setattr(sampling_params, kw, kwargs[kw]) + sampling_params.n = 1 + sampling_params.seed = seed return sampling_params def add_output( self, child_req_output: RequestOutput, ) -> None: + """Aggregate a parallel sampling child + request output. + + Non-stream-mode (`output_kind == FINAL_ONLY`) + only. Inject correct parent request ID and + completion index. + + Args: + child_req_output: a single request output + from a parallel sampling + child request. + """ if self.request_output is None: # Save the first request output; reinstate # original request ID; metrics are not @@ -39,11 +65,38 @@ def add_output( child_req_output.metrics = None self.request_output = child_req_output else: - # Add completion to the request output + # Aggregate additional completion into request + # output new_completion = child_req_output.outputs[0] new_completion.index = self.num_completions self.request_output.outputs.append(new_completion) + def transform_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> RequestOutput: + """Transform a parallel sampling child + request output into a parent request output. + + Stream-mode (`output_kind == DELTA`) only. + Inject correct parent request ID and completion + index. + + Args: + child_req_output: a single request output + from a parallel sampling + child request. + index: index within `n` parallel sampling + child requests + + Returns: + Stream-mode parent request output. + """ + child_req_output.request_id = self.request_id + child_req_output.outputs[0].index = index + return child_req_output + def get_warmup_request_id(self) -> str: return "w_" + self.request_id @@ -76,11 +129,15 @@ def output_kind(self) -> RequestOutputKind: class ParallelSamplingOutputProcessor: + """For parallel sampling requests, + filter and transform child request + outputs.""" def __init__( self, parent_state: ParentRequestState, ) -> None: + """Store parent request state.""" self.parent_state = parent_state def process_output( @@ -88,17 +145,37 @@ def process_output( child_req_output: RequestOutput, index: int, ) -> Optional[RequestOutput]: + """Filter, aggregate and transform parallel sampling + child request outputs. + + If the parent request has `stream=false` + (`output_kind == FINAL_ONLY`), each child will also have + `output_kind == FINAL_ONLY`. All child request outputs + must be aggregated into a single request output, with + multiple completions. This request output is only returned + once `n` completions are aggregated. + + If the parent request has `stream=true` + (`output_kind == DELTA`), each child will also have + `output_kind == DELTA`. All child request outputs + must be streamed directly to the caller. + + Args: + child_req_output: a single child request output + index: index within `n` child requests + + Returns: + `None`, unless a processed request output is ready to + send back to the caller. + """ if self.parent_state.output_kind == RequestOutputKind.FINAL_ONLY: - # stream=false: accumulate child completions + # stream=false: aggregate child completions self.parent_state.add_output(child_req_output) if self.parent_state.num_completions == self.parent_state.n: - # Return accumulated request output after obtaining + # Return aggregated request output after obtaining # all completions return self.parent_state.request_output else: # stream=true: return child completions immediately - child_req_output.request_id = self.parent_state.request_id - child_req_output.outputs[0].index = index - return child_req_output - + return self.parent_state.transform_output(child_req_output, index) return None From 374f1c7d8956201173d4f6a950bf00933d72e415 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Thu, 13 Feb 2025 06:37:54 -0500 Subject: [PATCH 10/54] Update vllm/v1/engine/async_llm.py Co-authored-by: Nick Hill --- vllm/v1/engine/async_llm.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0d3f6c15e0b15..f9ca7d4f2f86e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -365,17 +365,12 @@ async def generate( priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: n = sampling_params.n - if n is None or sampling_params.n == 1: - async for out in self._generate(prompt, sampling_params, - request_id, lora_request, - trace_headers, - prompt_adapter_request, priority): - yield out - else: - async for out in self._generate_parallel_sampling( - prompt, sampling_params, request_id, lora_request, - trace_headers, prompt_adapter_request, priority): - yield out + _generate = self._generate if n is None or n == 1 \ + else self._generate_parallel_sampling + return _generate(prompt, sampling_params, + request_id, lora_request, + trace_headers, + prompt_adapter_request, priority) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" From 35036eac27c50eb918707fccabd9eb0e970976ea Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Feb 2025 11:56:24 +0000 Subject: [PATCH 11/54] async def -> def Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f9ca7d4f2f86e..8e8c67a953eff 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -354,7 +354,7 @@ async def _generate_parallel_sampling( for task in active: task.cancel() - async def generate( + def generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -367,10 +367,8 @@ async def generate( n = sampling_params.n _generate = self._generate if n is None or n == 1 \ else self._generate_parallel_sampling - return _generate(prompt, sampling_params, - request_id, lora_request, - trace_headers, - prompt_adapter_request, priority) + return _generate(prompt, sampling_params, request_id, lora_request, + trace_headers, prompt_adapter_request, priority) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" From b45c41349ce14b6939110b4c2eaa62028a20ad80 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Thu, 13 Feb 2025 06:57:47 -0500 Subject: [PATCH 12/54] Update vllm/v1/engine/parallel_sampling.py Co-authored-by: Nick Hill --- vllm/v1/engine/parallel_sampling.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index e37689bdb0902..b556d8ad34524 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -168,14 +168,14 @@ def process_output( `None`, unless a processed request output is ready to send back to the caller. """ - if self.parent_state.output_kind == RequestOutputKind.FINAL_ONLY: - # stream=false: aggregate child completions - self.parent_state.add_output(child_req_output) - if self.parent_state.num_completions == self.parent_state.n: - # Return aggregated request output after obtaining - # all completions - return self.parent_state.request_output - else: + if self.parent_state.output_kind != RequestOutputKind.FINAL_ONLY: # stream=true: return child completions immediately return self.parent_state.transform_output(child_req_output, index) + + # stream=false: aggregate child completions + self.parent_state.add_output(child_req_output) + if self.parent_state.num_completions == self.parent_state.n: + # Return aggregated request output after obtaining + # all completions + return self.parent_state.request_output return None From 00bb1f2603fdc2863f78fc73ed5b34333e84544d Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Thu, 13 Feb 2025 16:12:13 -0500 Subject: [PATCH 13/54] Update vllm/v1/engine/parallel_sampling.py Co-authored-by: Nick Hill --- vllm/v1/engine/parallel_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index b556d8ad34524..bce7cfc1771cd 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -68,8 +68,8 @@ def add_output( # Aggregate additional completion into request # output new_completion = child_req_output.outputs[0] - new_completion.index = self.num_completions - self.request_output.outputs.append(new_completion) + new_completion.index = index + self.request_output.outputs[index] = new_completion def transform_output( self, From b16ba2b549258c9aefb42f981fc27e3a66574bd3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Feb 2025 21:50:50 +0000 Subject: [PATCH 14/54] index Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index bce7cfc1771cd..4fd377d501197 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -44,6 +44,7 @@ def get_child_sampling_params( def add_output( self, child_req_output: RequestOutput, + index: int, ) -> None: """Aggregate a parallel sampling child request output. @@ -55,7 +56,8 @@ def add_output( Args: child_req_output: a single request output from a parallel sampling - child request. + child request. + index: index within `n` child """ if self.request_output is None: # Save the first request output; reinstate @@ -69,7 +71,15 @@ def add_output( # output new_completion = child_req_output.outputs[0] new_completion.index = index - self.request_output.outputs[index] = new_completion + # Note: will be sorted by index later + self.request_output.outputs.append(new_completion) + + def get_parent_request_output(self) -> RequestOutput: + """Invariant: parent completion outputs sorted by index""" + assert self.request_output is not None + self.request_output.outputs = sorted(self.request_output.outputs, + key=lambda x: x.index) + return self.request_output def transform_output( self, @@ -171,11 +181,11 @@ def process_output( if self.parent_state.output_kind != RequestOutputKind.FINAL_ONLY: # stream=true: return child completions immediately return self.parent_state.transform_output(child_req_output, index) - + # stream=false: aggregate child completions - self.parent_state.add_output(child_req_output) + self.parent_state.add_output(child_req_output, index) if self.parent_state.num_completions == self.parent_state.n: # Return aggregated request output after obtaining # all completions - return self.parent_state.request_output + return self.parent_state.get_parent_request_output() return None From fbcd2136da0a320b0d05bb1aaba1b0593220df19 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Feb 2025 22:13:34 +0000 Subject: [PATCH 15/54] sort by index Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 4fd377d501197..dd72829da2b6a 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -30,6 +30,8 @@ def get_warmup_sampling_params(self, ) -> SamplingParams: sampling_params.max_tokens = 1 sampling_params.n = 1 sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + sampling_params.logprobs = None + sampling_params.prompt_logprobs = None return sampling_params def get_child_sampling_params( @@ -125,14 +127,6 @@ def num_completions(self) -> int: def n(self) -> int: return self.sampling_params.n - @property - def logprobs(self) -> Optional[int]: - return self.sampling_params.logprobs - - @property - def prompt_logprobs(self) -> Optional[int]: - return self.sampling_params.prompt_logprobs - @property def output_kind(self) -> RequestOutputKind: return self.sampling_params.output_kind From a4ded4053678ffa4798956d307f8f635c09176d8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Feb 2025 22:44:54 +0000 Subject: [PATCH 16/54] no warmup Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 16 ---------------- vllm/v1/engine/parallel_sampling.py | 9 --------- 2 files changed, 25 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 8e8c67a953eff..11d7e461a5089 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -298,22 +298,6 @@ async def _generate_parallel_sampling( # python/sglang/srt/managers/ # tokenizer_manager.py#L456-L532 - if self.enable_prefix_caching: - # If engine uses APC, generate a “warmup request” with - # max_tokens=1 which populates the APC - w_sampling_params = parent_state.get_warmup_sampling_params() - async for _ in self._generate( - prompt, - w_sampling_params, - parent_state.get_warmup_request_id(), - lora_request, - trace_headers, - prompt_adapter_request, - priority, - ): - # Exhaust the generator - pass - # Aggregate generators for n child requests gens: List[AsyncGenerator[RequestOutput, None]] = [] active: Dict[asyncio.Task, int] = {} diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index dd72829da2b6a..d844b50f238b6 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -25,15 +25,6 @@ def __init__(self, request_id: str, self.request_id = request_id self.sampling_params = sampling_params - def get_warmup_sampling_params(self, ) -> SamplingParams: - sampling_params = copy(self.sampling_params) - sampling_params.max_tokens = 1 - sampling_params.n = 1 - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - sampling_params.logprobs = None - sampling_params.prompt_logprobs = None - return sampling_params - def get_child_sampling_params( self, seed: Optional[int], From 119a77c15269b17a8ff316bb54c14773eb0d7a75 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Feb 2025 12:37:07 +0000 Subject: [PATCH 17/54] refactor transform_output() Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 33 +++-------------------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index d844b50f238b6..42899eee6ce84 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -74,35 +74,6 @@ def get_parent_request_output(self) -> RequestOutput: key=lambda x: x.index) return self.request_output - def transform_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> RequestOutput: - """Transform a parallel sampling child - request output into a parent request output. - - Stream-mode (`output_kind == DELTA`) only. - Inject correct parent request ID and completion - index. - - Args: - child_req_output: a single request output - from a parallel sampling - child request. - index: index within `n` parallel sampling - child requests - - Returns: - Stream-mode parent request output. - """ - child_req_output.request_id = self.request_id - child_req_output.outputs[0].index = index - return child_req_output - - def get_warmup_request_id(self) -> str: - return "w_" + self.request_id - def get_child_request_id( self, index: int, @@ -165,7 +136,9 @@ def process_output( """ if self.parent_state.output_kind != RequestOutputKind.FINAL_ONLY: # stream=true: return child completions immediately - return self.parent_state.transform_output(child_req_output, index) + child_req_output.request_id = self.parent_state.request_id + child_req_output.outputs[0].index = index + return child_req_output # stream=false: aggregate child completions self.parent_state.add_output(child_req_output, index) From a64e3b3376e6078cdb881213913b28675823ca4d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Feb 2025 12:45:03 +0000 Subject: [PATCH 18/54] wip Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 213e8b18854f4..055ab2139d76a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -307,13 +307,13 @@ async def _generate_parallel_sampling( if seed is not None: seed += 1 child_gen = self._generate( - prompt, - c_sampling_params, - parent_state.get_child_request_id(idx), - lora_request, - trace_headers, - prompt_adapter_request, - priority, + prompt=prompt, + sampling_params=c_sampling_params, + request_id=parent_state.get_child_request_id(idx), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, ) gen = self._parallel_sampling_child_gen(child_gen, output_processor, idx) From 36cd5555e5dac400ff2210179102407f455931cc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Feb 2025 13:01:17 +0000 Subject: [PATCH 19/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 38 +------------ vllm/v1/engine/parallel_sampling.py | 84 +++++++++++++++++------------ 2 files changed, 51 insertions(+), 71 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 055ab2139d76a..1a9bf1c53eb16 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,8 +24,7 @@ from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import (ParallelSamplingOutputProcessor, - ParentRequestState) +from vllm.v1.engine.parallel_sampling import ParentRequestState from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -245,37 +244,6 @@ async def _generate( await self.abort(request_id) raise - async def _parallel_sampling_child_gen( - self, - child_gen: AsyncGenerator[RequestOutput, None], - output_processor: ParallelSamplingOutputProcessor, - index: int, - ) -> AsyncGenerator[RequestOutput, None]: - """A single parallel sampling child request - output generator. - - Each parallel sampling request triggers at - least two child requests. This generator - yields zero or more request outputs to - return to the caller, as they become - available. - - Args: - child_gen: generator for child request - outputs. - output_processor: transform child request - outputs into parent - request outputs - index: index within the `n` child requests - - Returns: - Yields zero or more request outputs to return - to the caller. - """ - async for out in child_gen: - if req_out := output_processor.process_output(out, index): - yield req_out - async def _generate_parallel_sampling( self, prompt: PromptType, @@ -289,7 +257,6 @@ async def _generate_parallel_sampling( """Generation completes for parallel sampling requests.""" parent_state = ParentRequestState(request_id, sampling_params) - output_processor = ParallelSamplingOutputProcessor(parent_state) n = parent_state.n # Adapted from sglang: @@ -315,8 +282,7 @@ async def _generate_parallel_sampling( prompt_adapter_request=prompt_adapter_request, priority=priority, ) - gen = self._parallel_sampling_child_gen(child_gen, - output_processor, idx) + gen = parent_state.parallel_sampling_child_gen(child_gen, idx) gens.append(gen) active[asyncio.create_task(gen.__anext__())] = idx # type: ignore diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 42899eee6ce84..64a52fe54ef52 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import Optional +from typing import AsyncGenerator, Optional from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -34,7 +34,7 @@ def get_child_sampling_params( sampling_params.seed = seed return sampling_params - def add_output( + def _add_output( self, child_req_output: RequestOutput, index: int, @@ -67,7 +67,7 @@ def add_output( # Note: will be sorted by index later self.request_output.outputs.append(new_completion) - def get_parent_request_output(self) -> RequestOutput: + def _get_parent_request_output(self) -> RequestOutput: """Invariant: parent completion outputs sorted by index""" assert self.request_output is not None self.request_output.outputs = sorted(self.request_output.outputs, @@ -80,33 +80,7 @@ def get_child_request_id( ) -> str: return str(index) + "_" + self.request_id - @property - def num_completions(self) -> int: - assert self.request_output is not None - return len(self.request_output.outputs) - - @property - def n(self) -> int: - return self.sampling_params.n - - @property - def output_kind(self) -> RequestOutputKind: - return self.sampling_params.output_kind - - -class ParallelSamplingOutputProcessor: - """For parallel sampling requests, - filter and transform child request - outputs.""" - - def __init__( - self, - parent_state: ParentRequestState, - ) -> None: - """Store parent request state.""" - self.parent_state = parent_state - - def process_output( + def _process_output( self, child_req_output: RequestOutput, index: int, @@ -134,16 +108,56 @@ def process_output( `None`, unless a processed request output is ready to send back to the caller. """ - if self.parent_state.output_kind != RequestOutputKind.FINAL_ONLY: + if self.output_kind != RequestOutputKind.FINAL_ONLY: # stream=true: return child completions immediately - child_req_output.request_id = self.parent_state.request_id + child_req_output.request_id = self.request_id child_req_output.outputs[0].index = index return child_req_output # stream=false: aggregate child completions - self.parent_state.add_output(child_req_output, index) - if self.parent_state.num_completions == self.parent_state.n: + self._add_output(child_req_output, index) + if self.num_completions == self.n: # Return aggregated request output after obtaining # all completions - return self.parent_state.get_parent_request_output() + return self._get_parent_request_output() return None + + async def parallel_sampling_child_gen( + self, + child_gen: AsyncGenerator[RequestOutput, None], + index: int, + ) -> AsyncGenerator[RequestOutput, None]: + """Output generator for a single parallel sampling + child request. + + Each parallel sampling request triggers at + least two child requests. This generator + yields zero or more request outputs to + return to the caller, as they become + available. + + Args: + child_gen: generator for child request + outputs. + index: index within the `n` child requests + + Returns: + Yields zero or more request outputs to return + to the caller. + """ + async for out in child_gen: + if req_out := self._process_output(out, index): + yield req_out + + @property + def num_completions(self) -> int: + assert self.request_output is not None + return len(self.request_output.outputs) + + @property + def n(self) -> int: + return self.sampling_params.n + + @property + def output_kind(self) -> RequestOutputKind: + return self.sampling_params.output_kind From 39d3d0b32cc265ab49338f32c197d77d8231733b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Feb 2025 13:32:56 +0000 Subject: [PATCH 20/54] sampling params caching Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 5 +---- vllm/v1/engine/parallel_sampling.py | 34 ++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1a9bf1c53eb16..1509d9c4bf127 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -268,11 +268,8 @@ async def _generate_parallel_sampling( # Aggregate generators for n child requests gens: List[AsyncGenerator[RequestOutput, None]] = [] active: Dict[asyncio.Task, int] = {} - seed = sampling_params.seed for idx in range(n): - c_sampling_params = parent_state.get_child_sampling_params(seed) - if seed is not None: - seed += 1 + c_sampling_params = parent_state.get_child_sampling_params(idx) child_gen = self._generate( prompt=prompt, sampling_params=c_sampling_params, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 64a52fe54ef52..a5fb350e63b84 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -18,21 +18,45 @@ class ParentRequestState: request_id: str sampling_params: SamplingParams + cached_child_sampling_params: Optional[SamplingParams] request_output: Optional[RequestOutput] = None def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params + self.cached_child_sampling_params = None def get_child_sampling_params( self, - seed: Optional[int], + index: int, ) -> SamplingParams: - sampling_params = copy(self.sampling_params) - sampling_params.n = 1 - sampling_params.seed = seed - return sampling_params + """Efficiently obtain child `sampling_params` + + If `sampling_params.seed` is not `None` then + each child request requires a unique clone of + parent `sampling_params` with a unique seed. + + Args: + index: index within `n` child requests + + Returns: + Child `sampling_params` instance. + """ + seed = self.sampling_params.seed + if seed is None and self.cached_child_sampling_params: + # Reuse child sampling_params data structure + return self.cached_child_sampling_params + # Build child sampling_params + c_sampling_params = copy(self.sampling_params) + c_sampling_params.n = 1 + if seed is None: + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = c_sampling_params + else: + # Each child gets a clone with a unique seed + c_sampling_params.seed = seed + index + return c_sampling_params def _add_output( self, From 5462e830ad5ba59ef5917926d1006a7bf173b4c9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Feb 2025 15:04:55 +0000 Subject: [PATCH 21/54] wip Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 51 ++++++++--------------------- vllm/v1/engine/parallel_sampling.py | 7 ++-- 2 files changed, 17 insertions(+), 41 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1509d9c4bf127..f0c07dd4cad94 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -2,7 +2,7 @@ import asyncio import os -from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union +from typing import AsyncGenerator, List, Mapping, Optional, Type, Union import numpy as np @@ -21,10 +21,10 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import cdiv, kill_process_tree +from vllm.utils import cdiv, kill_process_tree, merge_async_iterators from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import ParentRequestState +from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -244,7 +244,7 @@ async def _generate( await self.abort(request_id) raise - async def _generate_parallel_sampling( + def _generate_parallel_sampling( self, prompt: PromptType, sampling_params: SamplingParams, @@ -254,54 +254,29 @@ async def _generate_parallel_sampling( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: - """Generation completes for parallel sampling requests.""" - - parent_state = ParentRequestState(request_id, sampling_params) - n = parent_state.n - - # Adapted from sglang: - # https://github.com/sgl-project/sglang/blob/ - # 4fe92bfca5517f3cf5ca967fc5fcfdb7cf335f30/ - # python/sglang/srt/managers/ - # tokenizer_manager.py#L456-L532 + """Generate completions for parallel sampling requests.""" + req_mgr = ParallelSamplingRequestManager(request_id, sampling_params) + n = req_mgr.n # Aggregate generators for n child requests gens: List[AsyncGenerator[RequestOutput, None]] = [] - active: Dict[asyncio.Task, int] = {} for idx in range(n): - c_sampling_params = parent_state.get_child_sampling_params(idx) + c_sampling_params = req_mgr.get_child_sampling_params(idx) child_gen = self._generate( prompt=prompt, sampling_params=c_sampling_params, - request_id=parent_state.get_child_request_id(idx), + request_id=req_mgr.get_child_request_id(idx), lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, ) - gen = parent_state.parallel_sampling_child_gen(child_gen, idx) + gen = req_mgr.parallel_sampling_child_gen(child_gen, idx) gens.append(gen) - active[asyncio.create_task(gen.__anext__())] = idx # type: ignore - try: - while active: - done, _ = await asyncio.wait( - active.keys(), return_when=asyncio.FIRST_COMPLETED) - for task in done: - idx = active.pop(task) - try: - result = task.result() - yield result - # Schedule the next result - active[asyncio.create_task( - gens[idx].__anext__())] = idx # type: ignore - except StopAsyncIteration: - continue - finally: - for task in active: - task.cancel() - - def generate( + return merge_async_iterators(*gens) + + async def generate( self, prompt: PromptType, sampling_params: SamplingParams, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index a5fb350e63b84..baaf0c653fc12 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -7,13 +7,14 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams -class ParentRequestState: - """Info and state for parallel sampling request. +class ParallelSamplingRequestManager: + """Info, state & processing for parallel sampling request. Store parent request ID and sampling params. Facilitate generating child request sampling params. When stream mode is disabled, then `self.request_output` - aggregates completions. + aggregates child request completions & transforms them + into a parent request completion. """ request_id: str From c0f8fb1fef889d674a9aaef38d9555fb0e9dbee0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 16 Feb 2025 21:21:00 +0000 Subject: [PATCH 22/54] wip Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 10 ++++++---- vllm/v1/engine/parallel_sampling.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f0c07dd4cad94..3c4916c86d061 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -244,7 +244,7 @@ async def _generate( await self.abort(request_id) raise - def _generate_parallel_sampling( + async def _generate_parallel_sampling( self, prompt: PromptType, sampling_params: SamplingParams, @@ -274,9 +274,11 @@ def _generate_parallel_sampling( gen = req_mgr.parallel_sampling_child_gen(child_gen, idx) gens.append(gen) - return merge_async_iterators(*gens) + # Merge generators + async for out in merge_async_iterators(*gens): + yield out[1] # out[0] is index - async def generate( + def generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -288,7 +290,7 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: n = sampling_params.n _generate = self._generate if n is None or n == 1 \ - else self._generate_parallel_sampling + else self._generate_parallel_sampling # handle parallel sampling return _generate(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request, priority) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index baaf0c653fc12..5af568f6cd08c 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -12,9 +12,10 @@ class ParallelSamplingRequestManager: Store parent request ID and sampling params. Facilitate generating child request sampling params. + Transform child request outputs into parent request + outputs. When stream mode is disabled, then `self.request_output` - aggregates child request completions & transforms them - into a parent request completion. + aggregates child request completions. """ request_id: str From 103ceb6d05500688f74f83bfac57a11e724d07ad Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Feb 2025 18:32:36 +0000 Subject: [PATCH 23/54] parallel sampling core Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 189 ++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 vllm/v1/engine/parallel_sampling.py diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py new file mode 100644 index 0000000000000..5af568f6cd08c --- /dev/null +++ b/vllm/v1/engine/parallel_sampling.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 + +from copy import copy +from typing import AsyncGenerator, Optional + +from vllm.outputs import RequestOutput +from vllm.sampling_params import RequestOutputKind, SamplingParams + + +class ParallelSamplingRequestManager: + """Info, state & processing for parallel sampling request. + + Store parent request ID and sampling params. + Facilitate generating child request sampling params. + Transform child request outputs into parent request + outputs. + When stream mode is disabled, then `self.request_output` + aggregates child request completions. + """ + + request_id: str + sampling_params: SamplingParams + cached_child_sampling_params: Optional[SamplingParams] + request_output: Optional[RequestOutput] = None + + def __init__(self, request_id: str, + sampling_params: SamplingParams) -> None: + self.request_id = request_id + self.sampling_params = sampling_params + self.cached_child_sampling_params = None + + def get_child_sampling_params( + self, + index: int, + ) -> SamplingParams: + """Efficiently obtain child `sampling_params` + + If `sampling_params.seed` is not `None` then + each child request requires a unique clone of + parent `sampling_params` with a unique seed. + + Args: + index: index within `n` child requests + + Returns: + Child `sampling_params` instance. + """ + seed = self.sampling_params.seed + if seed is None and self.cached_child_sampling_params: + # Reuse child sampling_params data structure + return self.cached_child_sampling_params + # Build child sampling_params + c_sampling_params = copy(self.sampling_params) + c_sampling_params.n = 1 + if seed is None: + # Cache child sampling_params for later reuse + self.cached_child_sampling_params = c_sampling_params + else: + # Each child gets a clone with a unique seed + c_sampling_params.seed = seed + index + return c_sampling_params + + def _add_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> None: + """Aggregate a parallel sampling child + request output. + + Non-stream-mode (`output_kind == FINAL_ONLY`) + only. Inject correct parent request ID and + completion index. + + Args: + child_req_output: a single request output + from a parallel sampling + child request. + index: index within `n` child + """ + if self.request_output is None: + # Save the first request output; reinstate + # original request ID; metrics are not + # supported for parallel sampling + child_req_output.request_id = self.request_id + child_req_output.metrics = None + self.request_output = child_req_output + else: + # Aggregate additional completion into request + # output + new_completion = child_req_output.outputs[0] + new_completion.index = index + # Note: will be sorted by index later + self.request_output.outputs.append(new_completion) + + def _get_parent_request_output(self) -> RequestOutput: + """Invariant: parent completion outputs sorted by index""" + assert self.request_output is not None + self.request_output.outputs = sorted(self.request_output.outputs, + key=lambda x: x.index) + return self.request_output + + def get_child_request_id( + self, + index: int, + ) -> str: + return str(index) + "_" + self.request_id + + def _process_output( + self, + child_req_output: RequestOutput, + index: int, + ) -> Optional[RequestOutput]: + """Filter, aggregate and transform parallel sampling + child request outputs. + + If the parent request has `stream=false` + (`output_kind == FINAL_ONLY`), each child will also have + `output_kind == FINAL_ONLY`. All child request outputs + must be aggregated into a single request output, with + multiple completions. This request output is only returned + once `n` completions are aggregated. + + If the parent request has `stream=true` + (`output_kind == DELTA`), each child will also have + `output_kind == DELTA`. All child request outputs + must be streamed directly to the caller. + + Args: + child_req_output: a single child request output + index: index within `n` child requests + + Returns: + `None`, unless a processed request output is ready to + send back to the caller. + """ + if self.output_kind != RequestOutputKind.FINAL_ONLY: + # stream=true: return child completions immediately + child_req_output.request_id = self.request_id + child_req_output.outputs[0].index = index + return child_req_output + + # stream=false: aggregate child completions + self._add_output(child_req_output, index) + if self.num_completions == self.n: + # Return aggregated request output after obtaining + # all completions + return self._get_parent_request_output() + return None + + async def parallel_sampling_child_gen( + self, + child_gen: AsyncGenerator[RequestOutput, None], + index: int, + ) -> AsyncGenerator[RequestOutput, None]: + """Output generator for a single parallel sampling + child request. + + Each parallel sampling request triggers at + least two child requests. This generator + yields zero or more request outputs to + return to the caller, as they become + available. + + Args: + child_gen: generator for child request + outputs. + index: index within the `n` child requests + + Returns: + Yields zero or more request outputs to return + to the caller. + """ + async for out in child_gen: + if req_out := self._process_output(out, index): + yield req_out + + @property + def num_completions(self) -> int: + assert self.request_output is not None + return len(self.request_output.outputs) + + @property + def n(self) -> int: + return self.sampling_params.n + + @property + def output_kind(self) -> RequestOutputKind: + return self.sampling_params.output_kind From e6a11341da46ae5e77ebd87f1a3219dae72e0813 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Feb 2025 20:41:37 +0000 Subject: [PATCH 24/54] wip Signed-off-by: Andrew Feldman --- vllm/entrypoints/llm.py | 56 +++++++++++++++++++++++++---- vllm/v1/engine/parallel_sampling.py | 2 +- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 40b7a529ebfb5..ed64cfaf5b407 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -48,6 +48,13 @@ _R = TypeVar("_R", default=Any) +PromptsArgType=Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, List[str]]]] +SamplingParamsArgType=Optional[Union[SamplingParams, + Sequence[SamplingParams]]] +LoRARequestArgType=Optional[Union[List[LoRARequest], LoRARequest]] +PriorityArgType=Optional[List[int]] + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -190,6 +197,8 @@ def __init__( it defaults to False. ''' + self._v1=envs.VLLM_USE_V1 + if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True @@ -377,17 +386,15 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, List[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + prompts: PromptsArgType = None, + sampling_params: SamplingParamsArgType = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: LoRARequestArgType = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - priority: Optional[List[int]] = None, + priority: PriorityArgType = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -455,8 +462,22 @@ def generate( **guided_options_request) if sampling_params is None: - # Use default sampling params. + # Use default sampling params. Note: n=1 by default sampling_params = self.get_default_sampling_params() + elif self._v1: + # V1 engine only: break out parallel sampling + # requests into `n` child requests + ( + prompts, + sampling_params, + lora_request, + priority, + ) = self._build_parallel_sampling_batch( + prompts, + sampling_params, + lora_request, + priority + ) self._validate_and_add_requests( prompts=parsed_prompts, @@ -467,8 +488,29 @@ def generate( priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) + + if self._v1: + # V1 engine only: aggregate parallel sampling child request + # outputs into parent request outputs + outputs = self._process_parallel_sampling_outputs(outputs) + return self.engine_class.validate_outputs(outputs, RequestOutput) + def _build_parallel_sampling_batch( + self, + prompts: PromptsArgType, + sampling_params: SamplingParamsArgType, + lora_request: LoRARequestArgType, + priority: PriorityArgType, + ) -> Tuple[PromptsArgType,SamplingParamsArgType, + LoRARequestArgType,PriorityArgType]: + pass + + def _process_parallel_sampling_outputs( + outputs: List[RequestOutput] + )->List[RequestOutput]: + if self._ + def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 5af568f6cd08c..412c0583fd1ae 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -186,4 +186,4 @@ def n(self) -> int: @property def output_kind(self) -> RequestOutputKind: - return self.sampling_params.output_kind + return self.sampling_params.output_kind \ No newline at end of file From 625e161b506f70b99c63070219a6a40e16c52f2a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Feb 2025 21:02:27 +0000 Subject: [PATCH 25/54] wip Signed-off-by: Andrew Feldman --- vllm/entrypoints/llm.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ed64cfaf5b407..2a8ea8802578b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -55,6 +55,17 @@ LoRARequestArgType=Optional[Union[List[LoRARequest], LoRARequest]] PriorityArgType=Optional[List[int]] +def split_parallel_sampling_batch( + sampling_params: SamplingParamsArgType, +)->List[int]: + if isinstance(sampling_params,SamplingParams) and sampling_params.n>1: + # There is one parallel sampling request + return [0] + if isinstance(sampling_params,Sequence[SamplingParams]): + # Multiple requests with potentially one or more + # parallel sampling requests + return any([sp.n>1 for sp in sampling_params]) + return [] class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -464,7 +475,9 @@ def generate( if sampling_params is None: # Use default sampling params. Note: n=1 by default sampling_params = self.get_default_sampling_params() + do_parallel_sampling = False elif self._v1: + do_parallel_sampling = any([sampling_params]) # V1 engine only: break out parallel sampling # requests into `n` child requests ( From e28388e3a770774b4f43aa3d92f8179d65eed65a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Feb 2025 21:33:27 +0000 Subject: [PATCH 26/54] add parallel sampling requests Signed-off-by: Andrew Feldman --- vllm/entrypoints/llm.py | 69 ++++-------------------------------- vllm/v1/engine/llm_engine.py | 52 +++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 62 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2a8ea8802578b..40b7a529ebfb5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -48,24 +48,6 @@ _R = TypeVar("_R", default=Any) -PromptsArgType=Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, List[str]]]] -SamplingParamsArgType=Optional[Union[SamplingParams, - Sequence[SamplingParams]]] -LoRARequestArgType=Optional[Union[List[LoRARequest], LoRARequest]] -PriorityArgType=Optional[List[int]] - -def split_parallel_sampling_batch( - sampling_params: SamplingParamsArgType, -)->List[int]: - if isinstance(sampling_params,SamplingParams) and sampling_params.n>1: - # There is one parallel sampling request - return [0] - if isinstance(sampling_params,Sequence[SamplingParams]): - # Multiple requests with potentially one or more - # parallel sampling requests - return any([sp.n>1 for sp in sampling_params]) - return [] class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -208,8 +190,6 @@ def __init__( it defaults to False. ''' - self._v1=envs.VLLM_USE_V1 - if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True @@ -397,15 +377,17 @@ def generate( ) def generate( self, - prompts: PromptsArgType = None, - sampling_params: SamplingParamsArgType = None, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, - lora_request: LoRARequestArgType = None, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, - priority: PriorityArgType = None, + priority: Optional[List[int]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -473,24 +455,8 @@ def generate( **guided_options_request) if sampling_params is None: - # Use default sampling params. Note: n=1 by default + # Use default sampling params. sampling_params = self.get_default_sampling_params() - do_parallel_sampling = False - elif self._v1: - do_parallel_sampling = any([sampling_params]) - # V1 engine only: break out parallel sampling - # requests into `n` child requests - ( - prompts, - sampling_params, - lora_request, - priority, - ) = self._build_parallel_sampling_batch( - prompts, - sampling_params, - lora_request, - priority - ) self._validate_and_add_requests( prompts=parsed_prompts, @@ -501,29 +467,8 @@ def generate( priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) - - if self._v1: - # V1 engine only: aggregate parallel sampling child request - # outputs into parent request outputs - outputs = self._process_parallel_sampling_outputs(outputs) - return self.engine_class.validate_outputs(outputs, RequestOutput) - def _build_parallel_sampling_batch( - self, - prompts: PromptsArgType, - sampling_params: SamplingParamsArgType, - lora_request: LoRARequestArgType, - priority: PriorityArgType, - ) -> Tuple[PromptsArgType,SamplingParamsArgType, - LoRARequestArgType,PriorityArgType]: - pass - - def _process_parallel_sampling_outputs( - outputs: List[RequestOutput] - )->List[RequestOutput]: - if self._ - def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c9a4c5369dfd8..70e71cd939d7f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,6 +21,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -129,6 +130,57 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + _add_request = (self._add_request if params is None + or isinstance(params, PoolingParams) or params.n == 1 + else self._add_request_parallel_sampling) + return _add_request(request_id=request_id, + prompt=prompt, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) + + def _add_request_parallel_sampling( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + req_mgr = ParallelSamplingRequestManager(request_id, params) + n = req_mgr.n + + # Add n child requests with unique request IDs and + # n=1 + for idx in range(n): + c_params = req_mgr.get_child_sampling_params(idx) + c_request_id = req_mgr.get_child_request_id(idx) + self._add_request(request_id=c_request_id, + prompt=prompt, + params=c_params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) + + def _add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: # 1) Process raw inputs into the request. request = self.processor.process_inputs(request_id, prompt, params, From 1c18dc294e80b1c9b1d0d997c528532eb86afe65 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Feb 2025 22:21:38 +0000 Subject: [PATCH 27/54] wip Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 5 ++++ vllm/v1/engine/llm_engine.py | 38 ++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 84b634316cb46..7cce944cf9c91 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -5,6 +5,11 @@ from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import LLM, SamplingParams +def test_parallel_sampling(monkeypatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + output=LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + "Hello, my name is", + SamplingParams(temperature=0.8, top_p=0.95, n=2)) def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): """Test passes if LLMEngine raises an exception when it is configured diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 70e71cd939d7f..988a4712b2850 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,6 +23,7 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager from vllm.v1.engine.processor import Processor +from vllm.v1.engine import EngineCoreOutputs from vllm.v1.executor.abstract import Executor logger = init_logger(__name__) @@ -48,6 +49,10 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + # Parallel sampling infra + self.parallel_parent_reqs: Dict[str,ParallelSamplingRequestManager]={} + self.parallel_child_reqs: Dict[str,str]={} + # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, @@ -154,13 +159,14 @@ def _add_request_parallel_sampling( priority: int = 0, ) -> None: req_mgr = ParallelSamplingRequestManager(request_id, params) + self.parallel_parent_reqs[request_id]=req_mgr n = req_mgr.n - # Add n child requests with unique request IDs and - # n=1 + # Add n child requests with unique request IDs and n=1 for idx in range(n): c_params = req_mgr.get_child_sampling_params(idx) c_request_id = req_mgr.get_child_request_id(idx) + self.parallel_child_reqs[c_request_id]=(idx,request_id) self._add_request(request_id=c_request_id, prompt=prompt, params=c_params, @@ -195,10 +201,38 @@ def _add_request( # 3) Add the request to EngineCore. self.engine_core.add_request(request) + def _aggregate_parallel_sampling_outputs( + self, + outputs: EngineCoreOutputs, + )->List[RequestOutput]: + agg_outputs=[] + for c_out in outputs.outputs: + c_req_id=c_out.request_id + if cdx_req_id := self.parallel_child_reqs.get(c_req_id,None): + (cdx,req_id)=cdx_req_id + # Update parallel sampling request + req_mgr=self.parallel_parent_reqs[req_id] + if out := req_mgr._process_output(c_out,cdx): + # Return parent request output if complete; + # cleanup parent request + agg_outputs.append(out) + del self.parallel_parent_reqs[req_id] + # Cleanup child request + del self.parallel_child_reqs[c_req_id] + else: + # Not a parallel sampling request output + agg_outputs.append(c_out) + return agg_outputs + + def _is_any_parallel_sampling(self)->bool: + return len(self.parallel_parent_reqs)>0 + def step(self) -> List[RequestOutput]: # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() + if self._is_any_parallel_sampling(): + outputs=self._aggregate_parallel_sampling_outputs(outputs) # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( From c9c3dbbc07e040674f4e90c67da419fc7dffc87b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Feb 2025 23:06:59 +0000 Subject: [PATCH 28/54] working parallel sampling in LLMEngine Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 52 +++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 988a4712b2850..b188cc5d8d21a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Tuple, Type, Union from typing_extensions import TypeVar @@ -23,7 +23,6 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager from vllm.v1.engine.processor import Processor -from vllm.v1.engine import EngineCoreOutputs from vllm.v1.executor.abstract import Executor logger = init_logger(__name__) @@ -50,8 +49,9 @@ def __init__( self.cache_config = vllm_config.cache_config # Parallel sampling infra - self.parallel_parent_reqs: Dict[str,ParallelSamplingRequestManager]={} - self.parallel_child_reqs: Dict[str,str]={} + self.parallel_parent_reqs: Dict[str, + ParallelSamplingRequestManager] = {} + self.parallel_child_reqs: Dict[str, Tuple[int, str]] = {} # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -109,7 +109,10 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.output_processor.get_num_unfinished_requests() + num_core_reqs = self.output_processor.get_num_unfinished_requests() + num_child_reqs = self._num_parallel_sampling_child_requests() + num_parent_reqs = self._num_parallel_sampling_requests() + return num_core_reqs + num_parent_reqs - num_child_reqs def has_unfinished_requests(self) -> bool: return self.output_processor.has_unfinished_requests() @@ -159,14 +162,14 @@ def _add_request_parallel_sampling( priority: int = 0, ) -> None: req_mgr = ParallelSamplingRequestManager(request_id, params) - self.parallel_parent_reqs[request_id]=req_mgr + self.parallel_parent_reqs[request_id] = req_mgr n = req_mgr.n # Add n child requests with unique request IDs and n=1 for idx in range(n): c_params = req_mgr.get_child_sampling_params(idx) c_request_id = req_mgr.get_child_request_id(idx) - self.parallel_child_reqs[c_request_id]=(idx,request_id) + self.parallel_child_reqs[c_request_id] = (idx, request_id) self._add_request(request_id=c_request_id, prompt=prompt, params=c_params, @@ -202,17 +205,17 @@ def _add_request( self.engine_core.add_request(request) def _aggregate_parallel_sampling_outputs( - self, - outputs: EngineCoreOutputs, - )->List[RequestOutput]: - agg_outputs=[] - for c_out in outputs.outputs: - c_req_id=c_out.request_id - if cdx_req_id := self.parallel_child_reqs.get(c_req_id,None): - (cdx,req_id)=cdx_req_id + self, + outputs: List[RequestOutput], + ) -> List[RequestOutput]: + agg_outputs = [] + for c_out in outputs: + c_req_id = c_out.request_id + if cdx_req_id := self.parallel_child_reqs.get(c_req_id, None): + (cdx, req_id) = cdx_req_id # Update parallel sampling request - req_mgr=self.parallel_parent_reqs[req_id] - if out := req_mgr._process_output(c_out,cdx): + req_mgr = self.parallel_parent_reqs[req_id] + if out := req_mgr._process_output(c_out, cdx): # Return parent request output if complete; # cleanup parent request agg_outputs.append(out) @@ -224,15 +227,16 @@ def _aggregate_parallel_sampling_outputs( agg_outputs.append(c_out) return agg_outputs - def _is_any_parallel_sampling(self)->bool: - return len(self.parallel_parent_reqs)>0 + def _num_parallel_sampling_requests(self) -> int: + return len(self.parallel_parent_reqs) + + def _num_parallel_sampling_child_requests(self) -> int: + return len(self.parallel_child_reqs) def step(self) -> List[RequestOutput]: # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() - if self._is_any_parallel_sampling(): - outputs=self._aggregate_parallel_sampling_outputs(outputs) # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( @@ -241,7 +245,11 @@ def step(self) -> List[RequestOutput]: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - return processed_outputs.request_outputs + if self._num_parallel_sampling_requests() > 0: + return self._aggregate_parallel_sampling_outputs( + processed_outputs.request_outputs) + else: + return processed_outputs.request_outputs def get_model_config(self): return self.model_config From 0f0075ce765c4a717e550a02eceb5d1c506e2b5c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Feb 2025 19:10:21 +0000 Subject: [PATCH 29/54] unit test Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 67 +++++++++++++++++++++++++++--- vllm/v1/engine/llm_engine.py | 33 +++++++++++---- 2 files changed, 87 insertions(+), 13 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 7cce944cf9c91..f3c639548bb53 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,15 +1,72 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import LLM, SamplingParams -def test_parallel_sampling(monkeypatch): +MODEL = "facebook/opt-125m" +DTYPE = "half" + + +@pytest.fixture( + scope="module", + # Prefix caching + params=[False, True]) +def vllm_model(vllm_runner, request): + with vllm_runner( + MODEL, + dtype=DTYPE, + max_logprobs=7, + # Very small number of batched tokens to ensure + # that we test chunking. + max_num_batched_tokens=16, + max_num_seqs=16, + max_model_len=128, + enforce_eager=True, + enable_prefix_caching=request.param, + gpu_memory_utilization=0.5, + ) as vllm_model: + yield vllm_model + + +def _get_test_sampling_params(prompt_lst): + + def get_mostly_n_gt1() -> int: + """Mostly n>1, sometimes n=1""" + x = random.randint(0, 28) + if x < 10: + return 1 + else: + return x - 8 + + n_list = [get_mostly_n_gt1() for _ in range(len(prompt_lst))] + return [SamplingParams(temperature=0.8, top_p=0.95, n=n) + for n in n_list], n_list + + +def test_parallel_sampling(monkeypatch, vllm_model, example_prompts): monkeypatch.setenv("VLLM_USE_V1", "1") - output=LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( - "Hello, my name is", - SamplingParams(temperature=0.8, top_p=0.95, n=2)) + sampling_params_list, n_list = _get_test_sampling_params(example_prompts) + model: LLM = vllm_model.model + outputs = model.generate(example_prompts, sampling_params_list) + for out, n in zip(outputs, n_list): + unique_texts = set() + # Correct number of completions + assert len(out.outputs) == n, ( + f"{len(out.outputs)} completions; {n} expected.") + for idx in range(n): + comp = out.outputs[idx] + # Correct completion indices + assert comp.index == idx, (f"Index {comp.index}; expected {idx}.") + unique_texts.add(comp.text) + # Unique completions + assert len(unique_texts) == n, ( + f"{len(unique_texts)} unique completions; expected" + f" {n}.") + def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): """Test passes if LLMEngine raises an exception when it is configured @@ -20,7 +77,7 @@ def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with pytest.raises(ValueError) as excinfo: - LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + LLM(model=MODEL, enable_prefix_caching=True).generate( "Hello, my name is", SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index b188cc5d8d21a..7608df5168c56 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -48,10 +48,15 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config - # Parallel sampling infra + # Bookkeeping for parallel sampling requests + # - parent req ID -> parent request manager self.parallel_parent_reqs: Dict[str, ParallelSamplingRequestManager] = {} + # - child req ID -> (child req index, parent req ID) self.parallel_child_reqs: Dict[str, Tuple[int, str]] = {} + # - flag to reset parallel sampling bookkeeping logic + # between engine runs + self._do_reset_parallel_sampling = False # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -127,6 +132,12 @@ def abort_request(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) self.output_processor.abort_requests(request_ids) + def _reset_parallel_sampling(self) -> None: + """Reset parallel sampling logic""" + self.parallel_parent_reqs.clear() + self.parallel_child_reqs.clear() + self._do_reset_parallel_sampling = False + def add_request( self, request_id: str, @@ -138,6 +149,8 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + if self._do_reset_parallel_sampling: + self._reset_parallel_sampling() _add_request = (self._add_request if params is None or isinstance(params, PoolingParams) or params.n == 1 else self._add_request_parallel_sampling) @@ -163,10 +176,8 @@ def _add_request_parallel_sampling( ) -> None: req_mgr = ParallelSamplingRequestManager(request_id, params) self.parallel_parent_reqs[request_id] = req_mgr - n = req_mgr.n - # Add n child requests with unique request IDs and n=1 - for idx in range(n): + for idx in range(req_mgr.n): c_params = req_mgr.get_child_sampling_params(idx) c_request_id = req_mgr.get_child_request_id(idx) self.parallel_child_reqs[c_request_id] = (idx, request_id) @@ -234,6 +245,12 @@ def _num_parallel_sampling_child_requests(self) -> int: return len(self.parallel_child_reqs) def step(self) -> List[RequestOutput]: + num_parallel_reqs = self._num_parallel_sampling_requests() + + # Ensure that parallel sampling logic gets reset after the + # engine finishes processing this batch + self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else + self._do_reset_parallel_sampling) # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() @@ -245,11 +262,11 @@ def step(self) -> List[RequestOutput]: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - if self._num_parallel_sampling_requests() > 0: - return self._aggregate_parallel_sampling_outputs( - processed_outputs.request_outputs) + request_outputs = processed_outputs.request_outputs + if num_parallel_reqs > 0 and len(request_outputs) > 0: + return self._aggregate_parallel_sampling_outputs(request_outputs) else: - return processed_outputs.request_outputs + return request_outputs def get_model_config(self): return self.model_config From e4a0e6c6056ed19197fae803998263233f907b6f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Feb 2025 20:38:08 +0000 Subject: [PATCH 30/54] refactoring; bugfix Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 67 +++++++++++++++++++---------- vllm/v1/engine/llm_engine.py | 36 +++++++++++++--- vllm/v1/engine/parallel_sampling.py | 9 ++-- 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index f3c639548bb53..be04b41dc49fb 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import random +from typing import List, Optional, Tuple import pytest @@ -16,56 +17,78 @@ # Prefix caching params=[False, True]) def vllm_model(vllm_runner, request): + """VllmRunner test fixture parameterized by APC.""" + enable_prefix_caching = request.param[0] with vllm_runner( MODEL, dtype=DTYPE, - max_logprobs=7, - # Very small number of batched tokens to ensure - # that we test chunking. - max_num_batched_tokens=16, - max_num_seqs=16, max_model_len=128, enforce_eager=True, - enable_prefix_caching=request.param, + enable_prefix_caching=enable_prefix_caching, gpu_memory_utilization=0.5, ) as vllm_model: + # VllmRunner instance is cleaned up after test. yield vllm_model -def _get_test_sampling_params(prompt_lst): +def _get_test_sampling_params( + prompt_list: List[str], + seed: Optional[int] = None, +) -> Tuple[List[SamplingParams], List[int]]: + """Generate random sampling params for a batch.""" def get_mostly_n_gt1() -> int: - """Mostly n>1, sometimes n=1""" + """Mostly n \in [2,20], ~1/3 n=1""" x = random.randint(0, 28) if x < 10: return 1 else: return x - 8 - n_list = [get_mostly_n_gt1() for _ in range(len(prompt_lst))] - return [SamplingParams(temperature=0.8, top_p=0.95, n=n) - for n in n_list], n_list - - -def test_parallel_sampling(monkeypatch, vllm_model, example_prompts): + n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] + # High temperature to maximize the chance of unique completions + return [ + SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) + for n in n_list + ], n_list + + +def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: + """Test passes if parallel sampling `n>1` yields `n` uniques completions. + + Args: + monkeypatch: test fixture for modifying text env, scoped to the test. + vllm_model: VllmRunner instance under test. + example_prompt: test fixture providing prompts for testing. + """ monkeypatch.setenv("VLLM_USE_V1", "1") + # Generate batch sampling params sampling_params_list, n_list = _get_test_sampling_params(example_prompts) + # Process requests model: LLM = vllm_model.model outputs = model.generate(example_prompts, sampling_params_list) + + # Validate each request response for out, n in zip(outputs, n_list): - unique_texts = set() - # Correct number of completions + completion_counts = {} + # Assert correct number of completions assert len(out.outputs) == n, ( f"{len(out.outputs)} completions; {n} expected.") for idx in range(n): comp = out.outputs[idx] - # Correct completion indices + # Assert correct completion indices assert comp.index == idx, (f"Index {comp.index}; expected {idx}.") - unique_texts.add(comp.text) - # Unique completions - assert len(unique_texts) == n, ( - f"{len(unique_texts)} unique completions; expected" - f" {n}.") + text = comp.text + completion_counts[text] = completion_counts.get(text, 0) + 1 + # Assert unique completions + if len(completion_counts) != n: + repeats = { + txt: num + for (txt, num) in completion_counts.items() if num > 1 + } + raise AssertionError( + f"{len(completion_counts)} unique completions; expected" + f" {n}. Repeats: {repeats}") def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7608df5168c56..cb8802b351a3c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -149,8 +149,12 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + """Add request.""" if self._do_reset_parallel_sampling: + # Reset parallel sampling logic between + # LLM.generate() calls self._reset_parallel_sampling() + # Handle parallel sampling requests differently. _add_request = (self._add_request if params is None or isinstance(params, PoolingParams) or params.n == 1 else self._add_request_parallel_sampling) @@ -174,16 +178,16 @@ def _add_request_parallel_sampling( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + """Add request, `n>1`""" req_mgr = ParallelSamplingRequestManager(request_id, params) self.parallel_parent_reqs[request_id] = req_mgr - # Add n child requests with unique request IDs and n=1 + # Add n child requests with unique request IDs & random seeds and n=1 for idx in range(req_mgr.n): - c_params = req_mgr.get_child_sampling_params(idx) c_request_id = req_mgr.get_child_request_id(idx) self.parallel_child_reqs[c_request_id] = (idx, request_id) self._add_request(request_id=c_request_id, prompt=prompt, - params=c_params, + params=req_mgr.get_child_sampling_params(idx), arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, @@ -201,7 +205,7 @@ def _add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - + """Add request, `n=1`""" # 1) Process raw inputs into the request. request = self.processor.process_inputs(request_id, prompt, params, arrival_time, lora_request, @@ -219,19 +223,36 @@ def _aggregate_parallel_sampling_outputs( self, outputs: List[RequestOutput], ) -> List[RequestOutput]: + """Build parallel sampling request outputs. + + Extract child request outputs, aggregate them + into parent request output, and return parent + output when complete. + + Do not modify `n=1` requests. + + Args: + outputs: step request outputs. Mix of child request + outputs & `n=1` request outputs. + + Return: + List of parallel sampling parent request outputs & + unmodified `n=1` request outputs passed-thru from input. + """ agg_outputs = [] for c_out in outputs: c_req_id = c_out.request_id if cdx_req_id := self.parallel_child_reqs.get(c_req_id, None): + # For each parallel sampling child request output: (cdx, req_id) = cdx_req_id - # Update parallel sampling request req_mgr = self.parallel_parent_reqs[req_id] + # Update parallel sampling request if out := req_mgr._process_output(c_out, cdx): # Return parent request output if complete; - # cleanup parent request + # cleanup parent request bookkeeping. agg_outputs.append(out) del self.parallel_parent_reqs[req_id] - # Cleanup child request + # Cleanup child request bookkeeping. del self.parallel_child_reqs[c_req_id] else: # Not a parallel sampling request output @@ -264,6 +285,7 @@ def step(self) -> List[RequestOutput]: request_outputs = processed_outputs.request_outputs if num_parallel_reqs > 0 and len(request_outputs) > 0: + # Process parallel sampling child request outputs return self._aggregate_parallel_sampling_outputs(request_outputs) else: return request_outputs diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 412c0583fd1ae..eb16404b35f81 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -78,6 +78,8 @@ def _add_output( child request. index: index within `n` child """ + new_completion = child_req_output.outputs[0] + new_completion.index = index if self.request_output is None: # Save the first request output; reinstate # original request ID; metrics are not @@ -86,10 +88,7 @@ def _add_output( child_req_output.metrics = None self.request_output = child_req_output else: - # Aggregate additional completion into request - # output - new_completion = child_req_output.outputs[0] - new_completion.index = index + # Aggregate additional completion into request output # Note: will be sorted by index later self.request_output.outputs.append(new_completion) @@ -186,4 +185,4 @@ def n(self) -> int: @property def output_kind(self) -> RequestOutputKind: - return self.sampling_params.output_kind \ No newline at end of file + return self.sampling_params.output_kind From 196fc68433434923677fa584de6e53ed5fdde549 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Feb 2025 20:42:05 +0000 Subject: [PATCH 31/54] seed Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index be04b41dc49fb..526aaa232c307 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -18,7 +18,7 @@ params=[False, True]) def vllm_model(vllm_runner, request): """VllmRunner test fixture parameterized by APC.""" - enable_prefix_caching = request.param[0] + enable_prefix_caching = request.param with vllm_runner( MODEL, dtype=DTYPE, @@ -33,7 +33,7 @@ def vllm_model(vllm_runner, request): def _get_test_sampling_params( prompt_list: List[str], - seed: Optional[int] = None, + seed: Optional[int] = 42, ) -> Tuple[List[SamplingParams], List[int]]: """Generate random sampling params for a batch.""" From f86708a14974100c8a9ec12592cceff1b3767e7e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Feb 2025 20:45:39 +0000 Subject: [PATCH 32/54] refactor Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 526aaa232c307..1bdb3dbd85e11 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -54,7 +54,7 @@ def get_mostly_n_gt1() -> int: def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: - """Test passes if parallel sampling `n>1` yields `n` uniques completions. + """Test passes if parallel sampling `n>1` yields `n` unique completions. Args: monkeypatch: test fixture for modifying text env, scoped to the test. @@ -62,9 +62,7 @@ def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: example_prompt: test fixture providing prompts for testing. """ monkeypatch.setenv("VLLM_USE_V1", "1") - # Generate batch sampling params sampling_params_list, n_list = _get_test_sampling_params(example_prompts) - # Process requests model: LLM = vllm_model.model outputs = model.generate(example_prompts, sampling_params_list) From 6b1be3663a4c55496fbb1e296a6d6506d49b25e9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Feb 2025 21:15:40 +0000 Subject: [PATCH 33/54] Parallel sampling unit tests Signed-off-by: Andrew Feldman --- .../v1/entrypoints/openai/test_completion.py | 56 ++++++++++++++++--- vllm/v1/engine/parallel_sampling.py | 7 +-- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index dc72089572f8d..68424a8c45749 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -265,14 +265,38 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, n = 3 max_tokens = 5 + # High temperature to maximize chance of unique completions. completion = await client.completions.create(model=model_name, prompt=prompt, max_tokens=max_tokens, n=n, - stream=False) - - for choice in completion.choices: - assert choice.finish_reason is not None + temperature=0.95, + stream=False, + seed=42) + + # Assert `n` completions + num_completions = len(completion.choices) + assert num_completions == n, ( + f"Num completions {num_completions} but expected {n}.") + completion_repeats = {} + for idx, choice in enumerate(completion.choices): + # Assert correct completion index & some finish reason. + assert choice.index == idx, ( + f"Index {choice.index} but expected {idx}.") + assert choice.finish_reason is not None, ( + "None finish_reason is invalid.") + text = choice.text + completion_repeats[text] = completion_repeats.get(text, 0) + 1 + # Assert `n` unique completions + num_unique = len(completion_repeats) + if num_unique != n: + repeats = { + txt: num + for (txt, num) in completion_repeats.items() if num > 1 + } + raise AssertionError( + f"Expected {n} unique completions, got {num_unique};" + f" repeats: {repeats}.") @pytest.mark.asyncio @@ -294,7 +318,9 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): prompt=prompt, max_tokens=max_tokens, n=n, - stream=True) + temperature=0.95, + stream=True, + seed=42) chunks: List[List[str]] = [[] for i in range(n)] finish_reason_count = 0 async for chunk in stream: @@ -303,10 +329,24 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): chunks[index].append(text) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 - assert finish_reason_count == n + # Assert `n` completions with correct finish reasons + assert finish_reason_count == n, ( + f"Expected {n} completions with valid indices and finish_reason.") + num_repeats = {} for chunk in chunks: - assert len(chunk) == max_tokens - print("".join(chunk)) + chunk_len = len(chunk) + # Assert correct number of completion tokens + assert chunk_len == max_tokens, ( + f"max_tokens={max_tokens} but chunk len is {chunk_len}.") + text = "".join(chunk) + num_repeats[text] = num_repeats.get(text, 0) + 1 + print(text) + # Assert `n` unique completions + num_unique = len(num_repeats) + if num_unique != n: + repeats = {txt: num for (txt, num) in num_repeats.items() if num > 1} + raise AssertionError(f"{num_unique} unique completions, expected {n};" + f" repeats: {repeats}") @pytest.mark.asyncio diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 5af568f6cd08c..eb16404b35f81 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -78,6 +78,8 @@ def _add_output( child request. index: index within `n` child """ + new_completion = child_req_output.outputs[0] + new_completion.index = index if self.request_output is None: # Save the first request output; reinstate # original request ID; metrics are not @@ -86,10 +88,7 @@ def _add_output( child_req_output.metrics = None self.request_output = child_req_output else: - # Aggregate additional completion into request - # output - new_completion = child_req_output.outputs[0] - new_completion.index = index + # Aggregate additional completion into request output # Note: will be sorted by index later self.request_output.outputs.append(new_completion) From 933a90eb57d9e37cd3a9f7edd3dcd3f8d9b60539 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Feb 2025 22:00:04 +0000 Subject: [PATCH 34/54] pre-commit hook fixes Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 4 ++-- tests/v1/entrypoints/openai/test_completion.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 1bdb3dbd85e11..7f2636d44a08d 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import random -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import pytest @@ -68,7 +68,7 @@ def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: # Validate each request response for out, n in zip(outputs, n_list): - completion_counts = {} + completion_counts: Dict[str, int] = {} # Assert correct number of completions assert len(out.outputs) == n, ( f"{len(out.outputs)} completions; {n} expected.") diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 68424a8c45749..35e059ccb5480 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -278,7 +278,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, num_completions = len(completion.choices) assert num_completions == n, ( f"Num completions {num_completions} but expected {n}.") - completion_repeats = {} + completion_repeats: Dict[str, int] = {} for idx, choice in enumerate(completion.choices): # Assert correct completion index & some finish reason. assert choice.index == idx, ( @@ -332,19 +332,22 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): # Assert `n` completions with correct finish reasons assert finish_reason_count == n, ( f"Expected {n} completions with valid indices and finish_reason.") - num_repeats = {} + completion_repeats: Dict[str, int] = {} for chunk in chunks: chunk_len = len(chunk) # Assert correct number of completion tokens assert chunk_len == max_tokens, ( f"max_tokens={max_tokens} but chunk len is {chunk_len}.") text = "".join(chunk) - num_repeats[text] = num_repeats.get(text, 0) + 1 + completion_repeats[text] = completion_repeats.get(text, 0) + 1 print(text) # Assert `n` unique completions - num_unique = len(num_repeats) + num_unique = len(completion_repeats) if num_unique != n: - repeats = {txt: num for (txt, num) in num_repeats.items() if num > 1} + repeats = { + txt: num + for (txt, num) in completion_repeats.items() if num > 1 + } raise AssertionError(f"{num_unique} unique completions, expected {n};" f" repeats: {repeats}") From fe9f88c44a72f797befc81e5a2a2078355e7472a Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:27:10 -0500 Subject: [PATCH 35/54] Update vllm/v1/engine/async_llm.py Co-authored-by: Nick Hill --- vllm/v1/engine/async_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a079e721a1d28..5a79ae40cf350 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -275,8 +275,8 @@ async def _generate_parallel_sampling( gens.append(gen) # Merge generators - async for out in merge_async_iterators(*gens): - yield out[1] # out[0] is index + async for _, out in merge_async_iterators(*gens): + yield out def generate( self, From d40089ee9c64ada3fd57e81d0176ab9e5f6d3118 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:30:50 -0500 Subject: [PATCH 36/54] Update vllm/v1/engine/parallel_sampling.py Co-authored-by: Nick Hill --- vllm/v1/engine/parallel_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index eb16404b35f81..1198331a6ffbb 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -46,7 +46,7 @@ def get_child_sampling_params( Child `sampling_params` instance. """ seed = self.sampling_params.seed - if seed is None and self.cached_child_sampling_params: + if self.cached_child_sampling_params: # Reuse child sampling_params data structure return self.cached_child_sampling_params # Build child sampling_params From cb0a2b44609f1f0b40cfde47df7e890bddd977a3 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:31:11 -0500 Subject: [PATCH 37/54] Update vllm/v1/engine/parallel_sampling.py Co-authored-by: Nick Hill --- vllm/v1/engine/parallel_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 1198331a6ffbb..edcd5651891a2 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -103,7 +103,7 @@ def get_child_request_id( self, index: int, ) -> str: - return str(index) + "_" + self.request_id + return f"{index}_{self.request_id}" def _process_output( self, From ef49ba7464ae65ccb4a2f3d3a5393ae256a44ca8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 17:33:11 +0000 Subject: [PATCH 38/54] rename Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 12 ++++++------ vllm/v1/engine/llm_engine.py | 7 +++---- vllm/v1/engine/parallel_sampling.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a079e721a1d28..5712c645cd494 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,7 +24,7 @@ from vllm.utils import cdiv, kill_process_tree, merge_async_iterators from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager +from vllm.v1.engine.parallel_sampling import ParallelSamplingRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -255,23 +255,23 @@ async def _generate_parallel_sampling( priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate completions for parallel sampling requests.""" - req_mgr = ParallelSamplingRequestManager(request_id, sampling_params) - n = req_mgr.n + parent_req = ParallelSamplingRequest(request_id, sampling_params) + n = parent_req.n # Aggregate generators for n child requests gens: List[AsyncGenerator[RequestOutput, None]] = [] for idx in range(n): - c_sampling_params = req_mgr.get_child_sampling_params(idx) + c_sampling_params = parent_req.get_child_sampling_params(idx) child_gen = self._generate( prompt=prompt, sampling_params=c_sampling_params, - request_id=req_mgr.get_child_request_id(idx), + request_id=parent_req.get_child_request_id(idx), lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, ) - gen = req_mgr.parallel_sampling_child_gen(child_gen, idx) + gen = parent_req.parallel_sampling_child_gen(child_gen, idx) gens.append(gen) # Merge generators diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index cb8802b351a3c..d400a410b26c3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,7 +21,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import ParallelSamplingRequestManager +from vllm.v1.engine.parallel_sampling import ParallelSamplingRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -50,8 +50,7 @@ def __init__( # Bookkeeping for parallel sampling requests # - parent req ID -> parent request manager - self.parallel_parent_reqs: Dict[str, - ParallelSamplingRequestManager] = {} + self.parallel_parent_reqs: Dict[str, ParallelSamplingRequest] = {} # - child req ID -> (child req index, parent req ID) self.parallel_child_reqs: Dict[str, Tuple[int, str]] = {} # - flag to reset parallel sampling bookkeeping logic @@ -179,7 +178,7 @@ def _add_request_parallel_sampling( priority: int = 0, ) -> None: """Add request, `n>1`""" - req_mgr = ParallelSamplingRequestManager(request_id, params) + req_mgr = ParallelSamplingRequest(request_id, params) self.parallel_parent_reqs[request_id] = req_mgr # Add n child requests with unique request IDs & random seeds and n=1 for idx in range(req_mgr.n): diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index eb16404b35f81..ca35e98bf1caf 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -7,7 +7,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams -class ParallelSamplingRequestManager: +class ParallelSamplingRequest: """Info, state & processing for parallel sampling request. Store parent request ID and sampling params. From 94261e2e808a946b9b22e52dd7c6bcc1c4058d9e Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:35:24 -0500 Subject: [PATCH 39/54] Update vllm/v1/engine/llm_engine.py Co-authored-by: Nick Hill --- vllm/v1/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index cb8802b351a3c..aaa851e32581c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -270,8 +270,8 @@ def step(self) -> List[RequestOutput]: # Ensure that parallel sampling logic gets reset after the # engine finishes processing this batch - self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else - self._do_reset_parallel_sampling) + if self.parallel_parent_requests: + self._do_reset_parallel_sampling = True # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() From 9cc19de2313cf301af672550e9bc3c54fe52a8b6 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:39:13 -0500 Subject: [PATCH 40/54] Update vllm/v1/engine/parallel_sampling.py Co-authored-by: Nick Hill --- vllm/v1/engine/parallel_sampling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index edcd5651891a2..1a3b752121cfd 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -99,10 +99,7 @@ def _get_parent_request_output(self) -> RequestOutput: key=lambda x: x.index) return self.request_output - def get_child_request_id( - self, - index: int, - ) -> str: + def get_child_request_id(self, index: int) -> str: return f"{index}_{self.request_id}" def _process_output( From 150fc93522be41e85ebef4baf490a45f88528a57 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 17:40:07 +0000 Subject: [PATCH 41/54] fix Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 40655f85eb1ae..905b438310efe 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -269,8 +269,8 @@ def step(self) -> List[RequestOutput]: # Ensure that parallel sampling logic gets reset after the # engine finishes processing this batch - if self.parallel_parent_requests: - self._do_reset_parallel_sampling = True + if self.parallel_parent_reqs: + self._do_reset_parallel_sampling = True # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() From 9e8d75531f84c82d75132319a6042d875d95f329 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 18:07:02 +0000 Subject: [PATCH 42/54] refactor generate_parallel_sampling_async() into parallel_sampling.py Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 50 ++++++---------------------- vllm/v1/engine/parallel_sampling.py | 51 ++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba7cd5856c72b..06af5afc878bc 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -21,10 +21,10 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import cdiv, kill_process_tree, merge_async_iterators +from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import ParallelSamplingRequest +from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -244,40 +244,6 @@ async def _generate( await self.abort(request_id) raise - async def _generate_parallel_sampling( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - """Generate completions for parallel sampling requests.""" - parent_req = ParallelSamplingRequest(request_id, sampling_params) - n = parent_req.n - - # Aggregate generators for n child requests - gens: List[AsyncGenerator[RequestOutput, None]] = [] - for idx in range(n): - c_sampling_params = parent_req.get_child_sampling_params(idx) - child_gen = self._generate( - prompt=prompt, - sampling_params=c_sampling_params, - request_id=parent_req.get_child_request_id(idx), - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ) - gen = parent_req.parallel_sampling_child_gen(child_gen, idx) - gens.append(gen) - - # Merge generators - async for _, out in merge_async_iterators(*gens): - yield out - def generate( self, prompt: PromptType, @@ -289,10 +255,14 @@ def generate( priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: n = sampling_params.n - _generate = self._generate if n is None or n == 1 \ - else self._generate_parallel_sampling # handle parallel sampling - return _generate(prompt, sampling_params, request_id, lora_request, - trace_headers, prompt_adapter_request, priority) + if n is None or n == 1: + return self._generate(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority) + else: + return generate_parallel_sampling_async( + self._generate, prompt, sampling_params, request_id, + lora_request, trace_headers, prompt_adapter_request, priority) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index fdd1d657de154..5cbefff2fb584 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,10 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Callable, List, Mapping, Optional +from vllm.inputs import PromptType +from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.utils import merge_async_iterators + +AsyncGenerateMethodType = Callable[[ + PromptType, SamplingParams, str, Optional[LoRARequest], Optional[Mapping[ + str, str]], Optional[PromptAdapterRequest], int +], AsyncGenerator[RequestOutput, None]] class ParallelSamplingRequest: @@ -92,7 +101,7 @@ def _add_output( # Note: will be sorted by index later self.request_output.outputs.append(new_completion) - def _get_parent_request_output(self) -> RequestOutput: + def _get_final_request_output(self) -> RequestOutput: """Invariant: parent completion outputs sorted by index""" assert self.request_output is not None self.request_output.outputs = sorted(self.request_output.outputs, @@ -144,10 +153,10 @@ def _process_output( if self.num_completions == self.n: # Return aggregated request output after obtaining # all completions - return self._get_parent_request_output() + return self._get_final_request_output() return None - async def parallel_sampling_child_gen( + async def wrap_child_async_generator( self, child_gen: AsyncGenerator[RequestOutput, None], index: int, @@ -186,3 +195,37 @@ def n(self) -> int: @property def output_kind(self) -> RequestOutputKind: return self.sampling_params.output_kind + + +async def generate_parallel_sampling_async( + generate: AsyncGenerateMethodType, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, +) -> AsyncGenerator[RequestOutput, None]: + """Generate completions for async parallel sampling requests.""" + parent_req = ParallelSamplingRequest(request_id, sampling_params) + + # Aggregate generators for n child requests + gens: List[AsyncGenerator[RequestOutput, None]] = [] + for idx in range(parent_req.n): + c_sampling_params = parent_req.get_child_sampling_params(idx) + child_gen = generate( + prompt=prompt, + sampling_params=c_sampling_params, + request_id=parent_req.get_child_request_id(idx), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) # type: ignore + gen = parent_req.wrap_child_async_generator(child_gen, idx) + gens.append(gen) + + # Merge generators + async for _, out in merge_async_iterators(*gens): + yield out From 73ccfb3a976cf39e907b5d0d757ff919e2652db0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 18:15:38 +0000 Subject: [PATCH 43/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 8 ++++---- vllm/v1/engine/parallel_sampling.py | 23 ++++++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 905b438310efe..eeb6888e334e3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -182,11 +182,11 @@ def _add_request_parallel_sampling( self.parallel_parent_reqs[request_id] = req_mgr # Add n child requests with unique request IDs & random seeds and n=1 for idx in range(req_mgr.n): - c_request_id = req_mgr.get_child_request_id(idx) - self.parallel_child_reqs[c_request_id] = (idx, request_id) - self._add_request(request_id=c_request_id, + c_req_id, c_params = req_mgr.get_child_info(idx) + self.parallel_child_reqs[c_req_id] = (idx, request_id) + self._add_request(request_id=c_req_id, prompt=prompt, - params=req_mgr.get_child_sampling_params(idx), + params=c_params, arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 1577f818a03a6..866104a135dc4 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import AsyncGenerator, Callable, List, Mapping, Optional +from typing import AsyncGenerator, Callable, List, Mapping, Optional, Tuple from vllm.inputs import PromptType from vllm.lora.request import LoRARequest @@ -38,7 +38,7 @@ def __init__(self, request_id: str, self.sampling_params = sampling_params self.cached_child_sampling_params = None - def get_child_sampling_params( + def _get_child_sampling_params( self, index: int, ) -> SamplingParams: @@ -108,8 +108,17 @@ def _get_final_request_output(self) -> RequestOutput: key=lambda x: x.index) return self.request_output - def get_child_request_id(self, index: int) -> str: - return f"{index}_{self.request_id}" + def get_child_info(self, index: int) -> Tuple[str, SamplingParams]: + """Get child request ID and sampling params. + + Args: + index: index within `n` child requests. + + Returns: + (request ID, sampling_params) tuple + """ + return (f"{index}_{self.request_id}", + self._get_child_sampling_params(index)) def _process_output( self, @@ -210,11 +219,11 @@ async def generate_parallel_sampling_async( # Aggregate generators for n child requests gens: List[AsyncGenerator[RequestOutput, None]] = [] for idx in range(parent_req.n): - c_sampling_params = parent_req.get_child_sampling_params(idx) + c_req_id, c_params = parent_req.get_child_info(idx) child_gen = generate( prompt=prompt, - sampling_params=c_sampling_params, - request_id=parent_req.get_child_request_id(idx), + sampling_params=c_params, + request_id=c_req_id, lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, From b334d608379c99d83e754539aa58806df2ee1620 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 18:20:29 +0000 Subject: [PATCH 44/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index eeb6888e334e3..7a45b5a6582e7 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -113,10 +113,9 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - num_core_reqs = self.output_processor.get_num_unfinished_requests() - num_child_reqs = self._num_parallel_sampling_child_requests() - num_parent_reqs = self._num_parallel_sampling_requests() - return num_core_reqs + num_parent_reqs - num_child_reqs + """Get num unfinished requests, accounting for parallel sampling.""" + return (self.output_processor.get_num_unfinished_requests() + + len(self.parallel_parent_reqs) - len(self.parallel_child_reqs)) def has_unfinished_requests(self) -> bool: return self.output_processor.has_unfinished_requests() @@ -258,14 +257,11 @@ def _aggregate_parallel_sampling_outputs( agg_outputs.append(c_out) return agg_outputs - def _num_parallel_sampling_requests(self) -> int: - return len(self.parallel_parent_reqs) - def _num_parallel_sampling_child_requests(self) -> int: return len(self.parallel_child_reqs) def step(self) -> List[RequestOutput]: - num_parallel_reqs = self._num_parallel_sampling_requests() + num_parallel_reqs = len(self.parallel_parent_reqs) # Ensure that parallel sampling logic gets reset after the # engine finishes processing this batch @@ -286,8 +282,7 @@ def step(self) -> List[RequestOutput]: if num_parallel_reqs > 0 and len(request_outputs) > 0: # Process parallel sampling child request outputs return self._aggregate_parallel_sampling_outputs(request_outputs) - else: - return request_outputs + return request_outputs def get_model_config(self): return self.model_config From 38ea05785d78c4585f955436a3ff0e8de2bc04e2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 19:17:47 +0000 Subject: [PATCH 45/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 22 ++++- vllm/v1/engine/llm_engine.py | 136 ++++++-------------------- vllm/v1/engine/parallel_sampling.py | 145 +++++++++++++++++++++++++++- 3 files changed, 189 insertions(+), 114 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 06af5afc878bc..c6d8c2331a2bb 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -256,13 +256,25 @@ def generate( ) -> AsyncGenerator[RequestOutput, None]: n = sampling_params.n if n is None or n == 1: - return self._generate(prompt, sampling_params, request_id, - lora_request, trace_headers, - prompt_adapter_request, priority) + return self._generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) else: + # Special handling for parallel sampling requests return generate_parallel_sampling_async( - self._generate, prompt, sampling_params, request_id, - lora_request, trace_headers, prompt_adapter_request, priority) + generate=self._generate, + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7a45b5a6582e7..52e825ceb0c79 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import Dict, List, Mapping, Optional, Type, Union from typing_extensions import TypeVar @@ -21,7 +21,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import ParallelSamplingRequest +from vllm.v1.engine.parallel_sampling import (SyncParallelSamplingManager, + add_request_parallel_sampling) from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -49,13 +50,7 @@ def __init__( self.cache_config = vllm_config.cache_config # Bookkeeping for parallel sampling requests - # - parent req ID -> parent request manager - self.parallel_parent_reqs: Dict[str, ParallelSamplingRequest] = {} - # - child req ID -> (child req index, parent req ID) - self.parallel_child_reqs: Dict[str, Tuple[int, str]] = {} - # - flag to reset parallel sampling bookkeeping logic - # between engine runs - self._do_reset_parallel_sampling = False + self.parallel_mgr = SyncParallelSamplingManager() # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -113,9 +108,8 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - """Get num unfinished requests, accounting for parallel sampling.""" - return (self.output_processor.get_num_unfinished_requests() + - len(self.parallel_parent_reqs) - len(self.parallel_child_reqs)) + return self.parallel_mgr.get_num_unfinished_requests( + self.output_processor.get_num_unfinished_requests()) def has_unfinished_requests(self) -> bool: return self.output_processor.has_unfinished_requests() @@ -130,12 +124,6 @@ def abort_request(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) self.output_processor.abort_requests(request_ids) - def _reset_parallel_sampling(self) -> None: - """Reset parallel sampling logic""" - self.parallel_parent_reqs.clear() - self.parallel_child_reqs.clear() - self._do_reset_parallel_sampling = False - def add_request( self, request_id: str, @@ -148,49 +136,30 @@ def add_request( priority: int = 0, ) -> None: """Add request.""" - if self._do_reset_parallel_sampling: - # Reset parallel sampling logic between - # LLM.generate() calls - self._reset_parallel_sampling() # Handle parallel sampling requests differently. - _add_request = (self._add_request if params is None - or isinstance(params, PoolingParams) or params.n == 1 - else self._add_request_parallel_sampling) - return _add_request(request_id=request_id, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - - def _add_request_parallel_sampling( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add request, `n>1`""" - req_mgr = ParallelSamplingRequest(request_id, params) - self.parallel_parent_reqs[request_id] = req_mgr - # Add n child requests with unique request IDs & random seeds and n=1 - for idx in range(req_mgr.n): - c_req_id, c_params = req_mgr.get_child_info(idx) - self.parallel_child_reqs[c_req_id] = (idx, request_id) - self._add_request(request_id=c_req_id, + if params is None or isinstance(params, + PoolingParams) or params.n == 1: + self._add_request(request_id=request_id, prompt=prompt, - params=c_params, + params=params, arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority) + else: + # Special handling for parallel sampling requests + add_request_parallel_sampling( + add_request=self._add_request, + parallel_mgr=self.parallel_mgr, + request_id=request_id, + prompt=prompt, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) def _add_request( self, @@ -217,56 +186,10 @@ def _add_request( # 3) Add the request to EngineCore. self.engine_core.add_request(request) - def _aggregate_parallel_sampling_outputs( - self, - outputs: List[RequestOutput], - ) -> List[RequestOutput]: - """Build parallel sampling request outputs. - - Extract child request outputs, aggregate them - into parent request output, and return parent - output when complete. - - Do not modify `n=1` requests. - - Args: - outputs: step request outputs. Mix of child request - outputs & `n=1` request outputs. - - Return: - List of parallel sampling parent request outputs & - unmodified `n=1` request outputs passed-thru from input. - """ - agg_outputs = [] - for c_out in outputs: - c_req_id = c_out.request_id - if cdx_req_id := self.parallel_child_reqs.get(c_req_id, None): - # For each parallel sampling child request output: - (cdx, req_id) = cdx_req_id - req_mgr = self.parallel_parent_reqs[req_id] - # Update parallel sampling request - if out := req_mgr._process_output(c_out, cdx): - # Return parent request output if complete; - # cleanup parent request bookkeeping. - agg_outputs.append(out) - del self.parallel_parent_reqs[req_id] - # Cleanup child request bookkeeping. - del self.parallel_child_reqs[c_req_id] - else: - # Not a parallel sampling request output - agg_outputs.append(c_out) - return agg_outputs - - def _num_parallel_sampling_child_requests(self) -> int: - return len(self.parallel_child_reqs) - def step(self) -> List[RequestOutput]: - num_parallel_reqs = len(self.parallel_parent_reqs) - - # Ensure that parallel sampling logic gets reset after the - # engine finishes processing this batch - if self.parallel_parent_reqs: - self._do_reset_parallel_sampling = True + # Schedule reset of parallel sampling logic + # in between generate() runs + self.parallel_mgr.schedule_reset() # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() @@ -279,9 +202,10 @@ def step(self) -> List[RequestOutput]: self.engine_core.abort_requests(processed_outputs.reqs_to_abort) request_outputs = processed_outputs.request_outputs - if num_parallel_reqs > 0 and len(request_outputs) > 0: - # Process parallel sampling child request outputs - return self._aggregate_parallel_sampling_outputs(request_outputs) + + # 4) Process unfinished parallel sampling requests + request_outputs = self.parallel_mgr.step(request_outputs) + return request_outputs def get_model_config(self): diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 866104a135dc4..91be169582100 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import AsyncGenerator, Callable, List, Mapping, Optional, Tuple +from typing import (AsyncGenerator, Callable, Dict, List, Mapping, Optional, + Tuple, Union) from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.utils import merge_async_iterators @@ -15,6 +17,12 @@ str, str]], Optional[PromptAdapterRequest], int ], AsyncGenerator[RequestOutput, None]] +SyncAddRequestMethodType = Callable[[ + str, PromptType, Union[SamplingParams, PoolingParams], Optional[float], + Optional[LoRARequest], Optional[Mapping[ + str, str]], Optional[PromptAdapterRequest], int +], None] + class ParallelSamplingRequest: """Info, state & processing for parallel sampling request. @@ -120,7 +128,7 @@ def get_child_info(self, index: int) -> Tuple[str, SamplingParams]: return (f"{index}_{self.request_id}", self._get_child_sampling_params(index)) - def _process_output( + def process_output( self, child_req_output: RequestOutput, index: int, @@ -186,7 +194,7 @@ async def wrap_child_async_generator( to the caller. """ async for out in child_gen: - if req_out := self._process_output(out, index): + if req_out := self.process_output(out, index): yield req_out @property @@ -203,6 +211,108 @@ def output_kind(self) -> RequestOutputKind: return self.sampling_params.output_kind +class SyncParallelSamplingManager: + + def __init__(self): + # Parent req ID -> parent request manager + self.parent_reqs: Dict[str, ParallelSamplingRequest] = {} + # Child req ID -> (child req index, parent req ID) + self.child_reqs: Dict[str, Tuple[int, str]] = {} + # Flag to reset parallel sampling bookkeeping logic + # between engine runs + self._do_reset = False + + def _reset_if_needed(self) -> None: + """Reset at beginning of sync generate()""" + if self._do_reset: + self.parent_reqs.clear() + self.child_reqs.clear() + self._do_reset = False + + def schedule_reset(self) -> None: + """Schedule reset for the next time a parent request is added.""" + if self.parent_reqs: + self._do_reset = True + + def register_parent_request(self, req: ParallelSamplingRequest) -> None: + """Register parallel sampling parent request.""" + self._reset_if_needed() + self.parent_reqs[req.request_id] = req + + def register_child_request(self, req_id: str, child_req_id: str, + index: int) -> None: + """Register parallel sampling child request with parent. + + Args: + req_id: parent request ID + child_req_id: child request ID + index: child request index within `n` child requests + """ + self.child_reqs[child_req_id] = (index, req_id) + + def _aggregate_parallel_sampling_outputs( + self, + outputs: List[RequestOutput], + ) -> List[RequestOutput]: + """Build parallel sampling request outputs. + + Extract child request outputs, aggregate them + into parent request output, and return parent + output when complete. + + Do not modify `n=1` requests. + + Args: + outputs: step request outputs. Mix of child request + outputs & `n=1` request outputs. + + Return: + List of parallel sampling parent request outputs & + unmodified `n=1` request outputs passed-thru from input. + """ + agg_outputs = [] + for c_out in outputs: + c_req_id = c_out.request_id + if cdx_req_id := self.child_reqs.get(c_req_id, None): + # For each parallel sampling child request output: + (cdx, req_id) = cdx_req_id + req = self.parent_reqs[req_id] + # Update parallel sampling request + if out := req.process_output(c_out, cdx): + # Return parent request output if complete; + # cleanup parent request bookkeeping. + agg_outputs.append(out) + del self.parent_reqs[req_id] + # Cleanup child request bookkeeping. + del self.child_reqs[c_req_id] + else: + # Not a parallel sampling request output + agg_outputs.append(c_out) + return agg_outputs + + def step(self, + request_outputs: List[RequestOutput]) -> List[RequestOutput]: + """Process parallel sampling child request outputs""" + if self.parent_reqs and request_outputs: + return self._aggregate_parallel_sampling_outputs(request_outputs) + # If there are no parallel sampling child request outputs, + # return unmodified. + return request_outputs + + def get_num_unfinished_requests(self, num_core_reqs: int) -> int: + """Get the number of unfinished requests, correcting for parallel + sampling. + + Args: + num_core_reqs: The number of unfinished requests in the engine core. + + Returns: + Number of unfinished requests, where each parallel sampling req + counts as 1 + """ + return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) + + async def generate_parallel_sampling_async( generate: AsyncGenerateMethodType, prompt: PromptType, @@ -235,3 +345,32 @@ async def generate_parallel_sampling_async( # Merge generators async for _, out in merge_async_iterators(*gens): yield out + + +def add_request_parallel_sampling( + add_request: SyncAddRequestMethodType, + parallel_mgr: SyncParallelSamplingManager, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, +) -> None: + """Add sync parallel sampling request.""" + req = ParallelSamplingRequest(request_id, params) + parallel_mgr.register_parent_request(req) + # Add n child requests with unique request IDs & random seeds and n=1 + for idx in range(req.n): + c_req_id, c_params = req.get_child_info(idx) + parallel_mgr.register_child_request(request_id, c_req_id, idx) + add_request(request_id=c_req_id, + prompt=prompt, + params=c_params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) # type: ignore From ecb39ae46114362d5c1a279c6c442c79458763e6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Feb 2025 19:21:37 +0000 Subject: [PATCH 46/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 1 - vllm/v1/engine/parallel_sampling.py | 14 ++++---------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 52e825ceb0c79..d3d88948785c9 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -205,7 +205,6 @@ def step(self) -> List[RequestOutput]: # 4) Process unfinished parallel sampling requests request_outputs = self.parallel_mgr.step(request_outputs) - return request_outputs def get_model_config(self): diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 91be169582100..aa19e80c41ba7 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -250,7 +250,7 @@ def register_child_request(self, req_id: str, child_req_id: str, """ self.child_reqs[child_req_id] = (index, req_id) - def _aggregate_parallel_sampling_outputs( + def step( self, outputs: List[RequestOutput], ) -> List[RequestOutput]: @@ -270,6 +270,9 @@ def _aggregate_parallel_sampling_outputs( List of parallel sampling parent request outputs & unmodified `n=1` request outputs passed-thru from input. """ + if not (self.parent_reqs and outputs): + # Return unmodified + return outputs agg_outputs = [] for c_out in outputs: c_req_id = c_out.request_id @@ -290,15 +293,6 @@ def _aggregate_parallel_sampling_outputs( agg_outputs.append(c_out) return agg_outputs - def step(self, - request_outputs: List[RequestOutput]) -> List[RequestOutput]: - """Process parallel sampling child request outputs""" - if self.parent_reqs and request_outputs: - return self._aggregate_parallel_sampling_outputs(request_outputs) - # If there are no parallel sampling child request outputs, - # return unmodified. - return request_outputs - def get_num_unfinished_requests(self, num_core_reqs: int) -> int: """Get the number of unfinished requests, correcting for parallel sampling. From dfb85136d661cfd40aaab27e1233422c5b7787a4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 17:27:16 +0000 Subject: [PATCH 47/54] reorg Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 6 +- vllm/v1/engine/parallel_sampling.py | 89 ++++++++++++++--------------- 2 files changed, 46 insertions(+), 49 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 07ff511b287fc..a7255f67c53dd 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,8 +21,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import (SyncParallelSamplingManager, - add_request_parallel_sampling) +from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -149,9 +148,8 @@ def add_request( priority=priority) else: # Special handling for parallel sampling requests - add_request_parallel_sampling( + self.parallel_mgr.add_request_parallel_sampling( add_request=self._add_request, - parallel_mgr=self.parallel_mgr, request_id=request_id, prompt=prompt, params=params, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index aa19e80c41ba7..14f2767b78841 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -234,13 +234,13 @@ def schedule_reset(self) -> None: if self.parent_reqs: self._do_reset = True - def register_parent_request(self, req: ParallelSamplingRequest) -> None: + def _register_parent_request(self, req: ParallelSamplingRequest) -> None: """Register parallel sampling parent request.""" self._reset_if_needed() self.parent_reqs[req.request_id] = req - def register_child_request(self, req_id: str, child_req_id: str, - index: int) -> None: + def _register_child_request(self, req_id: str, child_req_id: str, + index: int) -> None: """Register parallel sampling child request with parent. Args: @@ -250,6 +250,47 @@ def register_child_request(self, req_id: str, child_req_id: str, """ self.child_reqs[child_req_id] = (index, req_id) + def get_num_unfinished_requests(self, num_core_reqs: int) -> int: + """Get the number of unfinished requests, correcting for parallel + sampling. + + Args: + num_core_reqs: The number of unfinished requests in the engine core. + + Returns: + Number of unfinished requests, where each parallel sampling req + counts as 1 + """ + return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) + + def add_request_parallel_sampling( + self, + add_request: SyncAddRequestMethodType, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add sync parallel sampling request.""" + req = ParallelSamplingRequest(request_id, params) + self._register_parent_request(req) + # Add n child requests with unique request IDs & random seeds and n=1 + for idx in range(req.n): + c_req_id, c_params = req.get_child_info(idx) + self._register_child_request(request_id, c_req_id, idx) + add_request(request_id=c_req_id, + prompt=prompt, + params=c_params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) # type: ignore + def step( self, outputs: List[RequestOutput], @@ -293,19 +334,6 @@ def step( agg_outputs.append(c_out) return agg_outputs - def get_num_unfinished_requests(self, num_core_reqs: int) -> int: - """Get the number of unfinished requests, correcting for parallel - sampling. - - Args: - num_core_reqs: The number of unfinished requests in the engine core. - - Returns: - Number of unfinished requests, where each parallel sampling req - counts as 1 - """ - return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) - async def generate_parallel_sampling_async( generate: AsyncGenerateMethodType, @@ -339,32 +367,3 @@ async def generate_parallel_sampling_async( # Merge generators async for _, out in merge_async_iterators(*gens): yield out - - -def add_request_parallel_sampling( - add_request: SyncAddRequestMethodType, - parallel_mgr: SyncParallelSamplingManager, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -) -> None: - """Add sync parallel sampling request.""" - req = ParallelSamplingRequest(request_id, params) - parallel_mgr.register_parent_request(req) - # Add n child requests with unique request IDs & random seeds and n=1 - for idx in range(req.n): - c_req_id, c_params = req.get_child_info(idx) - parallel_mgr.register_child_request(request_id, c_req_id, idx) - add_request(request_id=c_req_id, - prompt=prompt, - params=c_params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) # type: ignore From 53e35dfb109e4e990d7b5f03a3dd010acdf66610 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 19:40:37 +0000 Subject: [PATCH 48/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 29 +++++---------- vllm/v1/engine/llm_engine.py | 38 +++++++------------ vllm/v1/engine/parallel_sampling.py | 58 +++++++++++------------------ 3 files changed, 45 insertions(+), 80 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 24a5adcbf38d6..1e661b04f32b7 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -51,8 +51,6 @@ def __init__( assert start_engine_loop self.model_config = vllm_config.model_config - self.enable_prefix_caching = ( - vllm_config.cache_config.enable_prefix_caching) self.log_requests = log_requests self.log_stats = log_stats @@ -254,27 +252,20 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: + kwargs = dict(prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) n = sampling_params.n if n is None or n == 1: - return self._generate( - prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) + return self._generate(**kwargs) else: # Special handling for parallel sampling requests - return generate_parallel_sampling_async( - generate=self._generate, - prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) + return generate_parallel_sampling_async(generate=self._generate, + **kwargs) async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a7255f67c53dd..69881a2aed3e8 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -49,7 +49,7 @@ def __init__( self.cache_config = vllm_config.cache_config # Bookkeeping for parallel sampling requests - self.parallel_mgr = SyncParallelSamplingManager() + self.parallel_manager = SyncParallelSamplingManager() # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -107,7 +107,7 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.parallel_mgr.get_num_unfinished_requests( + return self.parallel_manager.get_num_unfinished_requests( self.output_processor.get_num_unfinished_requests()) def has_unfinished_requests(self) -> bool: @@ -135,29 +135,22 @@ def add_request( priority: int = 0, ) -> None: """Add request.""" + kwargs = dict(request_id=request_id, + prompt=prompt, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority) # Handle parallel sampling requests differently. if params is None or isinstance(params, PoolingParams) or params.n == 1: - self._add_request(request_id=request_id, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) + self._add_request(**kwargs) else: # Special handling for parallel sampling requests - self.parallel_mgr.add_request_parallel_sampling( - add_request=self._add_request, - request_id=request_id, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) + self.parallel_manager.add_request_parallel_sampling( + add_request=self._add_request, **kwargs) def _add_request( self, @@ -185,9 +178,6 @@ def _add_request( self.engine_core.add_request(request) def step(self) -> List[RequestOutput]: - # Schedule reset of parallel sampling logic - # in between generate() runs - self.parallel_mgr.schedule_reset() # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() @@ -202,7 +192,7 @@ def step(self) -> List[RequestOutput]: request_outputs = processed_outputs.request_outputs # 4) Process unfinished parallel sampling requests - request_outputs = self.parallel_mgr.step(request_outputs) + request_outputs = self.parallel_manager.step(request_outputs) return request_outputs def get_model_config(self): diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 14f2767b78841..4f2cff41f76a7 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -67,15 +67,15 @@ def _get_child_sampling_params( # Reuse child sampling_params data structure return self.cached_child_sampling_params # Build child sampling_params - c_sampling_params = copy(self.sampling_params) - c_sampling_params.n = 1 + child_sampling_params = copy(self.sampling_params) + child_sampling_params.n = 1 if seed is None: # Cache child sampling_params for later reuse - self.cached_child_sampling_params = c_sampling_params + self.cached_child_sampling_params = child_sampling_params else: # Each child gets a clone with a unique seed - c_sampling_params.seed = seed + index - return c_sampling_params + child_sampling_params.seed = seed + index + return child_sampling_params def _add_output( self, @@ -218,25 +218,9 @@ def __init__(self): self.parent_reqs: Dict[str, ParallelSamplingRequest] = {} # Child req ID -> (child req index, parent req ID) self.child_reqs: Dict[str, Tuple[int, str]] = {} - # Flag to reset parallel sampling bookkeeping logic - # between engine runs - self._do_reset = False - - def _reset_if_needed(self) -> None: - """Reset at beginning of sync generate()""" - if self._do_reset: - self.parent_reqs.clear() - self.child_reqs.clear() - self._do_reset = False - - def schedule_reset(self) -> None: - """Schedule reset for the next time a parent request is added.""" - if self.parent_reqs: - self._do_reset = True def _register_parent_request(self, req: ParallelSamplingRequest) -> None: """Register parallel sampling parent request.""" - self._reset_if_needed() self.parent_reqs[req.request_id] = req def _register_child_request(self, req_id: str, child_req_id: str, @@ -280,11 +264,11 @@ def add_request_parallel_sampling( self._register_parent_request(req) # Add n child requests with unique request IDs & random seeds and n=1 for idx in range(req.n): - c_req_id, c_params = req.get_child_info(idx) - self._register_child_request(request_id, c_req_id, idx) - add_request(request_id=c_req_id, + child_req_id, child_params = req.get_child_info(idx) + self._register_child_request(request_id, child_req_id, idx) + add_request(request_id=child_req_id, prompt=prompt, - params=c_params, + params=child_params, arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, @@ -315,23 +299,23 @@ def step( # Return unmodified return outputs agg_outputs = [] - for c_out in outputs: - c_req_id = c_out.request_id - if cdx_req_id := self.child_reqs.get(c_req_id, None): + for output in outputs: + req_id = output.request_id + if child_req_entry := self.child_reqs.get(req_id, None): # For each parallel sampling child request output: - (cdx, req_id) = cdx_req_id - req = self.parent_reqs[req_id] + (index, parent_req_id) = child_req_entry + req = self.parent_reqs[parent_req_id] # Update parallel sampling request - if out := req.process_output(c_out, cdx): + if out := req.process_output(output, index): # Return parent request output if complete; # cleanup parent request bookkeeping. agg_outputs.append(out) - del self.parent_reqs[req_id] + del self.parent_reqs[parent_req_id] # Cleanup child request bookkeeping. - del self.child_reqs[c_req_id] + del self.child_reqs[req_id] else: # Not a parallel sampling request output - agg_outputs.append(c_out) + agg_outputs.append(output) return agg_outputs @@ -351,11 +335,11 @@ async def generate_parallel_sampling_async( # Aggregate generators for n child requests gens: List[AsyncGenerator[RequestOutput, None]] = [] for idx in range(parent_req.n): - c_req_id, c_params = parent_req.get_child_info(idx) + child_req_id, child_params = parent_req.get_child_info(idx) child_gen = generate( prompt=prompt, - sampling_params=c_params, - request_id=c_req_id, + sampling_params=child_params, + request_id=child_req_id, lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, From 382edd6f8ec6bb9a5e62beda4e013b4cad5544d5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 21:30:31 +0000 Subject: [PATCH 49/54] stream mode finished flag Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 4f2cff41f76a7..df156c43b056a 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -38,13 +38,16 @@ class ParallelSamplingRequest: request_id: str sampling_params: SamplingParams cached_child_sampling_params: Optional[SamplingParams] - request_output: Optional[RequestOutput] = None + request_output: Optional[RequestOutput] + num_completions: int def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params self.cached_child_sampling_params = None + self.request_output = None + self.num_completions = 0 def _get_child_sampling_params( self, @@ -95,6 +98,7 @@ def _add_output( child request. index: index within `n` child """ + self.num_completions += 1 new_completion = child_req_output.outputs[0] new_completion.index = index if self.request_output is None: @@ -112,6 +116,7 @@ def _add_output( def _get_final_request_output(self) -> RequestOutput: """Invariant: parent completion outputs sorted by index""" assert self.request_output is not None + self.request_output.finished = True self.request_output.outputs = sorted(self.request_output.outputs, key=lambda x: x.index) return self.request_output @@ -160,6 +165,11 @@ def process_output( # stream=true: return child completions immediately child_req_output.request_id = self.request_id child_req_output.outputs[0].index = index + if child_req_output.finished: + # Parent request is complete if all child requests are + # complete. + self.num_completions += 1 + child_req_output.finished = (self.num_completions == self.n) return child_req_output # stream=false: aggregate child completions @@ -197,11 +207,6 @@ async def wrap_child_async_generator( if req_out := self.process_output(out, index): yield req_out - @property - def num_completions(self) -> int: - assert self.request_output is not None - return len(self.request_output.outputs) - @property def n(self) -> int: return self.sampling_params.n From 2001bef02ca9c4e5b5488f6a61c06f68aff54eb6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 21:40:44 +0000 Subject: [PATCH 50/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/async_llm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1e661b04f32b7..36a02628f405d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -259,8 +259,7 @@ def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority) - n = sampling_params.n - if n is None or n == 1: + if sampling_params.n is None or sampling_params.n == 1: return self._generate(**kwargs) else: # Special handling for parallel sampling requests From 892429c3e8a4b5300ea4d0edfc3ac381cf17c3ad Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 21:45:56 +0000 Subject: [PATCH 51/54] refactor Signed-off-by: Andrew Feldman --- vllm/v1/engine/llm_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 69881a2aed3e8..599e836367c56 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -192,8 +192,7 @@ def step(self) -> List[RequestOutput]: request_outputs = processed_outputs.request_outputs # 4) Process unfinished parallel sampling requests - request_outputs = self.parallel_manager.step(request_outputs) - return request_outputs + return self.parallel_manager.step(request_outputs) def get_model_config(self): return self.model_config From 267c1b80a82807d04103bced187c71125c84bd10 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 21:49:47 +0000 Subject: [PATCH 52/54] rename Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index df156c43b056a..c8a103e1a8cb9 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -39,7 +39,7 @@ class ParallelSamplingRequest: sampling_params: SamplingParams cached_child_sampling_params: Optional[SamplingParams] request_output: Optional[RequestOutput] - num_completions: int + num_finished_completions: int def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: @@ -47,7 +47,7 @@ def __init__(self, request_id: str, self.sampling_params = sampling_params self.cached_child_sampling_params = None self.request_output = None - self.num_completions = 0 + self.num_finished_completions = 0 def _get_child_sampling_params( self, @@ -98,7 +98,7 @@ def _add_output( child request. index: index within `n` child """ - self.num_completions += 1 + self.num_finished_completions += 1 new_completion = child_req_output.outputs[0] new_completion.index = index if self.request_output is None: @@ -168,13 +168,14 @@ def process_output( if child_req_output.finished: # Parent request is complete if all child requests are # complete. - self.num_completions += 1 - child_req_output.finished = (self.num_completions == self.n) + self.num_finished_completions += 1 + child_req_output.finished = ( + self.num_finished_completions == self.n) return child_req_output # stream=false: aggregate child completions self._add_output(child_req_output, index) - if self.num_completions == self.n: + if self.num_finished_completions == self.n: # Return aggregated request output after obtaining # all completions return self._get_final_request_output() From bef174ed0dbd5e7a4b0e54125e9cdf160c749fb3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Feb 2025 22:07:25 +0000 Subject: [PATCH 53/54] protocol-based types Signed-off-by: Andrew Feldman --- vllm/v1/engine/parallel_sampling.py | 36 +++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index c8a103e1a8cb9..5d4ea111abfc9 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import (AsyncGenerator, Callable, Dict, List, Mapping, Optional, +from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Protocol, Tuple, Union) from vllm.inputs import PromptType @@ -12,16 +12,32 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.utils import merge_async_iterators -AsyncGenerateMethodType = Callable[[ - PromptType, SamplingParams, str, Optional[LoRARequest], Optional[Mapping[ - str, str]], Optional[PromptAdapterRequest], int -], AsyncGenerator[RequestOutput, None]] -SyncAddRequestMethodType = Callable[[ - str, PromptType, Union[SamplingParams, PoolingParams], Optional[float], - Optional[LoRARequest], Optional[Mapping[ - str, str]], Optional[PromptAdapterRequest], int -], None] +class AsyncGenerateMethodType(Protocol): + + def __call__(self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0) -> AsyncGenerator[RequestOutput, None]: + ... + + +class SyncAddRequestMethodType(Protocol): + + def __call__(self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0) -> None: + ... class ParallelSamplingRequest: From 7e845cba21882037b6fa9308dcbbadb89315b923 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 24 Feb 2025 07:39:05 +0000 Subject: [PATCH 54/54] llm_engine test fixtures Signed-off-by: Andrew Feldman --- tests/v1/engine/test_llm_engine.py | 54 ++++++++++++++++++------------ 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 7f2636d44a08d..de2a39ee9c083 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -12,22 +12,37 @@ DTYPE = "half" +def _vllm_model(apc: bool, vllm_runner, monkeypatch): + """Set up VllmRunner instance.""" + monkeypatch.setenv("VLLM_USE_V1", "1") + # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + return vllm_runner( + MODEL, + dtype=DTYPE, + max_model_len=128, + enforce_eager=True, + enable_prefix_caching=apc, + gpu_memory_utilization=0.5, + ) + + @pytest.fixture( - scope="module", + # Function scope decouples tests & allows + # env var adjustment via monkeypatch + scope="function", # Prefix caching params=[False, True]) -def vllm_model(vllm_runner, request): - """VllmRunner test fixture parameterized by APC.""" - enable_prefix_caching = request.param - with vllm_runner( - MODEL, - dtype=DTYPE, - max_model_len=128, - enforce_eager=True, - enable_prefix_caching=enable_prefix_caching, - gpu_memory_utilization=0.5, - ) as vllm_model: - # VllmRunner instance is cleaned up after test. +def vllm_model(vllm_runner, request, monkeypatch): + """VllmRunner test fixture parameterized by APC True/False.""" + with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model: + yield vllm_model + + +@pytest.fixture(scope="function") +def vllm_model_apc(vllm_runner, monkeypatch): + """VllmRunner test fixture with APC.""" + with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model: yield vllm_model @@ -53,15 +68,13 @@ def get_mostly_n_gt1() -> int: ], n_list -def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: +def test_parallel_sampling(vllm_model, example_prompts) -> None: """Test passes if parallel sampling `n>1` yields `n` unique completions. Args: - monkeypatch: test fixture for modifying text env, scoped to the test. vllm_model: VllmRunner instance under test. example_prompt: test fixture providing prompts for testing. """ - monkeypatch.setenv("VLLM_USE_V1", "1") sampling_params_list, n_list = _get_test_sampling_params(example_prompts) model: LLM = vllm_model.model outputs = model.generate(example_prompts, sampling_params_list) @@ -89,16 +102,13 @@ def test_parallel_sampling(monkeypatch, vllm_model, example_prompts) -> None: f" {n}. Repeats: {repeats}") -def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): +def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc): """Test passes if LLMEngine raises an exception when it is configured for automatic prefix caching and it receives a request with prompt_logprobs enabled, which is incompatible.""" - - monkeypatch.setenv("VLLM_USE_V1", "1") - # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + model: LLM = vllm_model_apc.model with pytest.raises(ValueError) as excinfo: - LLM(model=MODEL, enable_prefix_caching=True).generate( + model.generate( "Hello, my name is", SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))