diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c3a6c65be1d90..cf6a6e140f564 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,11 +7,14 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + @dataclass class MockModelConfig: @@ -37,7 +40,7 @@ async def _async_serving_chat_init(): serving_completion = OpenAIServingChat(engine, model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, @@ -57,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens(): serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 32bbade256973..5ad35b6497154 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -17,6 +17,7 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -140,6 +141,10 @@ async def main(args): # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] if args.disable_log_requests: request_logger = None @@ -150,7 +155,7 @@ async def main(args): openai_serving_chat = OpenAIServingChat( engine, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=None, prompt_adapters=None, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index dd7e3aba56b74..84cf613e5fc7d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -495,4 +495,3 @@ async def unload_lora_adapter( def _is_model_supported(self, model_name): return any(model.name == model_name for model in self.base_model_paths) -