Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Use vllm chat object #2659

Merged
merged 7 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Optional
from pydantic import Field
from vllm.entrypoints.openai.protocol import ChatCompletionRequest


class ChatProperties(ChatCompletionRequest):
"""
Chat input parameters for chat completions API.
See https://platform.openai.com/docs/api-reference/chat/create
"""

model: Optional[str] = Field(default=None, exclude=True) # Unused
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict, List, Optional, Union

from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)


def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_chat_completions_request_vllm(
input_map: Dict,
is_rolling_batch: bool,
rolling_batch,
tokenizer,
chat_template: Optional[str] = None,
image_token: Optional[str] = None,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
# Chat completions can either be a rolling batch or no-batching .
if not (is_rolling_batch or configs.batch_size == 1):
raise ValueError(
"chat completions support is not currently available for dynamic batching. "
"You must enable rolling batch to use the chat completions format."
)

if not is_mistral_tokenizer and not hasattr(tokenizer,
"apply_chat_template"):
raise AttributeError(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")

chat_params = ChatProperties(**input_map)
exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)

prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
text_inputs = apply_mistral_chat_template(
tokenizer,
messages=chat_params.messages,
chat_template=chat_template,
add_generation_prompt=True,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=True,
)

param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

if mm_data:
param.update(mm_data)

# In the case of mistral, text_inputs = List[TokenIds], else = str
return text_inputs, param
30 changes: 22 additions & 8 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from djl_python import Input
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm
from djl_python.encode_decode import decode
from djl_python.properties_manager.properties import is_rolling_batch_enabled
from djl_python.request import Request
Expand Down Expand Up @@ -140,14 +141,27 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
if configs is not None:
is_bedrock = configs.bedrock_compat
if is_chat_completions_request(input_map):
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
if type(kwargs.get("rolling_batch")).__name__ in [
"LmiDistRollingBatch", "VLLMRollingBatch"
]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to base this choice of the config option.rolling_batch=x?

Copy link
Contributor Author

@xyang16 xyang16 Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

option.rolling_batch may be auto, this will be lmi-dist or trtllm depends on which container it is. So it's hard to tell which rolling batch it is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could set a config within the RB class like use_vllm_chat_completions? I think I would prefer that since i'm not sure whether using VllmRollingBatch with Neuron (a valid use case) supports some of the utilities we are using from vllm since we're pulling those neuron's vllm repo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Added use_vllm_chat_completions()

inputs, param = parse_chat_completions_request_vllm(
input_map,
kwargs.get("is_rolling_batch"),
kwargs.get("rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
else:
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
elif is_bedrock:
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def get_tokenizer(self):
return self.engine.preprocessor.tokenizer
return self.engine.preprocessor.tokenizer.tokenizer

def get_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config if not self.is_t5_model else None

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(self, model_id_or_path: str, properties: dict,
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer

def get_model_config(self):
return self.engine.model_config

def get_huggingface_model_config(self):
return self.engine.model_config.hf_config

Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph"
'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1"
]

setup(name='djl_python',
Expand Down
Loading