Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral #12332

Merged
merged 19 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
40 changes: 40 additions & 0 deletions tests/mistral_tool_use/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download

from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform

from .utils import ARGS, CONFIGS, ServerConfig


# for each server config, download the model and return the config
@pytest.fixture(scope="session", params=CONFIGS.keys())
def server_config(request):
config = CONFIGS[request.param]

if current_platform.is_rocm() and not config.get("supports_rocm", True):
pytest.skip("The {} model can't be tested on the ROCm platform".format(
config["model"]))

# download model and tokenizer using transformers
snapshot_download(config["model"])
yield CONFIGS[request.param]


# run this for each server config
@pytest.fixture(scope="session")
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
yield server


@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
async with server.get_async_client() as async_client:
yield async_client
29 changes: 29 additions & 0 deletions tests/mistral_tool_use/test_mistral_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0

import openai
import pytest

from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL


# test: a tool_choice with mistral-tokenizer results in an ID of length 9
@pytest.mark.asyncio
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice=WEATHER_TOOL,
logprobs=False)

choice = chat_completion.choices[0]

assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 1
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
33 changes: 33 additions & 0 deletions tests/mistral_tool_use/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional

from typing_extensions import TypedDict


class ServerConfig(TypedDict, total=False):
model: str
arguments: List[str]
system_prompt: Optional[str]
supports_parallel: Optional[bool]
supports_rocm: Optional[bool]


ARGS: List[str] = ["--max-model-len", "1024"]

CONFIGS: Dict[str, ServerConfig] = {
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--tokenizer-mode", "mistral",
"--ignore-patterns=\"consolidated.safetensors\""
],
"system_prompt":
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
}
16 changes: 11 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids)

logger = init_logger(__name__)

Expand Down Expand Up @@ -150,11 +153,12 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)

if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
Expand Down Expand Up @@ -745,11 +749,13 @@ async def chat_completion_full_generator(
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:

tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))

Expand Down
7 changes: 5 additions & 2 deletions vllm/transformers_utils/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

from .mistral import MistralTokenizer, maybe_serialize_tool_calls
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids)

__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
__all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
]
30 changes: 30 additions & 0 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request.messages[i]["tool_calls"] = validated_tool_calls


def truncate_tool_call_ids(request: "ChatCompletionRequest"):
"""Truncates tool call IDs for Mistral's ID requirements."""
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls = message.get("tool_calls", [])
for tool_call in tool_calls:
if len(tool_call["id"]) > 9:
logger.warning(
"Truncating tool call ID: %s to %s",
tool_call["id"],
tool_call["id"][-9:],
)
tool_call["id"] = tool_call["id"][-9:]

request.messages[i]["tool_calls"] = tool_calls

elif message.get("role") in {"tool_results", "tool"}:
if "tool_call_id" in message:
tool_call_id = message["tool_call_id"]

if len(tool_call_id) > 9:
logger.warning(
"Truncating tool_call_id: %s to %s",
tool_call_id,
tool_call_id[-9:],
)
tool_call_id = tool_call_id[-9:]
request.messages[i]["tool_call_id"] = tool_call_id


def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
Expand Down