From f634ba755923f7938e3bdcd67e209c6a828f9596 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 30 Jan 2025 20:21:50 -0800 Subject: [PATCH] [python] Add tool calling support --- .../chat_completions/vllm_chat_utils.py | 26 ++- .../python/setup/djl_python/input_parser.py | 1 - .../setup/djl_python/output_formatter.py | 94 ++++++++++- .../properties_manager/vllm_rb_properties.py | 16 +- engines/python/setup/djl_python/request_io.py | 4 +- .../djl_python/rolling_batch/rolling_batch.py | 6 + .../rolling_batch/rolling_batch_vllm_utils.py | 16 ++ .../rolling_batch/vllm_rolling_batch.py | 17 +- .../djl_python/tests/test_rolling_batch.py | 2 +- engines/python/setup/setup.py | 2 +- tests/integration/llm/client.py | 152 +++++++++++++++++- tests/integration/llm/prepare.py | 16 +- tests/integration/tests.py | 12 ++ 13 files changed, 345 insertions(+), 19 deletions(-) diff --git a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py index 8cea73c56..fddd85585 100644 --- a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py +++ b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py @@ -14,6 +14,7 @@ from djl_python.chat_completions.vllm_chat_properties import ChatProperties from djl_python.properties_manager.properties import Properties +from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, @@ -30,7 +31,6 @@ def parse_chat_completions_request_vllm( rolling_batch, tokenizer, chat_template: Optional[str] = None, - image_token: Optional[str] = None, configs: Properties = None, is_mistral_tokenizer: bool = False, ): @@ -47,10 +47,30 @@ def parse_chat_completions_request_vllm( f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, " f"please ensure that your tokenizer supports chat templates.") + tool_parser = rolling_batch.get_tool_parser() chat_params = ChatProperties(**input_map) + + if chat_params.tool_choice == "required": + raise ValueError("tool_choice = \"required\" is not supported!") + + if is_mistral_tokenizer: + maybe_serialize_tool_calls(chat_params) + elif chat_params.tool_choice == "auto" and tool_parser is None: + raise ValueError( + "\"auto\" tool choice requires tool_call_parser to be available") + + should_parse_tools = tool_parser is not None and (hasattr( + chat_params, "tool_choice") and chat_params.tool_choice != "none") + if should_parse_tools: + chat_params = tool_parser.adjust_request(request=chat_params) + exclude = {"messages"} param = chat_params.model_dump(exclude_none=True, exclude=exclude) + tool_dicts = None if chat_params.tools is None else [ + tool.model_dump() for tool in chat_params.tools + ] + conversation, mm_data = parse_chat_messages( chat_params.messages, rolling_batch.get_model_config(), tokenizer) @@ -61,6 +81,7 @@ def parse_chat_completions_request_vllm( messages=chat_params.messages, chat_template=chat_template, add_generation_prompt=True, + tools=tool_dicts, ) else: text_inputs = apply_hf_chat_template( @@ -68,11 +89,14 @@ def parse_chat_completions_request_vllm( conversation=conversation, chat_template=chat_template, add_generation_prompt=True, + tools=tool_dicts, ) param["details"] = True # Enable details for chat completions param[ "output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat" + param["tool_parser"] = tool_parser + param["chat_params"] = chat_params if mm_data: param["mm_data"] = mm_data diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index f4832ddbe..f21626546 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -150,7 +150,6 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input, kwargs.get("is_rolling_batch"), rolling_batch, tokenizer, - image_token=image_token, configs=configs, is_mistral_tokenizer=is_mistral_tokenizer, ) diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index 7b97dab37..8abe39e66 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -110,7 +110,7 @@ def _json_output_formatter(request_output: TextGenerationOutput): request_output.best_sequence_index] # TODO: Fix this so it is not required. Right now, this call is needed to # advance the token iterator, which is needed for rolling batch to work properly - next_token, _, is_last_token = best_sequence.get_next_token() + next_token, _, _, is_last_token = best_sequence.get_next_token() if not request_output.finished: return "" details = get_details_dict(request_output, include_tokens=True) @@ -141,7 +141,7 @@ def _json_3p_output_formatter(request_output: TextGenerationOutput): request_output.best_sequence_index] # TODO: Fix this so it is not required. Right now, this call is needed to # advance the token iterator, which is needed for rolling batch to work properly - next_token, first_token, last_token = best_sequence.get_next_token() + next_token, index, first_token, last_token = best_sequence.get_next_token() if not request_output.finished: return "" @@ -221,7 +221,7 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput): parameters = request_output.input.parameters best_sequence = request_output.sequences[ request_output.best_sequence_index] - next_token, _, last_token = best_sequence.get_next_token() + next_token, _, _, last_token = best_sequence.get_next_token() # with chunked prefill, we don't generate any tokens until the full prompt has been processed. # that means we sometimes don't have a token to return if next_token is None: @@ -242,7 +242,7 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput): def _jsonlines_3p_output_formatter(request_output: TextGenerationOutput): best_sequence = request_output.sequences[ request_output.best_sequence_index] - next_token, first_token, last_token = best_sequence.get_next_token() + next_token, index, first_token, last_token = best_sequence.get_next_token() # with chunked prefill, we don't generate any tokens until the full prompt has been processed. # that means we sometimes don't have a token to return if next_token is None: @@ -282,6 +282,8 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): :return: formatted output """ parameters = request_output.input.parameters + chat_params = parameters.get("chat_params") + tool_parser = parameters.get("tool_parser") best_sequence = request_output.sequences[ request_output.best_sequence_index] generated_text = get_generated_text(best_sequence, request_output) @@ -299,6 +301,46 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): "logprobs": None, "finish_reason": best_sequence.finish_reason, } + if chat_params and chat_params.tool_choice and type( + chat_params.tool_choice + ).__name__ == "ChatCompletionNamedToolChoiceParam": + tool_calls = [{ + "id": f"chatcmpl-tool-{id(request_output)}", + "type": "function", + "function": { + "name": chat_params.tool_choice.function.name, + "arguments": generated_text + } + }] + choice = { + "index": 0, + "message": { + "role": "assistant", + "content": "", + }, + "tool_calls": tool_calls, + "logprobs": None, + "finish_reason": best_sequence.finish_reason, + } + elif parameters.get("tools") and (parameters.get("tool_choice") == "auto" + or parameters.get("tool_choice") is None + ) and parameters.get("tool_parser"): + tool_call_info = tool_parser.extract_tool_calls(generated_text, + request=chat_params) + auto_tools_called = tool_call_info.tools_called + if auto_tools_called: + tool_calls = [t.model_dump() for t in tool_call_info.tool_calls] + choice = { + "index": 0, + "message": { + "role": "assistant", + "content": tool_call_info.content, + }, + "tool_calls": tool_calls, + "logprobs": None, + "finish_reason": "tool_calls", + } + if parameters.get("logprobs"): logprobs = { "content": [ @@ -317,6 +359,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): ] } choice["logprobs"] = logprobs + prompt_tokens = len(request_output.prompt_tokens_details) completion_tokens = len(best_sequence.tokens) usage = { @@ -341,16 +384,52 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput): :return: formatted output """ parameters = request_output.input.parameters + chat_params = parameters.get("chat_params") + tool_parser = parameters.get("tool_parser") best_sequence = request_output.sequences[ request_output.best_sequence_index] - next_token, first_token, last_token = best_sequence.get_next_token() + next_token, index, first_token, last_token = best_sequence.get_next_token() # with chunked prefill, we don't generate any tokens until the full prompt has been processed. # that means we sometimes don't have a token to return if next_token is None: return "" created = int(time.time()) - delta = {"content": next_token.text} + + if chat_params and chat_params.tool_choice and type( + chat_params.tool_choice + ).__name__ == "ChatCompletionNamedToolChoiceParam": + tool_calls = [{ + "index": 0, + "function": { + "name": chat_params.tool_choice.function.name, + "arguments": next_token.text + } + }] + delta = {"tool_calls": tool_calls} + elif parameters.get("tools") and (parameters.get("tool_choice") == "auto" + or parameters.get("tool_choice") is None + ) and parameters.get("tool_parser"): + current_text = get_generated_text(best_sequence, request_output) + previous_text = current_text[0:-len(next_token.text)] + current_token_ids = [t.id for t in best_sequence.tokens] + previous_token_ids = current_token_ids[:-1] + tool_call_info = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=next_token.text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=[next_token.id], + request=chat_params) + if tool_call_info is None: + return "" + tool_calls = [ + t.model_dump(exclude_none=True) for t in tool_call_info.tool_calls + ] + delta = {"tool_calls": tool_calls} + else: + delta = {"content": next_token.text} if first_token: delta["role"] = "assistant" @@ -423,7 +502,8 @@ def adapt_legacy_output_formatter(request_output: TextGenerationOutput) -> str: elif best_sequence.finish_reason == "error": details_dict["finish_reason"] = best_sequence.finish_reason - next_token, first_token, last_token = best_sequence.get_next_token() + next_token, index, first_token, last_token = best_sequence.get_next_token( + ) if last_token: for token in best_sequence.tokens: generated_text += token.text diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index 2d391ce3a..5d85e6d52 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -11,13 +11,12 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. import ast -import json -from enum import Enum from typing import Optional, Any, Mapping, Tuple, Dict from pydantic import field_validator, model_validator from djl_python.properties_manager.properties import Properties +from vllm.entrypoints.openai.tool_parsers import ToolParserManager class VllmRbProperties(Properties): @@ -74,6 +73,10 @@ class VllmRbProperties(Properties): qlora_adapter_name_or_path: Optional[str] = None disable_logprobs_during_spec_decoding: Optional[bool] = None + # Tool calling properties + enable_auto_tool_choice: Optional[bool] = False + tool_call_parser: Optional[str] = None + @field_validator('engine') def validate_engine(cls, engine): if engine != "Python": @@ -147,3 +150,12 @@ def validate_pipeline_parallel(self): "Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation" ) return self + + @model_validator(mode='after') + def validate_tool_call_parser(self): + valid_tool_parses = ToolParserManager.tool_parsers.keys() + if self.enable_auto_tool_choice \ + and self.tool_call_parser not in valid_tool_parses: + raise ValueError( + f"Invalid tool call parser: {self.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})") diff --git a/engines/python/setup/djl_python/request_io.py b/engines/python/setup/djl_python/request_io.py index 20c85fb33..7f66fa9a2 100644 --- a/engines/python/setup/djl_python/request_io.py +++ b/engines/python/setup/djl_python/request_io.py @@ -118,8 +118,8 @@ def get_next_token(self) -> (Token, bool, bool): index = self._tokens_iterator.next_index() first_token = index == 0 last_token = index == self._last_token_index - return self.tokens[index], first_token, last_token - return None, False, False + return self.tokens[index], index, first_token, last_token + return None, 0, False, False def get_last_token(self) -> Optional[Token]: if self._last_token_index is not None: diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 66e566a68..764fac394 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -120,6 +120,12 @@ def use_vllm_chat_completions(self): """ return False + def get_tool_parser(self): + """ + :return: the tool call parser if available + """ + return None + @abstractmethod def inference(self, new_requests: List[Request]) -> List: """ diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index 688c00484..22b6b2f73 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -290,3 +290,19 @@ def get_prompt_inputs(request: Request): if multi_modal_data is not None: prompt["multi_modal_data"] = multi_modal_data return prompt + + +def maybe_serialize_tool_calls(request): + # Adapted from https://github.com/vllm-project/vllm/blob/v0.7.0/vllm/transformers_utils/tokenizers/mistral.py#L34-L68 + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls_validator = message.get("tool_calls", ().__iter__()) + validated_tool_calls = [] + while True: + try: + tool_call = next(tool_calls_validator) # type: ignore + validated_tool_calls.append(tool_call) + except StopIteration: + break + + request.messages[i]["tool_calls"] = validated_tool_calls diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 33022176f..7a73dbcc7 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -10,11 +10,12 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -import logging from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams from vllm.sampling_params import RequestOutputKind +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -23,7 +24,7 @@ update_request_cache_with_output, create_lora_request, get_lora_request, get_engine_args_from_config, get_prompt_inputs) from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties -from typing import List, Optional +from typing import Callable, List, Optional # FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr( @@ -55,6 +56,15 @@ def __init__(self, model_id_or_path: str, properties: dict, self.lora_id_counter = AtomicCounter(0) self.lora_requests = {} self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral' + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.vllm_configs.enable_auto_tool_choice: + try: + self.tool_parser = ToolParserManager.get_tool_parser( + self.vllm_configs.tool_call_parser) + self.tool_parser = self.tool_parser( + self.engine.tokenizer.tokenizer) + except Exception as e: + raise TypeError("Error in tool parser creation.") from e def get_tokenizer(self): return self.engine.tokenizer.tokenizer @@ -68,6 +78,9 @@ def get_huggingface_model_config(self): def use_vllm_chat_completions(self): return True + def get_tool_parser(self): + return self.tool_parser + def reset(self) -> None: """ Aborts all requests diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py index 63403da5a..372368f2f 100644 --- a/engines/python/setup/djl_python/tests/test_rolling_batch.py +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -1047,7 +1047,7 @@ def custom_fmt_wait(request_output: TextGenerationOutput): sequence_index = request_output.best_sequence_index best_sequence = request_output.sequences[ request_output.best_sequence_index] - _, _, last_token = best_sequence.get_next_token() + _, _, _, last_token = best_sequence.get_next_token() if last_token: tokens = best_sequence.tokens generated_text = "" diff --git a/engines/python/setup/setup.py b/engines/python/setup/setup.py index 1cf04e748..679228f0d 100644 --- a/engines/python/setup/setup.py +++ b/engines/python/setup/setup.py @@ -58,7 +58,7 @@ def run(self): test_requirements = [ 'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops', 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', - 'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1" + 'pydantic>=2.0', "objgraph", "vllm" ] setup(name='djl_python', diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 2728f5325..ca61a0b0c 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -601,7 +601,20 @@ def get_model_name(): "batch_size": [1, 4], "seq_length": [256], "tokenizer": "TheBloke/Llama-2-7B-Chat-fp16" - } + }, +} + +vllm_tool_model_spec = { + "llama3-1-8b-instruct-tool": { + "batch_size": [1, 4], + "seq_length": [256], + "tokenizer": "unsloth/Meta-Llama-3.1-8B-Instruct" + }, + "mistral-7b-instruct-v03-tool": { + "batch_size": [1, 4], + "seq_length": [256], + "tokenizer": "unsloth/mistral-7b-instruct-v0.3" + }, } lmi_dist_aiccl_model_spec = { @@ -1281,6 +1294,111 @@ def batch_generation_pair(batch_size): return data[:batch_size] +def batch_generation_tool(batch_size): + data = [{ + "messages": [{ + "role": "user", + "content": "Hi! How are you doing today?" + }, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + }], + "tools": [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } + }], + "tool_choice": + "auto" + }, { + "messages": [{ + "role": "user", + "content": "Hi! How are you doing today?" + }, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + }], + "tools": [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } + }], + "tool_choice": { + "type": "function", + "function": { + "name": "get_current_weather" + } + }, + }] + + if batch_size > len(data): + # dynamically extend to support larger bs by repetition + data *= math.ceil(batch_size / len(data)) + return data[:batch_size] + + def t5_batch_generation(batch_size): input_sentences = [ "translate English to German: The house is wonderful.", @@ -1548,6 +1666,36 @@ def test_handler_rolling_batch_chat(model, model_spec): awscurl_run(req, spec.get("tokenizer", None), batch_size) +def test_handler_rolling_batch_tool(model, model_spec): + modelspec_checker(model, model_spec) + spec = model_spec[args.model] + if "worker" in spec: + check_worker_number(spec["worker"]) + stream_values = spec.get("stream", [False, True]) + # dryrun phase + req = batch_generation_tool(1)[0] + seq_length = 100 + req["max_tokens"] = seq_length + req["logprobs"] = True + req["top_logprobs"] = 1 + if "adapters" in spec: + req["adapters"] = spec.get("adapters")[0] + + for stream in stream_values: + req["stream"] = stream + LOGGER.info(f"req {req}") + res = send_json(req) + LOGGER.info(f"res: {res.content}") + # awscurl little benchmark phase + for i, batch_size in enumerate(spec["batch_size"]): + for seq_length in spec["seq_length"]: + LOGGER.info( + f"Little benchmark: concurrency {batch_size} seq_len {seq_length}" + ) + req["max_tokens"] = seq_length + awscurl_run(req, spec.get("tokenizer", None), batch_size) + + def test_handler(model, model_spec): modelspec_checker(model, model_spec) spec = model_spec[args.model] @@ -1920,6 +2068,8 @@ def run(raw_args): test_handler_rolling_batch_chat(args.model, lmi_dist_chat_model_spec) elif args.handler == "vllm_chat": test_handler_rolling_batch_chat(args.model, vllm_chat_model_spec) + elif args.handler == "vllm_tool": + test_handler_rolling_batch_tool(args.model, vllm_tool_model_spec) elif args.handler == "vllm_neo": test_handler_rolling_batch(args.model, vllm_neo_model_spec) elif args.handler == "handler_performance": diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index 14e392001..8bf55aa7e 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -1072,7 +1072,21 @@ "option.max_model_len": 8192, "option.max_rolling_batch_size": 16, "option.enforce_eager": True, - } + }, + "llama3-1-8b-instruct-tool": { + "option.model_id": "s3://djl-llm/llama-3.1-8b-instruct-hf/", + "option.tensor_parallel_degree": 4, + "option.max_rolling_batch_size": 4, + "option.enable_auto_tool_choice": True, + "option.tool_call_parser": "llama3_json", + }, + "mistral-7b-instruct-v03-tool": { + "option.model_id": "s3://djl-llm/mistral-7b-instruct-v03/", + "option.tensor_parallel_degree": 4, + "option.max_rolling_batch_size": 4, + "option.enable_auto_tool_choice": True, + "option.tool_call_parser": "mistral", + }, } vllm_neo_model_list = { diff --git a/tests/integration/tests.py b/tests/integration/tests.py index e52f18c21..f020af0ae 100644 --- a/tests/integration/tests.py +++ b/tests/integration/tests.py @@ -630,6 +630,18 @@ def test_llama_68m_speculative_eagle(self): r.launch() client.run("vllm llama-68m-speculative-eagle".split()) + def test_llama3_1_8b_instruct_tool(self): + with Runner('lmi', 'llama3-1-8b-instruct-tool') as r: + prepare.build_vllm_model("llama3-1-8b-instruct-tool") + r.launch() + client.run("vllm_tool llama3-1-8b-instruct-tool".split()) + + def test_mistral_7b_instruct_v03_tool(self): + with Runner('lmi', 'mistral-7b-instruct-v03-tool') as r: + prepare.build_vllm_model("mistral-7b-instruct-v03-tool") + r.launch() + client.run("vllm_tool mistral-7b-instruct-v03-tool".split()) + @pytest.mark.vllm @pytest.mark.lora