Skip to content

Commit

Permalink
[Frontend] Generate valid tool call IDs when using `tokenizer-mode=mi…
Browse files Browse the repository at this point in the history
…stral` (#12332)
  • Loading branch information
rafvasq authored Feb 12, 2025
1 parent 985b4a2 commit 314cfad
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 8 deletions.
Empty file.
40 changes: 40 additions & 0 deletions tests/mistral_tool_use/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download

from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform

from .utils import ARGS, CONFIGS, ServerConfig


# for each server config, download the model and return the config
@pytest.fixture(scope="session", params=CONFIGS.keys())
def server_config(request):
config = CONFIGS[request.param]

if current_platform.is_rocm() and not config.get("supports_rocm", True):
pytest.skip("The {} model can't be tested on the ROCm platform".format(
config["model"]))

# download model and tokenizer using transformers
snapshot_download(config["model"])
yield CONFIGS[request.param]


# run this for each server config
@pytest.fixture(scope="session")
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
yield server


@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
async with server.get_async_client() as async_client:
yield async_client
29 changes: 29 additions & 0 deletions tests/mistral_tool_use/test_mistral_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0

import openai
import pytest

from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL


# test: a tool_choice with mistral-tokenizer results in an ID of length 9
@pytest.mark.asyncio
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice=WEATHER_TOOL,
logprobs=False)

choice = chat_completion.choices[0]

assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 1
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
33 changes: 33 additions & 0 deletions tests/mistral_tool_use/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional

from typing_extensions import TypedDict


class ServerConfig(TypedDict, total=False):
model: str
arguments: List[str]
system_prompt: Optional[str]
supports_parallel: Optional[bool]
supports_rocm: Optional[bool]


ARGS: List[str] = ["--max-model-len", "1024"]

CONFIGS: Dict[str, ServerConfig] = {
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--tokenizer-mode", "mistral",
"--ignore-patterns=\"consolidated.safetensors\""
],
"system_prompt":
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
}
16 changes: 11 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids)

logger = init_logger(__name__)

Expand Down Expand Up @@ -150,11 +153,12 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)

if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
Expand Down Expand Up @@ -745,11 +749,13 @@ async def chat_completion_full_generator(
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:

tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))

Expand Down
7 changes: 5 additions & 2 deletions vllm/transformers_utils/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

from .mistral import MistralTokenizer, maybe_serialize_tool_calls
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids)

__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
__all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
]
30 changes: 30 additions & 0 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request.messages[i]["tool_calls"] = validated_tool_calls


def truncate_tool_call_ids(request: "ChatCompletionRequest"):
"""Truncates tool call IDs for Mistral's ID requirements."""
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls = message.get("tool_calls", [])
for tool_call in tool_calls:
if len(tool_call["id"]) > 9:
logger.warning(
"Truncating tool call ID: %s to %s",
tool_call["id"],
tool_call["id"][-9:],
)
tool_call["id"] = tool_call["id"][-9:]

request.messages[i]["tool_calls"] = tool_calls

elif message.get("role") in {"tool_results", "tool"}:
if "tool_call_id" in message:
tool_call_id = message["tool_call_id"]

if len(tool_call_id) > 9:
logger.warning(
"Truncating tool_call_id: %s to %s",
tool_call_id,
tool_call_id[-9:],
)
tool_call_id = tool_call_id[-9:]
request.messages[i]["tool_call_id"] = tool_call_id


def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
Expand Down

0 comments on commit 314cfad

Please sign in to comment.