Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Add tool calling support #2675

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
Expand All @@ -30,7 +31,6 @@ def parse_chat_completions_request_vllm(
rolling_batch,
tokenizer,
chat_template: Optional[str] = None,
image_token: Optional[str] = None,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
Expand All @@ -47,10 +47,30 @@ def parse_chat_completions_request_vllm(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")

tool_parser = rolling_batch.get_tool_parser()
chat_params = ChatProperties(**input_map)

if chat_params.tool_choice == "required":
raise ValueError("tool_choice = \"required\" is not supported!")

if is_mistral_tokenizer:
maybe_serialize_tool_calls(chat_params)
elif chat_params.tool_choice == "auto" and tool_parser is None:
raise ValueError(
"\"auto\" tool choice requires tool_call_parser to be available")

should_parse_tools = tool_parser is not None and (hasattr(
chat_params, "tool_choice") and chat_params.tool_choice != "none")
if should_parse_tools:
chat_params = tool_parser.adjust_request(request=chat_params)

exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)

tool_dicts = None if chat_params.tools is None else [
tool.model_dump() for tool in chat_params.tools
]

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)

Expand All @@ -61,18 +81,22 @@ def parse_chat_completions_request_vllm(
messages=chat_params.messages,
chat_template=chat_template,
add_generation_prompt=True,
tools=tool_dicts,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=True,
tools=tool_dicts,
)

param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"
param["tool_parser"] = tool_parser
param["chat_params"] = chat_params

if mm_data:
param["mm_data"] = mm_data
Expand Down
1 change: 0 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
kwargs.get("is_rolling_batch"),
rolling_batch,
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
Expand Down
94 changes: 87 additions & 7 deletions engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _json_output_formatter(request_output: TextGenerationOutput):
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, _, is_last_token = best_sequence.get_next_token()
next_token, _, _, is_last_token = best_sequence.get_next_token()
if not request_output.finished:
return ""
details = get_details_dict(request_output, include_tokens=True)
Expand Down Expand Up @@ -141,7 +141,7 @@ def _json_3p_output_formatter(request_output: TextGenerationOutput):
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token()
if not request_output.finished:
return ""

Expand Down Expand Up @@ -221,7 +221,7 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput):
parameters = request_output.input.parameters
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, _, last_token = best_sequence.get_next_token()
next_token, _, _, last_token = best_sequence.get_next_token()
# with chunked prefill, we don't generate any tokens until the full prompt has been processed.
# that means we sometimes don't have a token to return
if next_token is None:
Expand All @@ -242,7 +242,7 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput):
def _jsonlines_3p_output_formatter(request_output: TextGenerationOutput):
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token()
# with chunked prefill, we don't generate any tokens until the full prompt has been processed.
# that means we sometimes don't have a token to return
if next_token is None:
Expand Down Expand Up @@ -282,6 +282,8 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
:return: formatted output
"""
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
generated_text = get_generated_text(best_sequence, request_output)
Expand All @@ -299,6 +301,46 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
if chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
"id": f"chatcmpl-tool-{id(request_output)}",
"type": "function",
"function": {
"name": chat_params.tool_choice.function.name,
"arguments": generated_text
}
}]
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": "",
},
"tool_calls": tool_calls,
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
elif parameters.get("tools") and (parameters.get("tool_choice") == "auto"
or parameters.get("tool_choice") is None
) and parameters.get("tool_parser"):
tool_call_info = tool_parser.extract_tool_calls(generated_text,
request=chat_params)
auto_tools_called = tool_call_info.tools_called
if auto_tools_called:
tool_calls = [t.model_dump() for t in tool_call_info.tool_calls]
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": tool_call_info.content,
},
"tool_calls": tool_calls,
"logprobs": None,
"finish_reason": "tool_calls",
}

if parameters.get("logprobs"):
logprobs = {
"content": [
Expand All @@ -317,6 +359,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
]
}
choice["logprobs"] = logprobs

prompt_tokens = len(request_output.prompt_tokens_details)
completion_tokens = len(best_sequence.tokens)
usage = {
Expand All @@ -341,16 +384,52 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):
:return: formatted output
"""
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token()
# with chunked prefill, we don't generate any tokens until the full prompt has been processed.
# that means we sometimes don't have a token to return
if next_token is None:
return ""

created = int(time.time())
delta = {"content": next_token.text}

if chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
"index": 0,
"function": {
"name": chat_params.tool_choice.function.name,
"arguments": next_token.text
}
}]
delta = {"tool_calls": tool_calls}
elif parameters.get("tools") and (parameters.get("tool_choice") == "auto"
or parameters.get("tool_choice") is None
) and parameters.get("tool_parser"):
current_text = get_generated_text(best_sequence, request_output)
previous_text = current_text[0:-len(next_token.text)]
current_token_ids = [t.id for t in best_sequence.tokens]
previous_token_ids = current_token_ids[:-1]
tool_call_info = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=next_token.text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=[next_token.id],
request=chat_params)
if tool_call_info is None:
return ""
tool_calls = [
t.model_dump(exclude_none=True) for t in tool_call_info.tool_calls
]
delta = {"tool_calls": tool_calls}
else:
delta = {"content": next_token.text}
if first_token:
delta["role"] = "assistant"

Expand Down Expand Up @@ -423,7 +502,8 @@ def adapt_legacy_output_formatter(request_output: TextGenerationOutput) -> str:
elif best_sequence.finish_reason == "error":
details_dict["finish_reason"] = best_sequence.finish_reason

next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token(
)
if last_token:
for token in best_sequence.tokens:
generated_text += token.text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
import json
from enum import Enum
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator

from djl_python.properties_manager.properties import Properties
from vllm.entrypoints.openai.tool_parsers import ToolParserManager


class VllmRbProperties(Properties):
Expand Down Expand Up @@ -74,6 +73,10 @@ class VllmRbProperties(Properties):
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None

# Tool calling properties
enable_auto_tool_choice: Optional[bool] = False
tool_call_parser: Optional[str] = None

@field_validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
Expand Down Expand Up @@ -147,3 +150,12 @@ def validate_pipeline_parallel(self):
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self

@model_validator(mode='after')
def validate_tool_call_parser(self):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if self.enable_auto_tool_choice \
and self.tool_call_parser not in valid_tool_parses:
raise ValueError(
f"Invalid tool call parser: {self.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def get_next_token(self) -> (Token, bool, bool):
index = self._tokens_iterator.next_index()
first_token = index == 0
last_token = index == self._last_token_index
return self.tokens[index], first_token, last_token
return None, False, False
return self.tokens[index], index, first_token, last_token
return None, 0, False, False

def get_last_token(self) -> Optional[Token]:
if self._last_token_index is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def use_vllm_chat_completions(self):
"""
return False

def get_tool_parser(self):
"""
:return: the tool call parser if available
"""
return None

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,19 @@ def get_prompt_inputs(request: Request):
if multi_modal_data is not None:
prompt["multi_modal_data"] = multi_modal_data
return prompt


def maybe_serialize_tool_calls(request):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.0/vllm/transformers_utils/tokenizers/mistral.py#L34-L68
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get("tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break

request.messages[i]["tool_calls"] = validated_tool_calls
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand All @@ -23,7 +24,7 @@
update_request_cache_with_output, create_lora_request, get_lora_request,
get_engine_args_from_config, get_prompt_inputs)
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List, Optional
from typing import Callable, List, Optional

# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr(
Expand Down Expand Up @@ -55,6 +56,15 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.vllm_configs.enable_auto_tool_choice:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
self.vllm_configs.tool_call_parser)
self.tool_parser = self.tool_parser(
self.engine.tokenizer.tokenizer)
except Exception as e:
raise TypeError("Error in tool parser creation.") from e

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand All @@ -68,6 +78,9 @@ def get_huggingface_model_config(self):
def use_vllm_chat_completions(self):
return True

def get_tool_parser(self):
return self.tool_parser

def reset(self) -> None:
"""
Aborts all requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def custom_fmt_wait(request_output: TextGenerationOutput):
sequence_index = request_output.best_sequence_index
best_sequence = request_output.sequences[
request_output.best_sequence_index]
_, _, last_token = best_sequence.get_next_token()
_, _, _, last_token = best_sequence.get_next_token()
if last_token:
tokens = best_sequence.tokens
generated_text = ""
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1"
'pydantic>=2.0', "objgraph", "vllm"
]

setup(name='djl_python',
Expand Down
Loading
Loading