-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] Generate valid tool call IDs when using `tokenizer-mode=mi…
…stral` (#12332)
- Loading branch information
Showing
8 changed files
with
149 additions
and
8 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import pytest_asyncio | ||
from huggingface_hub import snapshot_download | ||
|
||
from tests.utils import RemoteOpenAIServer | ||
from vllm.platforms import current_platform | ||
|
||
from .utils import ARGS, CONFIGS, ServerConfig | ||
|
||
|
||
# for each server config, download the model and return the config | ||
@pytest.fixture(scope="session", params=CONFIGS.keys()) | ||
def server_config(request): | ||
config = CONFIGS[request.param] | ||
|
||
if current_platform.is_rocm() and not config.get("supports_rocm", True): | ||
pytest.skip("The {} model can't be tested on the ROCm platform".format( | ||
config["model"])) | ||
|
||
# download model and tokenizer using transformers | ||
snapshot_download(config["model"]) | ||
yield CONFIGS[request.param] | ||
|
||
|
||
# run this for each server config | ||
@pytest.fixture(scope="session") | ||
def server(request, server_config: ServerConfig): | ||
model = server_config["model"] | ||
args_for_model = server_config["arguments"] | ||
with RemoteOpenAIServer(model, ARGS + args_for_model, | ||
max_wait_seconds=480) as server: | ||
yield server | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def client(server: RemoteOpenAIServer): | ||
async with server.get_async_client() as async_client: | ||
yield async_client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import openai | ||
import pytest | ||
|
||
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL | ||
|
||
|
||
# test: a tool_choice with mistral-tokenizer results in an ID of length 9 | ||
@pytest.mark.asyncio | ||
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): | ||
models = await client.models.list() | ||
model_name: str = models.data[0].id | ||
chat_completion = await client.chat.completions.create( | ||
messages=MESSAGES_ASKING_FOR_TOOLS, | ||
temperature=0, | ||
max_completion_tokens=100, | ||
model=model_name, | ||
tools=[WEATHER_TOOL], | ||
tool_choice=WEATHER_TOOL, | ||
logprobs=False) | ||
|
||
choice = chat_completion.choices[0] | ||
|
||
assert choice.finish_reason != "tool_calls" # "stop" or "length" | ||
assert choice.message.role == "assistant" | ||
assert choice.message.tool_calls is None \ | ||
or len(choice.message.tool_calls) == 1 | ||
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Dict, List, Optional | ||
|
||
from typing_extensions import TypedDict | ||
|
||
|
||
class ServerConfig(TypedDict, total=False): | ||
model: str | ||
arguments: List[str] | ||
system_prompt: Optional[str] | ||
supports_parallel: Optional[bool] | ||
supports_rocm: Optional[bool] | ||
|
||
|
||
ARGS: List[str] = ["--max-model-len", "1024"] | ||
|
||
CONFIGS: Dict[str, ServerConfig] = { | ||
"mistral": { | ||
"model": | ||
"mistralai/Mistral-7B-Instruct-v0.3", | ||
"arguments": [ | ||
"--tokenizer-mode", "mistral", | ||
"--ignore-patterns=\"consolidated.safetensors\"" | ||
], | ||
"system_prompt": | ||
"You are a helpful assistant with access to tools. If a tool" | ||
" that you have would be helpful to answer a user query, " | ||
"call the tool. Otherwise, answer the user's query directly " | ||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " | ||
"to the user's question - just respond to it normally." | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls | ||
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, | ||
truncate_tool_call_ids) | ||
|
||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"] | ||
__all__ = [ | ||
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters