From 51f8644d290935c6e1b3a7efef39e5c6dee6e97d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 12 Sep 2024 20:02:00 +0100 Subject: [PATCH] [Core] Add engine option to return only deltas or final output (#7381) --- .buildkite/test-pipeline.yaml | 1 + tests/async_engine/test_async_llm_engine.py | 161 ++++++++++++++++-- vllm/engine/llm_engine.py | 24 +-- vllm/entrypoints/llm.py | 23 +-- vllm/entrypoints/openai/protocol.py | 7 +- vllm/entrypoints/openai/serving_chat.py | 125 ++++++++------ vllm/entrypoints/openai/serving_completion.py | 32 ++-- vllm/outputs.py | 79 ++++++--- vllm/sampling_params.py | 17 +- vllm/sequence.py | 39 ++++- 10 files changed, 371 insertions(+), 137 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 25f18cc57793e..d0732ec3fe2fb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,6 +50,7 @@ steps: - tests/worker commands: - pytest -v -s async_engine # Async Engine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 3bf11fbcfb3b8..bab42942d311f 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,7 +1,10 @@ import asyncio +import os +import uuid from asyncio import CancelledError +from copy import copy from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pytest import pytest_asyncio @@ -11,6 +14,7 @@ from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.outputs import RequestOutput as RealRequestOutput +from vllm.sampling_params import RequestOutputKind from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -122,8 +126,17 @@ def start_engine(): timeout_s=60, ) + num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1")) + print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}") + return AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True)) + AsyncEngineArgs(model="facebook/opt-125m", + enforce_eager=True, + num_scheduler_steps=num_scheduler_steps)) + + +def uid() -> str: + return str(uuid.uuid4()) @pytest_asyncio.fixture(scope="module") @@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool: @pytest.mark.asyncio(scope="module") async def test_asyncio_run(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + async def run(prompt: str): sampling_params = SamplingParams( temperature=0, max_tokens=32, + min_tokens=32, ) + output_count = 0 + final_output = None async for output in async_engine.generate(prompt, sampling_params, - request_id=prompt): + request_id=uid()): + output_count += 1 final_output = output - return final_output + return final_output, output_count results = await asyncio.gather( run("test0"), - run("test1"), + run("test0"), ) assert len(results) == 2 + first, second = results + + # remove nondeterministic fields for comparison + first[0].metrics = None + second[0].metrics = None + first[0].request_id = None + second[0].request_id = None + + assert str(first) == str(second) + + output_count = results[0][1] + if num_scheduler_steps == 1: + assert output_count == 32 + else: + assert 1 < output_count < 32 + + +@pytest.mark.asyncio(scope="module") +async def test_output_kinds(async_engine): + """Test that output_kind works as expected and that + results are equivalent across different kinds.""" + + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + + sampling_params = SamplingParams( + temperature=0, + max_tokens=32, + min_tokens=32, + ) + + async def run(prompt: str, kind: RequestOutputKind): + params = copy(sampling_params) + params.output_kind = kind + + output_count = 0 + final_output = None + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + output_count += 1 + final_output = output + + assert final_output is not None + return (final_output.prompt_token_ids, + final_output.outputs[0].token_ids, + final_output.outputs[0].text, output_count) + + async def run_deltas(prompt: str): + params = copy(sampling_params) + params.output_kind = RequestOutputKind.DELTA + + prompt_tokens = None + output_tokens: List[int] = [] + output_text = "" + output_count = 0 + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + token_ids = output.outputs[0].token_ids + text = output.outputs[0].text + + # Ensure we get prompt ids iff we haven't yet received output tokens + if output_tokens: + assert 1 <= len(token_ids) <= num_scheduler_steps + assert text + assert not output.prompt_token_ids + else: + assert output.prompt_token_ids + prompt_tokens = output.prompt_token_ids + + output_tokens.extend(token_ids) + output_text += text + + output_count += 1 + return prompt_tokens, output_tokens, output_text, output_count + + results = await asyncio.gather( + run("common input prompt", RequestOutputKind.CUMULATIVE), + run("common input prompt", RequestOutputKind.FINAL_ONLY), + run_deltas("common input prompt")) + + # Make sure outputs are the same + prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results) + assert len(prompt_set) == 1 + + text_set = set(text for _, _, text, _ in results) + assert len(text_set) == 1 + + tokens_set = set(tuple(ids) for _, ids, _, _ in results) + assert len(tokens_set) == 1 + + cumulative, final, deltas = results + + # output message counts + assert cumulative[3] == deltas[3] + + if num_scheduler_steps == 1: + assert cumulative[3] == 32 + else: + assert 1 < cumulative[3] < 32 + + assert final[3] == 1 @pytest.mark.asyncio(scope="module") async def test_cancellation(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + sampling_params = SamplingParams( temperature=0, - min_tokens=10, - max_tokens=10, + min_tokens=13, + max_tokens=13, ) + stop_at = 5 if num_scheduler_steps == 1 else 1 + + request_id = uid() + i = 0 with pytest.raises(CancelledError): async for output in async_engine.generate("test2", sampling_params, - request_id="test2"): + request_id=request_id): assert not output.finished i += 1 - if i == 5: - await async_engine.abort("test2") + if i == stop_at: + await async_engine.abort(request_id) - assert i == 5 + assert i == stop_at @pytest.mark.asyncio(scope="module") async def test_delayed_generator(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + + if scheduler_config.num_scheduler_steps != 1: + pytest.skip("no need to test this one with multistep") + sampling_params = SamplingParams( temperature=0, min_tokens=10, max_tokens=10, ) - stream = async_engine.generate("test3", - sampling_params, - request_id="test3") + stream = async_engine.generate("test3", sampling_params, request_id=uid()) i = 0 final_output: Optional[RealRequestOutput] = None async for output in stream: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92e46c7af5162..e07893b29ec38 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -39,7 +39,7 @@ RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -225,9 +225,6 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, - # To improve performance, only final requests outputs may be required. - # If this set to true, then no intermediate outputs will be returned. - step_return_finished_only: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -295,7 +292,6 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats - self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -1273,7 +1269,7 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - + """ now = time.time() @@ -1378,7 +1374,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create(seq_group) - ctx.request_outputs.append(request_output) + if request_output: + ctx.request_outputs.append(request_output) # When we process a single request, we skip it for the next time, # and invoke the request output callback (if there was final output) @@ -1415,14 +1412,19 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - if (seq_group.is_finished() - if self.step_return_finished_only else True): - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create(seq_group) + if request_output: ctx.request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + request_output = RequestOutputFactory.create(seq_group) - ctx.request_outputs.append(request_output) + if request_output: + ctx.request_outputs.append(request_output) # Immediately process request outputs here (if callback is given) if (ctx.request_outputs diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b1d9f386b6c3e..c01bffeb4289d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -19,7 +19,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -642,14 +642,12 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") - if isinstance(params, list): - params = [ - self._add_guided_processor(param, guided_options) - if isinstance(param, SamplingParams) else param - for param in params - ] - elif isinstance(params, SamplingParams): - params = self._add_guided_processor(params, guided_options) + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_processor(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. for i, request_inputs in enumerate(inputs): @@ -709,9 +707,6 @@ def _run_engine( f"output: {0:.2f} toks/s"), ) - # In the loop below, only finished outputs are used - self.llm_engine.step_return_finished_only = True - # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -724,6 +719,7 @@ def _run_engine( if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( @@ -735,9 +731,6 @@ def _run_engine( f"output: {out_spd:.2f} toks/s") pbar.update(1) - # Restore original behavior - self.llm_engine.step_return_finished_only = False - if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 374196044b7e8..7e9f53b1816d1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -12,7 +12,8 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -316,6 +317,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") @@ -559,6 +562,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8ac4caffb37f0..58e42fb5363fb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -246,8 +246,7 @@ async def create_chat_completion( def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: return self.response_role - else: - return request.messages[-1]["role"] + return request.messages[-1]["role"] async def chat_completion_stream_generator( self, @@ -264,15 +263,37 @@ async def chat_completion_stream_generator( # Send response for each token for each request.n (index) num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + tool_parser: Optional[ToolParser] = self.tool_parser( tokenizer) if self.tool_parser else None + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + all_previous_token_ids: Optional[List[List[int]]] + if tool_choice_auto: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + else: + previous_texts, all_previous_token_ids = None, None + try: async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). @@ -305,10 +326,10 @@ async def chat_completion_stream_generator( and request.stream_options.include_usage): # if continuous usage stats are requested, add it if request.stream_options.continuous_usage_stats: - prompt_tokens = len(res.prompt_token_ids) - usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=0, - total_tokens=prompt_tokens) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) chunk.usage = usage # otherwise don't else: @@ -344,12 +365,10 @@ async def chat_completion_stream_generator( request.stream_options.include_usage): if (request.stream_options. continuous_usage_stats): - prompt_tokens = len( - res.prompt_token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=prompt_tokens) + total_tokens=num_prompt_tokens) chunk.usage = usage else: chunk.usage = None @@ -360,65 +379,66 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: - i = output.index if finish_reason_sent[i]: continue - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, ( + assert output.logprobs is not None, ( "Did not output logprobs") logprobs = self._create_chat_logprobs( - token_ids=delta_token_ids, - top_logprobs=out_logprobs, + token_ids=output.token_ids, + top_logprobs=output.logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None - delta_text = output.text[len(previous_texts[i]):] - delta_message: Optional[DeltaMessage] = None + delta_text = output.text + delta_message: Optional[DeltaMessage] # handle streaming deltas for tools with named tool_choice - if (request.tool_choice and type(request.tool_choice) is - ChatCompletionNamedToolChoiceParam): + if tool_choice_function_name: delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(function=DeltaFunctionCall( - name=request.tool_choice.function.name, + name=tool_choice_function_name, arguments=delta_text), index=i) ]) # handle streaming deltas for tools with "auto" tool choice - elif (self._should_stream_with_auto_tool_parsing(request) - and tool_parser): + elif tool_choice_auto: + assert previous_texts is not None + assert all_previous_token_ids is not None + assert tool_parser is not None + #TODO optimize manipulation of these lists + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + delta_message = ( tool_parser.extract_tool_calls_streaming( - previous_text=previous_texts[i], - current_text=output.text, + previous_text=previous_text, + current_text=current_text, delta_text=delta_text, - previous_token_ids= \ - output.token_ids[ - :-1 * len(delta_token_ids) - ], - current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids - ) - ) + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) # set the previous values for the next iteration - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_num_tokens[i] += len(output.token_ids) # if the message delta is None (e.g. because it was a # "control token" for tool calls or the parser otherwise @@ -445,13 +465,12 @@ async def chat_completion_stream_generator( # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -482,7 +501,7 @@ async def chat_completion_stream_generator( tool_parser.prev_tool_call_arr[index].get( "arguments", {})) - # get what we've streamed so for for arguments + # get what we've streamed so far for arguments # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] @@ -500,7 +519,6 @@ async def chat_completion_stream_generator( ]) # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, @@ -518,13 +536,12 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -538,10 +555,11 @@ async def chat_completion_stream_generator( # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): + completion_tokens = previous_num_tokens[i] final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( @@ -680,6 +698,7 @@ async def chat_completion_full_generator( or "") choice.message.content = full_message + assert final_res.prompt_token_ids is not None num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) @@ -789,9 +808,9 @@ def _should_check_for_unstreamed_tool_arg_tokens( return bool( # if there is a delta message that includes tool calls which # include a function that has arguments - self.enable_auto_tools and self.tool_parser and delta_message + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message and delta_message.tool_calls and delta_message.tool_calls[0] and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None - and output.finish_reason is not None ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 34f1200753f8d..42142efb5f23e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -223,9 +223,10 @@ async def completion_stream_generator( tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices * num_prompts + previous_text_lens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts try: async for prompt_idx, res in result_generator: @@ -233,6 +234,10 @@ async def completion_stream_generator( prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt + # Prompt details are excluded from later streamed outputs + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[ int, Logprob]]]] @@ -244,6 +249,7 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_token_ids is not None assert prompt_text is not None # only return the prompt delta_text = prompt_text @@ -252,6 +258,7 @@ async def completion_stream_generator( has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert prompt_token_ids is not None assert prompt_text is not None assert prompt_logprobs is not None # echo the prompt and first token @@ -266,11 +273,9 @@ async def completion_stream_generator( has_echoed[i] = True else: # return just the delta - delta_text = output.text[len(previous_texts[i]):] - delta_token_ids = output.token_ids[ - previous_num_tokens[i]:] - out_logprobs = output.logprobs[previous_num_tokens[ - i]:] if output.logprobs else None + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs if request.logprobs is not None: assert out_logprobs is not None, ( @@ -280,13 +285,13 @@ async def completion_stream_generator( top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, - initial_text_offset=len(previous_texts[i]), + initial_text_offset=previous_text_lens[i], ) else: logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) finish_reason = output.finish_reason stop_reason = output.stop_reason @@ -307,8 +312,8 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(prompt_token_ids) - completion_tokens = len(output.token_ids) + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -356,6 +361,7 @@ def request_output_to_completion_response( for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt @@ -411,9 +417,9 @@ def request_output_to_completion_response( ) choices.append(choice_data) + num_generated_tokens += len(output.token_ids) + num_prompt_tokens += len(prompt_token_ids) - num_generated_tokens += sum( - len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/vllm/outputs.py b/vllm/outputs.py index e091b576f5972..85ea9196b25df 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -5,6 +5,7 @@ from typing import Union from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceStatus) @@ -92,7 +93,7 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -113,19 +114,26 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": - if seq_group.sampling_params is None: + def from_seq_group(cls, + seq_group: SequenceGroup) -> Optional["RequestOutput"]: + sampling_params = seq_group.sampling_params + if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): + return None + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs else: # Get the top-n sequences. - n = seq_group.sampling_params.n - if seq_group.sampling_params.use_beam_search: + n = sampling_params.n + if sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + sampling_params.length_penalty) else: sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) @@ -135,26 +143,49 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # NOTE: We need omit logprobs here explicitly because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. - include_logprobs = seq_group.sampling_params.logprobs is not None - text_buffer_length = seq_group.sampling_params.output_text_buffer_length - outputs = [ - CompletionOutput( - seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.data._output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs - ] + include_logprobs = sampling_params.logprobs is not None + text_buffer_length = sampling_params.output_text_buffer_length + delta = sampling_params.output_kind == RequestOutputKind.DELTA + + outputs = [] + include_prompt = True + for seq in top_n_seqs: + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + output_logprobs = seq.output_logprobs if include_logprobs else None + + if delta: + # Slice logprobs delta if applicable + if output_logprobs: + output_logprobs = output_logprobs[-len(output_token_ids):] + # Don't include prompt if this is after the first output + # containing decode token ids + if include_prompt and seq.get_output_len() > len( + output_token_ids): + include_prompt = False + + outputs.append( + CompletionOutput( + seqs.index(seq), output_text, output_token_ids, + seq.get_cumulative_logprob() if include_logprobs else None, + output_logprobs, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason)) # Every sequence in the sequence group should have the same prompt. - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - finished = seq_group.is_finished() + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) return cls(seq_group.request_id, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c83ed5cca6791..5edbc8e424e81 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,6 @@ """Sampling parameters for text generation.""" import copy -from enum import IntEnum +from enum import Enum, IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -33,6 +33,15 @@ class SamplingType(IntEnum): to sample from.""" +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOuputs + FINAL_ONLY = 2 + + class SamplingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -147,6 +156,7 @@ class SamplingParams( logits_processors: Optional[Any] = None include_stop_str_in_output: bool = False truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE # The below fields are not supposed to be used as an input. # They are set in post_init. @@ -182,6 +192,7 @@ def from_optional( logits_processors: Optional[List[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None, + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, ) -> "SamplingParams": return SamplingParams( n=1 if n is None else n, @@ -213,6 +224,7 @@ def from_optional( spaces_between_special_tokens=spaces_between_special_tokens, logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, + output_kind=output_kind, ) def __post_init__(self) -> None: @@ -317,6 +329,9 @@ def _verify_args(self) -> None: raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop.") + if self.best_of != self.n and self.output_kind == ( + RequestOutputKind.DELTA): + raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_beam_search(self) -> None: if self.best_of == 1: diff --git a/vllm/sequence.py b/vllm/sequence.py index 135586831e680..98a8b73586062 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,9 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - Optional, Set, Tuple, Union, cast) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union, cast import msgspec import torch @@ -407,6 +408,10 @@ def __init__( self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None + # These are used to keep track of delta outputs + self._last_token_ids_offset: int = 0 + self._last_output_text_offset: int = 0 + # Used for incremental detokenization self.prefix_offset = 0 self.read_offset = 0 @@ -462,11 +467,35 @@ def prompt_adapter_id(self) -> int: return self.prompt_adapter_request.prompt_adapter_id \ if self.prompt_adapter_request else 0 - def get_output_text_to_return(self, buffer_length: int): + def get_output_text_to_return(self, buffer_length: int, + delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) + if not delta: + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + length = len(self.output_text) - buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + def get_output_token_ids_to_return(self, + delta: bool) -> GenericSequence[int]: + """If delta is True, only new tokens since the last call to + this method are returned""" + if not delta: + return self.get_output_token_ids() + length = self.get_output_len() + last_offset = self._last_token_ids_offset + if last_offset < length: + self._last_token_ids_offset = length + return self.data._output_token_ids[last_offset:] + return () def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size