From 7e36adeaa39872b723fc68b569e95ab18bb12f93 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Tue, 4 Feb 2025 09:52:02 -0800 Subject: [PATCH] [vllm] update vllm chat processing --- .../chat_completions/vllm_chat_properties.py | 24 -- .../chat_completions/vllm_chat_utils.py | 263 ++++++++++++++---- .../python/setup/djl_python/input_parser.py | 3 - .../properties_manager/vllm_rb_properties.py | 6 +- .../rolling_batch/lmi_dist_rolling_batch.py | 2 +- .../rolling_batch/rolling_batch_vllm_utils.py | 19 +- .../rolling_batch/vllm_rolling_batch.py | 31 ++- 7 files changed, 252 insertions(+), 96 deletions(-) delete mode 100644 engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py diff --git a/engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py b/engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py deleted file mode 100644 index b8daee9e0..000000000 --- a/engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file -# except in compliance with the License. A copy of the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" -# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for -# the specific language governing permissions and limitations under the License. -from typing import Optional -from pydantic import Field -from vllm.entrypoints.openai.protocol import ChatCompletionRequest - - -class ChatProperties(ChatCompletionRequest): - """ - Chat input parameters for chat completions API. - See https://platform.openai.com/docs/api-reference/chat/create - """ - - model: Optional[str] = Field(default=None, exclude=True) # Unused diff --git a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py index 68b646052..8efbfe58e 100644 --- a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py +++ b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py @@ -10,91 +10,240 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any, Callable, Annotated, Tuple, Sequence -from djl_python.chat_completions.vllm_chat_properties import ChatProperties -from djl_python.properties_manager.properties import Properties -from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls -from vllm.entrypoints.chat_utils import (apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages, - resolve_chat_template_content_format) +from pydantic import Field +from vllm import TokensPrompt +from vllm.entrypoints.openai.serving_engine import RequestPrompt, TextTokensPrompt +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.transformers_utils.tokenizers.mistral import maybe_serialize_tool_calls +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.entrypoints.chat_utils import ( + apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages, + resolve_chat_template_content_format, ChatCompletionMessageParam, + ChatTemplateContentFormatOption, ConversationMessage) + +from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch + +# The logic in this file is heavily inspired by https://github.com/vllm-project/vllm/blob/v0.7.1/vllm/entrypoints/openai/serving_chat.py#L109 +# Many of the utilities and validation logic are modified directly from vLLM's code +# TODO: Figure out a way to integrate with vLLM at a higher level than we do now to avoid this code def parse_chat_completions_request_vllm( input_map: Dict, - is_rolling_batch: bool, - rolling_batch, + rolling_batch: VLLMRollingBatch, tokenizer, - configs: Properties = None, - is_mistral_tokenizer: bool = False, ): - # Chat completions can either be a rolling batch or no-batching . - if not (is_rolling_batch or configs.batch_size == 1): - raise ValueError( - "chat completions support is not currently available for dynamic batching. " - "You must enable rolling batch to use the chat completions format." - ) tool_parser = rolling_batch.get_tool_parser() - chat_params = ChatProperties(**input_map) + model = input_map.pop("model", "lmi") + chat_params = ChatCompletionRequest(**input_map, model=model) if chat_params.tool_choice == "required": raise ValueError("tool_choice = \"required\" is not supported!") - if is_mistral_tokenizer: + if rolling_batch.is_mistral_tokenizer: maybe_serialize_tool_calls(chat_params) - elif chat_params.tool_choice == "auto" and tool_parser is None: + if (chat_params.tool_choice == "auto" + and not (rolling_batch.vllm_configs.enable_auto_tool_choice + and tool_parser is not None) + and not rolling_batch.is_mistral_tokenizer): raise ValueError( - "\"auto\" tool choice requires tool_call_parser to be available") - - should_parse_tools = tool_parser is not None and (hasattr( - chat_params, "tool_choice") and chat_params.tool_choice != "none") - if should_parse_tools: - chat_params = tool_parser.adjust_request(request=chat_params) - - exclude = {"messages"} - param = chat_params.model_dump(exclude_none=True, exclude=exclude) + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set") tool_dicts = None if chat_params.tools is None else [ tool.model_dump() for tool in chat_params.tools ] - # TODO - figure out what we need to pass for given format - content_format = resolve_chat_template_content_format( - chat_template=None, - given_format="auto", - tokenizer=tokenizer, + + conversation, request_prompt, engine_prompt, input_text = _preprocess_chat( + chat_params, + tokenizer, + chat_params.messages, + chat_params.chat_template or rolling_batch.get_chat_template(), + rolling_batch.get_chat_template_content_format(), + rolling_batch, + add_generation_prompt=chat_params.add_generation_prompt, + continue_final_message=chat_params.continue_final_message, + tool_dicts=tool_dicts, + documents=chat_params.documents, + tool_parser=tool_parser, + truncate_prompt_tokens=chat_params.truncate_prompt_tokens, + add_special_tokens=chat_params.add_special_tokens, ) + default_sampling_params = rolling_batch.get_default_sampling_params() + default_max_new_tokens = rolling_batch.engine.model_config.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = chat_params.to_sampling_params( + default_max_new_tokens, + rolling_batch.engine.model_config.logits_processor_pattern, + default_sampling_params) + params = { + "stream": chat_params.stream, + "output_formatter": + "jsonlines_chat" if chat_params.stream else "json_chat", + "sampling_params": sampling_params, + "conversation": conversation, + "request_prompts": request_prompt, + "engine_prompt": engine_prompt + } + return input_text, params + + +def _preprocess_chat( + request: ChatCompletionRequest, + tokenizer: AnyTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + rolling_batch: VLLMRollingBatch, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Dict[str, str]]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = False, +) -> Tuple[List[ConversationMessage], RequestPrompt, TokensPrompt, str]: + resolved_content_format = resolve_chat_template_content_format( + chat_template, chat_template_content_format, tokenizer) conversation, mm_data = parse_chat_messages( - chat_params.messages, rolling_batch.get_model_config(), tokenizer, - content_format) + messages, + rolling_batch.engine.model_config, + tokenizer, + content_format=resolved_content_format, + ) + chat_template_kwargs: Dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + + request_prompt: Union[str, List[int]] + if rolling_batch.is_mistral_tokenizer: + request_prompt = apply_mistral_chat_template(tokenizer, + messages=messages, + **chat_template_kwargs) + else: + request_prompt = apply_hf_chat_template(tokenizer, + conversation=conversation, + **chat_template_kwargs) + + should_parse_tools = tool_parser is not None and request.tool_choice != "none" + if should_parse_tools: + request = tool_parser(tokenizer).adjust_request(request=request) - prompt_data: Union[str, List[int]] - if is_mistral_tokenizer: - text_inputs = apply_mistral_chat_template( + if isinstance(request_prompt, str): + # Hf tokenizer case + prompt_inputs = tokenize_prompt_input( + request, tokenizer, - chat_params.messages, - None, - tools=tool_dicts, + request_prompt, + rolling_batch.engine.model_config.max_model_len, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, ) else: - text_inputs = apply_hf_chat_template( + # MistralTokenizer case + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) + + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + return conversation, request_prompt, engine_prompt, prompt_inputs["prompt"] + + +def tokenize_prompt_input(request: ChatCompletionRequest, + tokenizer: AnyTokenizer, + prompt_input: Union[str, List[int]], + max_model_len: int, + truncate_prompt_tokens: Optional[Annotated[ + int, Field(ge=1)]] = None, + add_special_tokens: bool = True) -> TextTokensPrompt: + if isinstance(prompt_input, str): + return normalize_prompt_text_to_input( + request, + tokenizer, + prompt_input, + truncate_prompt_tokens, + add_special_tokens, + max_model_len, + ) + else: + return normalize_prompt_tokens_to_input( + request, tokenizer, - conversation, - None, - add_generation_prompt=True, - tools=tool_dicts, + prompt_input, + truncate_prompt_tokens, + max_model_len, ) - param["details"] = True # Enable details for chat completions - param[ - "output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat" - param["tool_parser"] = tool_parser - param["chat_params"] = chat_params - if mm_data: - param["mm_data"] = mm_data +def normalize_prompt_text_to_input( + request: ChatCompletionRequest, + tokenizer: AnyTokenizer, + prompt: str, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + add_special_tokens: bool, + max_model_len: int, +) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + return validate_input(request, encoded.input_ids, prompt, max_model_len) + + +def normalize_prompt_tokens_to_input( + request: ChatCompletionRequest, + tokenizer: AnyTokenizer, + prompt_ids: List[int], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + max_model_len: int, +) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + input_ids = prompt_ids + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + input_text = tokenizer.decode(input_ids) + return validate_input(request, input_ids, input_text, max_model_len) + + +def validate_input( + request: ChatCompletionRequest, + input_ids: List[int], + input_text: str, + max_model_len: int, +) -> TextTokensPrompt: + token_num = len(input_ids) + + # chat completion endpoint supports max_completion_tokens + max_tokens = request.max_completion_tokens or request.max_tokens + if max_tokens is None: + if token_num >= max_model_len: + raise ValueError(f"This model's maximum context length is " + f"{max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.") + elif token_num + max_tokens > max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{max_model_len} tokens. However, you requested " + f"{max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.") - # In the case of mistral, text_inputs = List[TokenIds], else = str - return text_inputs, param + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index f21626546..5a4849637 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -147,11 +147,8 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input, from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm inputs, param = parse_chat_completions_request_vllm( input_map, - kwargs.get("is_rolling_batch"), rolling_batch, tokenizer, - configs=configs, - is_mistral_tokenizer=is_mistral_tokenizer, ) else: inputs, param = parse_chat_completions_request( diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index 8fae3f5f6..9488fcd06 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -12,7 +12,7 @@ # the specific language governing permissions and limitations under the License. import ast import logging -from typing import Optional, Any, Dict, Tuple +from typing import Optional, Any, Dict, Tuple, Literal from pydantic import field_validator, model_validator, ConfigDict, Field from vllm import EngineArgs from vllm.utils import FlexibleArgumentParser @@ -80,6 +80,10 @@ class VllmRbProperties(Properties): generation_config: Optional[Any] = None override_neuron_config: Optional[Dict] = None + # Non engine arg properties + chat_template: Optional[str] = None + chat_template_content_format: Literal["auto", "string", "openai"] = "auto" + # This allows generic vllm engine args to be passed in and set with vllm model_config = ConfigDict(extra='allow', populate_by_name=True) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 1ef86d98a..962e49645 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -132,7 +132,7 @@ def get_model_config(self): return self.engine.preprocessor.model_config if not self.is_t5_model else None def use_vllm_chat_completions(self): - # vllm chat parsing requires 0.7.0 currently, lmi-dist is on 0.6.3.post1 + # vllm chat parsing requires 0.7.1 currently, lmi-dist is on 0.6.3.post1 return False def get_huggingface_model_config(self): diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index 9c6c86e51..c4ea46f47 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -60,16 +60,18 @@ def update_request_cache_with_output(request_cache: OrderedDict, if not request_output.prompt_tokens_details: # TODO: Temp check adding the check for T5. if isinstance(vllm_request_output.prompt_token_ids, list): - converted_texts_from_ids = tokenizer.convert_ids_to_tokens( - vllm_request_output.prompt_token_ids) for index, prompt_token_id in enumerate( vllm_request_output.prompt_token_ids): log_prob = None if vllm_request_output.prompt_logprobs and index > 0: log_prob = vllm_request_output.prompt_logprobs[index][ prompt_token_id].logprob + # TODO: this is inefficient, but it works for now. There are implications for multimodal + # and mistral models (mistral tokenizer) when doing batch decodes. Favoring readability here + # over performance + text = tokenizer.decode(prompt_token_id) prompt_token = Token(id=prompt_token_id, - text=converted_texts_from_ids[index], + text=text, log_prob=log_prob) request_output.prompt_tokens_details.append(prompt_token) @@ -208,9 +210,18 @@ def get_lora_request(lora_name: str, lora_requests: dict) -> dict: return lora_requests[lora_name] +def get_multi_modal_data(request: Request) -> Optional[dict]: + parameters = request.parameters + images = parameters.pop("images", None) + multi_modal_data = None + if images: + multi_modal_data = {"image": images} + return multi_modal_data + + def get_prompt_inputs(request: Request): text_prompt = request.request_input.input_text - multi_modal_data = request.parameters.pop("mm_data", None) + multi_modal_data = get_multi_modal_data(request) # TODO: In chat cases, we need to apply the chat template to the messages object to get a string # In both HuggingFace and mistral cases, that process can also yield token-ids directly # that we may want to consider passing directly to the engine diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 13f722e88..3479a8ab6 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -22,7 +22,7 @@ update_request_cache_with_output, create_lora_request, get_lora_request, get_prompt_inputs) from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties -from typing import Callable, List, Optional +from typing import List, Optional VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) @@ -57,8 +57,6 @@ def __init__(self, model_id_or_path: str, properties: dict, try: self.tool_parser = ToolParserManager.get_tool_parser( self.vllm_configs.tool_call_parser) - self.tool_parser = self.tool_parser( - self.engine.tokenizer.tokenizer) except Exception as e: raise TypeError("Error in tool parser creation.") from e @@ -77,6 +75,20 @@ def use_vllm_chat_completions(self): def get_tool_parser(self): return self.tool_parser + def get_chat_template(self): + if self.is_mistral_tokenizer: + # Mistral tokenizer chat template cannot be overridden + return None + if self.vllm_configs.chat_template is None: + return self.get_tokenizer().chat_template + return self.vllm_configs.chat_template + + def get_chat_template_content_format(self): + return self.vllm_configs.chat_template_content_format + + def get_default_sampling_params(self): + return self.engine.model_config.get_diff_sampling_param() + def reset(self) -> None: """ Aborts all requests @@ -142,9 +154,16 @@ def inference(self, new_requests: List[Request]) -> List: # step 0: register new requests to engine for request in new_requests: request_id = random_uuid() - prompt_inputs = get_prompt_inputs(request) - params = self.translate_vllm_params(request.parameters) - sampling_params = SamplingParams(**params) + # Chat completions request route + if request.parameters.get("sampling_params") is not None: + prompt_inputs = request.parameters.get("engine_prompt") + sampling_params = request.parameters.get("sampling_params") + sampling_params.output_kind = RequestOutputKind.DELTA + # LMI request route + else: + prompt_inputs = get_prompt_inputs(request) + params = self.translate_vllm_params(request.parameters) + sampling_params = SamplingParams(**params) request_params = dict() if request.adapter is not None: adapter_name = request.adapter.get_property("name")