Skip to content

Commit

Permalink
Fix unit tests reference on served_model_names
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan committed Sep 6, 2024
1 parent c454206 commit af59780
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
7 changes: 5 additions & 2 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"

BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]


@dataclass
class MockModelConfig:
Expand All @@ -37,7 +40,7 @@ async def _async_serving_chat_init():

serving_completion = OpenAIServingChat(engine,
model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
Expand All @@ -57,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():

serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
Expand Down Expand Up @@ -140,6 +141,10 @@ async def main(args):

# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]

if args.disable_log_requests:
request_logger = None
Expand All @@ -150,7 +155,7 @@ async def main(args):
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
base_model_paths,
args.response_role,
lora_modules=None,
prompt_adapters=None,
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,4 +495,3 @@ async def unload_lora_adapter(

def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)

0 comments on commit af59780

Please sign in to comment.