Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add engine option to return only deltas or final output #7381

Merged
merged 22 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1eb9991
[Core] Add engine option to return only deltas or final output
njhill Aug 9, 2024
9bc3fdd
Fixes
njhill Aug 10, 2024
ef2e59f
Fix ignored sequence case
njhill Aug 10, 2024
dc1f3f2
Also exclude prompt details in subsequent outputs in delta mode
njhill Aug 13, 2024
9d35a00
Fix prompt token counts in streaming cases
njhill Aug 13, 2024
b7ff44e
Simplification suggestion from @joerunde
njhill Aug 14, 2024
34df9bd
Make tests more robust
njhill Aug 15, 2024
a68506f
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Aug 15, 2024
cfe7118
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Aug 27, 2024
45fd069
Post-merge wip
njhill Aug 27, 2024
3f21ad6
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 6, 2024
d59ffd1
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 8, 2024
d2f36dd
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 10, 2024
2843365
Fix delta computation, remove unrelated changes
njhill Sep 10, 2024
2736ab1
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 10, 2024
a045dff
Address Alex's comments, fix include_prompt logic
njhill Sep 10, 2024
58f6112
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 11, 2024
e7a2b55
Add tests
njhill Sep 11, 2024
6b1f355
Some rework/simplification
njhill Sep 12, 2024
3233a92
Remove obsolete engine.step_return_finished_only field
njhill Sep 12, 2024
f351ed2
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 12, 2024
75814bd
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
Expand Down Expand Up @@ -1273,7 +1273,7 @@ def _process_model_outputs(self,

ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed

"""
now = time.time()

Expand Down Expand Up @@ -1378,7 +1378,8 @@ def _process_model_outputs(self,
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if request_output:
ctx.request_outputs.append(request_output)

# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
Expand Down Expand Up @@ -1415,14 +1416,19 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(seq_group)
if request_output:
ctx.request_outputs.append(request_output)

for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue

request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if request_output:
ctx.request_outputs.append(request_output)

# Immediately process request outputs here (if callback is given)
if (ctx.request_outputs
Expand Down
17 changes: 8 additions & 9 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -642,14 +642,12 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options)

# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
Expand Down Expand Up @@ -724,6 +722,7 @@ def _run_engine(
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
Expand Down Expand Up @@ -316,6 +317,8 @@ def to_sampling_params(
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)

@model_validator(mode="before")
Expand Down Expand Up @@ -559,6 +562,8 @@ def to_sampling_params(
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)

@model_validator(mode="before")
Expand Down
124 changes: 71 additions & 53 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ async def create_chat_completion(
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1]["role"]
return request.messages[-1]["role"]

async def chat_completion_stream_generator(
self,
Expand All @@ -264,15 +263,36 @@ async def chat_completion_stream_generator(

# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices

num_prompt_tokens = 0

tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None

tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))

all_previous_token_ids: Optional[List[List[int]]]
# These are only used in "auto" tool choice case
if tool_choice_auto:
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None

try:
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)

# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
Expand Down Expand Up @@ -305,10 +325,10 @@ async def chat_completion_stream_generator(
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
Expand Down Expand Up @@ -344,12 +364,10 @@ async def chat_completion_stream_generator(
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
Expand All @@ -360,65 +378,66 @@ async def chat_completion_stream_generator(
first_iteration = False

for output in res.outputs:

i = output.index

if finish_reason_sent[i]:
continue

delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
assert output.logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
token_ids=output.token_ids,
top_logprobs=output.logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None

delta_text = output.text[len(previous_texts[i]):]
delta_message: Optional[DeltaMessage] = None
delta_text = output.text
delta_message: Optional[DeltaMessage]

# handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is
ChatCompletionNamedToolChoiceParam):
if tool_choice_function_name:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name,
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])

# handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request)
and tool_parser):
elif tool_choice_auto:
assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)

delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i],
current_text=output.text,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids= \
output.token_ids[
:-1 * len(delta_token_ids)
],
current_token_ids=output.token_ids,
delta_token_ids=delta_token_ids
)
)
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))

# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids

# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)

# set the previous values for the next iteration
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_num_tokens[i] += len(output.token_ids)

# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
Expand All @@ -445,13 +464,12 @@ async def chat_completion_stream_generator(
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
Expand Down Expand Up @@ -482,7 +500,7 @@ async def chat_completion_stream_generator(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}))

# get what we've streamed so for for arguments
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
Expand All @@ -500,7 +518,6 @@ async def chat_completion_stream_generator(
])

# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
Expand All @@ -518,13 +535,12 @@ async def chat_completion_stream_generator(
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
Expand All @@ -538,10 +554,11 @@ async def chat_completion_stream_generator(
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)

final_usage_chunk = ChatCompletionStreamResponse(
Expand Down Expand Up @@ -680,6 +697,7 @@ async def chat_completion_full_generator(
or "")
choice.message.content = full_message

assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
Expand Down Expand Up @@ -789,9 +807,9 @@ def _should_check_for_unstreamed_tool_arg_tokens(
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message
output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
)
Loading
Loading