Skip to content

Commit

Permalink
log probs seem to work for both stream and not
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Feb 19, 2024
1 parent 786b7f1 commit c3e2a7e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand Down Expand Up @@ -93,6 +94,8 @@ def to_sampling_params(self) -> SamplingParams:
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.logprobs,
prompt_logprobs=self.logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
Expand Down Expand Up @@ -208,6 +211,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand All @@ -228,6 +232,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
72 changes: 65 additions & 7 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union, Dict, Callable
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA

logger = init_logger(__name__)

TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]

class OpenAIServingChat(OpenAIServing):

Expand Down Expand Up @@ -77,10 +82,10 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
request, result_generator, request_id, self._create_logprobs)
else:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
request, raw_request, result_generator, request_id, self._create_logprobs)

def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
Expand All @@ -90,7 +95,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
result_generator: AsyncIterator[RequestOutput], request_id: str,
create_logprobs_fn: TypeCreateLogProbsFn
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:

model_name = request.model
Expand All @@ -101,7 +107,7 @@ async def chat_completion_stream_generator(
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
index=i, delta=DeltaMessage(role=role), logprobs=None, finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
Expand All @@ -118,6 +124,7 @@ async def chat_completion_stream_generator(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]

if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
Expand All @@ -129,6 +136,7 @@ async def chat_completion_stream_generator(
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
Expand All @@ -147,13 +155,37 @@ async def chat_completion_stream_generator(

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)

if request.echo and request.max_tokens == 0:
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
elif request.echo and request.max_tokens > 0:
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
else:
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs is not None:
assert(top_logprobs is not None),\
"top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None

previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -174,6 +206,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -193,7 +226,8 @@ async def chat_completion_stream_generator(
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
request_id: str,
create_logprobs_fn: TypeCreateLogProbsFn) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = request.model
created_time = int(time.monotonic())
Expand All @@ -208,11 +242,35 @@ async def chat_completion_full_generator(
assert final_res is not None

choices = []

prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs

role = self.get_chat_request_role(request)
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
else:
token_ids = output.token_ids
top_logprobs = output.logprobs

if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None

choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
Expand Down

0 comments on commit c3e2a7e

Please sign in to comment.