Skip to content

Commit

Permalink
[vllm] update vllm chat processing
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Feb 5, 2025
1 parent ba3950c commit 7e36ade
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 96 deletions.

This file was deleted.

263 changes: 206 additions & 57 deletions engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,91 +10,240 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any, Callable, Annotated, Tuple, Sequence

from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
from pydantic import Field
from vllm import TokensPrompt
from vllm.entrypoints.openai.serving_engine import RequestPrompt, TextTokensPrompt
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizers.mistral import maybe_serialize_tool_calls
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.chat_utils import (
apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages,
resolve_chat_template_content_format, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ConversationMessage)

from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch

# The logic in this file is heavily inspired by https://github.com/vllm-project/vllm/blob/v0.7.1/vllm/entrypoints/openai/serving_chat.py#L109
# Many of the utilities and validation logic are modified directly from vLLM's code
# TODO: Figure out a way to integrate with vLLM at a higher level than we do now to avoid this code


def parse_chat_completions_request_vllm(
input_map: Dict,
is_rolling_batch: bool,
rolling_batch,
rolling_batch: VLLMRollingBatch,
tokenizer,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
# Chat completions can either be a rolling batch or no-batching .
if not (is_rolling_batch or configs.batch_size == 1):
raise ValueError(
"chat completions support is not currently available for dynamic batching. "
"You must enable rolling batch to use the chat completions format."
)

tool_parser = rolling_batch.get_tool_parser()
chat_params = ChatProperties(**input_map)
model = input_map.pop("model", "lmi")
chat_params = ChatCompletionRequest(**input_map, model=model)

if chat_params.tool_choice == "required":
raise ValueError("tool_choice = \"required\" is not supported!")

if is_mistral_tokenizer:
if rolling_batch.is_mistral_tokenizer:
maybe_serialize_tool_calls(chat_params)
elif chat_params.tool_choice == "auto" and tool_parser is None:
if (chat_params.tool_choice == "auto"
and not (rolling_batch.vllm_configs.enable_auto_tool_choice
and tool_parser is not None)
and not rolling_batch.is_mistral_tokenizer):
raise ValueError(
"\"auto\" tool choice requires tool_call_parser to be available")

should_parse_tools = tool_parser is not None and (hasattr(
chat_params, "tool_choice") and chat_params.tool_choice != "none")
if should_parse_tools:
chat_params = tool_parser.adjust_request(request=chat_params)

exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")

tool_dicts = None if chat_params.tools is None else [
tool.model_dump() for tool in chat_params.tools
]
# TODO - figure out what we need to pass for given format
content_format = resolve_chat_template_content_format(
chat_template=None,
given_format="auto",
tokenizer=tokenizer,

conversation, request_prompt, engine_prompt, input_text = _preprocess_chat(
chat_params,
tokenizer,
chat_params.messages,
chat_params.chat_template or rolling_batch.get_chat_template(),
rolling_batch.get_chat_template_content_format(),
rolling_batch,
add_generation_prompt=chat_params.add_generation_prompt,
continue_final_message=chat_params.continue_final_message,
tool_dicts=tool_dicts,
documents=chat_params.documents,
tool_parser=tool_parser,
truncate_prompt_tokens=chat_params.truncate_prompt_tokens,
add_special_tokens=chat_params.add_special_tokens,
)

default_sampling_params = rolling_batch.get_default_sampling_params()
default_max_new_tokens = rolling_batch.engine.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = chat_params.to_sampling_params(
default_max_new_tokens,
rolling_batch.engine.model_config.logits_processor_pattern,
default_sampling_params)
params = {
"stream": chat_params.stream,
"output_formatter":
"jsonlines_chat" if chat_params.stream else "json_chat",
"sampling_params": sampling_params,
"conversation": conversation,
"request_prompts": request_prompt,
"engine_prompt": engine_prompt
}
return input_text, params


def _preprocess_chat(
request: ChatCompletionRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
rolling_batch: VLLMRollingBatch,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Dict[str, str]]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], RequestPrompt, TokensPrompt, str]:
resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template_content_format, tokenizer)
conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer,
content_format)
messages,
rolling_batch.engine.model_config,
tokenizer,
content_format=resolved_content_format,
)
chat_template_kwargs: Dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
)

request_prompt: Union[str, List[int]]
if rolling_batch.is_mistral_tokenizer:
request_prompt = apply_mistral_chat_template(tokenizer,
messages=messages,
**chat_template_kwargs)
else:
request_prompt = apply_hf_chat_template(tokenizer,
conversation=conversation,
**chat_template_kwargs)

should_parse_tools = tool_parser is not None and request.tool_choice != "none"
if should_parse_tools:
request = tool_parser(tokenizer).adjust_request(request=request)

prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
text_inputs = apply_mistral_chat_template(
if isinstance(request_prompt, str):
# Hf tokenizer case
prompt_inputs = tokenize_prompt_input(
request,
tokenizer,
chat_params.messages,
None,
tools=tool_dicts,
request_prompt,
rolling_batch.engine.model_config.max_model_len,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
text_inputs = apply_hf_chat_template(
# MistralTokenizer case
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)

engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return conversation, request_prompt, engine_prompt, prompt_inputs["prompt"]


def tokenize_prompt_input(request: ChatCompletionRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
max_model_len: int,
truncate_prompt_tokens: Optional[Annotated[
int, Field(ge=1)]] = None,
add_special_tokens: bool = True) -> TextTokensPrompt:
if isinstance(prompt_input, str):
return normalize_prompt_text_to_input(
request,
tokenizer,
prompt_input,
truncate_prompt_tokens,
add_special_tokens,
max_model_len,
)
else:
return normalize_prompt_tokens_to_input(
request,
tokenizer,
conversation,
None,
add_generation_prompt=True,
tools=tool_dicts,
prompt_input,
truncate_prompt_tokens,
max_model_len,
)

param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"
param["tool_parser"] = tool_parser
param["chat_params"] = chat_params

if mm_data:
param["mm_data"] = mm_data
def normalize_prompt_text_to_input(
request: ChatCompletionRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
max_model_len: int,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)

return validate_input(request, encoded.input_ids, prompt, max_model_len)


def normalize_prompt_tokens_to_input(
request: ChatCompletionRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
max_model_len: int,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return validate_input(request, input_ids, input_text, max_model_len)


def validate_input(
request: ChatCompletionRequest,
input_ids: List[int],
input_text: str,
max_model_len: int,
) -> TextTokensPrompt:
token_num = len(input_ids)

# chat completion endpoint supports max_completion_tokens
max_tokens = request.max_completion_tokens or request.max_tokens
if max_tokens is None:
if token_num >= max_model_len:
raise ValueError(f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + max_tokens > max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, you requested "
f"{max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")

# In the case of mistral, text_inputs = List[TokenIds], else = str
return text_inputs, param
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
3 changes: 0 additions & 3 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,8 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm
inputs, param = parse_chat_completions_request_vllm(
input_map,
kwargs.get("is_rolling_batch"),
rolling_batch,
tokenizer,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
else:
inputs, param = parse_chat_completions_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# the specific language governing permissions and limitations under the License.
import ast
import logging
from typing import Optional, Any, Dict, Tuple
from typing import Optional, Any, Dict, Tuple, Literal
from pydantic import field_validator, model_validator, ConfigDict, Field
from vllm import EngineArgs
from vllm.utils import FlexibleArgumentParser
Expand Down Expand Up @@ -80,6 +80,10 @@ class VllmRbProperties(Properties):
generation_config: Optional[Any] = None
override_neuron_config: Optional[Dict] = None

# Non engine arg properties
chat_template: Optional[str] = None
chat_template_content_format: Literal["auto", "string", "openai"] = "auto"

# This allows generic vllm engine args to be passed in and set with vllm
model_config = ConfigDict(extra='allow', populate_by_name=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_model_config(self):
return self.engine.preprocessor.model_config if not self.is_t5_model else None

def use_vllm_chat_completions(self):
# vllm chat parsing requires 0.7.0 currently, lmi-dist is on 0.6.3.post1
# vllm chat parsing requires 0.7.1 currently, lmi-dist is on 0.6.3.post1
return False

def get_huggingface_model_config(self):
Expand Down
Loading

0 comments on commit 7e36ade

Please sign in to comment.