From c3e2a7e7fee23e71f64a2a382b4c024f0b1a4da7 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 18 Feb 2024 18:56:51 -0800 Subject: [PATCH] log probs seem to work for both stream and not --- vllm/entrypoints/openai/protocol.py | 5 ++ vllm/entrypoints/openai/serving_chat.py | 72 ++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fc15b7833ecf2..6bebf38c1a161 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -62,6 +62,7 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + logprobs: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -93,6 +94,8 @@ def to_sampling_params(self) -> SamplingParams: stop=self.stop, stop_token_ids=self.stop_token_ids, max_tokens=self.max_tokens, + logprobs=self.logprobs, + prompt_logprobs=self.logprobs if self.echo else None, best_of=self.best_of, top_k=self.top_k, ignore_eos=self.ignore_eos, @@ -208,6 +211,7 @@ class ChatMessage(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None @@ -228,6 +232,7 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 850797ae4b9b6..2e8424a6c5494 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,7 +1,7 @@ import time import codecs from fastapi import Request -from typing import AsyncGenerator, AsyncIterator, Optional, List, Union +from typing import AsyncGenerator, AsyncIterator, Optional, List, Union, Dict, Callable from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -9,12 +9,17 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + LogProbs, UsageInfo) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA logger = init_logger(__name__) +TypeTokenIDs = List[int] +TypeTopLogProbs = List[Optional[Dict[int, float]]] +TypeCreateLogProbsFn = Callable[ + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] class OpenAIServingChat(OpenAIServing): @@ -77,10 +82,10 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id) + request, result_generator, request_id, self._create_logprobs) else: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + request, raw_request, result_generator, request_id, self._create_logprobs) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: @@ -90,7 +95,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str + result_generator: AsyncIterator[RequestOutput], request_id: str, + create_logprobs_fn: TypeCreateLogProbsFn ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: model_name = request.model @@ -101,7 +107,7 @@ async def chat_completion_stream_generator( role = self.get_chat_request_role(request) for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=DeltaMessage(role=role), finish_reason=None) + index=i, delta=DeltaMessage(role=role), logprobs=None, finish_reason=None) chunk = ChatCompletionStreamResponse(id=request_id, object=chunk_object_type, created=created_time, @@ -118,6 +124,7 @@ async def chat_completion_stream_generator( "content") and request.messages[-1].get( "role") == role: last_msg_content = request.messages[-1]["content"] + if last_msg_content: for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( @@ -129,6 +136,7 @@ async def chat_completion_stream_generator( object=chunk_object_type, created=created_time, choices=[choice_data], + logprobs=None, model=model_name) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -147,13 +155,37 @@ async def chat_completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + if request.echo and request.max_tokens == 0: + delta_token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + elif request.echo and request.max_tokens > 0: + delta_token_ids = res.prompt_token_ids + output.token_ids + top_logprobs = res.prompt_logprobs + (output.logprobs or []) + else: + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs is not None: + assert(top_logprobs is not None),\ + "top_logprobs must be provided when logprobs is requested" + logprobs = create_logprobs_fn( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + previous_num_tokens[i] = len(output.token_ids) if output.finish_reason is None: # Send token-by-token response for each request.n choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), + logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, @@ -174,6 +206,7 @@ async def chat_completion_stream_generator( choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), + logprobs=logprobs, finish_reason=output.finish_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -193,7 +226,8 @@ async def chat_completion_stream_generator( async def chat_completion_full_generator( self, request: ChatCompletionRequest, raw_request: Request, result_generator: AsyncIterator[RequestOutput], - request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + request_id: str, + create_logprobs_fn: TypeCreateLogProbsFn) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = request.model created_time = int(time.monotonic()) @@ -208,11 +242,35 @@ async def chat_completion_full_generator( assert final_res is not None choices = [] + + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + role = self.get_chat_request_role(request) for output in final_res.outputs: + if request.echo and request.max_tokens == 0: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + elif request.echo and request.max_tokens > 0: + token_ids = prompt_token_ids + output.token_ids + top_logprobs = prompt_logprobs + output.logprobs + else: + token_ids = output.token_ids + top_logprobs = output.logprobs + + if request.logprobs is not None: + logprobs = create_logprobs_fn( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + choice_data = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role=role, content=output.text), + logprobs=logprobs, finish_reason=output.finish_reason, ) choices.append(choice_data)