Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Core] Move logits processor construction to engine #7666

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions tests/entrypoints/openai/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch
from transformers import AutoTokenizer

from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)

Expand Down Expand Up @@ -44,10 +45,8 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=sample_regex)
regex_lp = await get_guided_decoding_logits_processor(
regex_request = GuidedDecodingRequest(guided_regex=sample_regex)
regex_lp = get_local_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
Expand All @@ -59,10 +58,8 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=sample_json_schema)
json_lp = await get_guided_decoding_logits_processor(
json_request = GuidedDecodingRequest(guided_json=sample_json_schema)
json_lp = get_local_guided_decoding_logits_processor(
backend, json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
Expand Down
32 changes: 32 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
Expand All @@ -30,6 +31,8 @@
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
Expand Down Expand Up @@ -1056,6 +1059,35 @@ def _create_sequence_group_with_sampling(
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")

# Construct logits processors and add them to the sampling params
if (lp_params := sampling_params.logits_processor_params) is not None:
logits_processors = []

tokenizer = self.get_tokenizer(lora_request=lora_request)

# Guided decoding processor
if (gdr := lp_params.guided_decoding_request) is not None:
decoding_backend = gdr.guided_decoding_backend or \
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_decoding_backend=decoding_backend,
guided_options=gdr,
tokenizer=tokenizer)
if processor:
logits_processors.append(processor)

# Logit bias + allowed token IDs processors
processors = get_logits_processors(
logit_bias=lp_params.logit_bias,
allowed_token_ids=lp_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)

if len(logits_processors) > 0:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.extend(logits_processors)

# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
Expand Down
31 changes: 11 additions & 20 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import LogitsProcessorParams, SamplingParams
from vllm.utils import random_uuid

# torch is mocked during docs generation,
Expand Down Expand Up @@ -233,21 +233,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
self, guided_decoding_request: Optional[GuidedDecodingRequest],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logits_processor_params = LogitsProcessorParams(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
guided_decoding_request=guided_decoding_request)

return SamplingParams(
n=self.n,
Expand All @@ -273,7 +268,7 @@ def to_sampling_params(
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
logits_processor_params=logits_processor_params,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

Expand Down Expand Up @@ -418,22 +413,18 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
self, guided_decoding_request: Optional[GuidedDecodingRequest],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = get_logits_processors(
logits_processor_params = LogitsProcessorParams(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
guided_decoding_request=guided_decoding_request)

return SamplingParams(
n=self.n,
Expand All @@ -459,7 +450,7 @@ def to_sampling_params(
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
logits_processor_params=logits_processor_params,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

Expand Down
35 changes: 31 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

from fastapi import Request
from pydantic import BaseModel
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
Expand Down Expand Up @@ -127,8 +128,12 @@ async def create_chat_completion(

request_id = f"chat-{random_uuid()}"
try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
guided_decoding_request = await \
self._create_guided_decoding_request(api_request=request)
# Some requests for tools will use guided decoding
if (guided_json :=
self._get_guided_json_from_tool(request)) is not None:
guided_decoding_request.guided_json = guided_json

prompt_inputs = self._tokenize_prompt_input(
request,
Expand All @@ -139,8 +144,7 @@ async def create_chat_completion(
)

sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
guided_decoding_request,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

Expand Down Expand Up @@ -563,3 +567,26 @@ def _create_chat_logprobs(
tokenizer)))

return ChatCompletionLogProbs(content=logprobs_content)

@staticmethod
def _get_guided_json_from_tool(
request: ChatCompletionRequest
) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if request.tool_choice == "none" or request.tools is None:
return None

# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {
tool.function.name: tool.function
for tool in request.tools
}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters

return None
7 changes: 3 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ async def create_completion(self, request: CompletionRequest,
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)

guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
guided_decoding_request = await \
self._create_guided_decoding_request(api_request=request)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
Expand All @@ -108,8 +108,7 @@ async def create_completion(self, request: CompletionRequest,

for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
guided_decoding_request,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

Expand Down
35 changes: 23 additions & 12 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer

Expand Down Expand Up @@ -152,15 +152,6 @@ def create_streaming_error_response(
})
return json_str

async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.async_engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)

async def _check_model(
self,
request: AnyRequest,
Expand Down Expand Up @@ -403,3 +394,23 @@ def _get_decoded_token(logprob: Logprob,
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)

@staticmethod
async def _create_guided_decoding_request(
api_request: Union[CompletionRequest, ChatCompletionRequest]
) -> GuidedDecodingRequest:
"""Extract all of the guided decoding parameters from a frontend api
request"""
guided_json_object = None
if (api_request.response_format is not None
and api_request.response_format.type == "json_object"):
guided_json_object = True

return GuidedDecodingRequest(
guided_json=api_request.guided_json,
guided_choice=api_request.guided_choice,
guided_decoding_backend=api_request.guided_decoding_backend,
guided_grammar=api_request.guided_grammar,
guided_regex=api_request.guided_regex,
guided_whitespace_pattern=api_request.guided_whitespace_pattern,
guided_json_object=guided_json_object)
51 changes: 2 additions & 49 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,12 @@
from typing import Optional, Union
from typing import Optional

from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
get_local_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor


async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
Expand All @@ -48,26 +24,3 @@ def get_local_guided_decoding_logits_processor(
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request

# user has chosen to not use any tool
if request.tool_choice == "none":
return request

# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters

return request
Loading
Loading