diff --git a/tests/test_config.py b/tests/test_config.py index c45cbb367326c..dc1cddd45c34e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -35,7 +35,7 @@ def test_ambiguous_task(model_id): seed=0, dtype="float16", ) - + MODEL_IDS_EXPECTED = [ ("Qwen/Qwen1.5-7B", 32768), diff --git a/vllm/config.py b/vllm/config.py index 8efa48483765a..b04bcf9be8c06 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,7 +33,6 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 - Task = Literal["generate", "embed"] TaskOption = Literal["auto", Task] @@ -264,22 +263,25 @@ def _resolve_task( "embed": ModelRegistry.is_embedding_model(architectures) } supported_tasks: Set[Task] = { - task for task, is_supported in task_support.items() if is_supported + task + for task, is_supported in task_support.items() if is_supported } if task_option == "auto": if len(supported_tasks) > 1: - msg = (f"This model supports multiple tasks: {supported_tasks}." - " Please specify one explicitly via `--task`.") + msg = ( + f"This model supports multiple tasks: {supported_tasks}." + " Please specify one explicitly via `--task`.") raise ValueError(msg) - + task = next(iter(supported_tasks)) else: if task_option not in supported_tasks: - msg = (f"This model does not support the '{task_option}' task. " - f"Supported tasks: {supported_tasks}") + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}") raise ValueError(msg) - + task = task_option return task diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 16f7a106bc1c8..ba2fec65a4762 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -10,10 +10,9 @@ import vllm.envs as envs from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, TaskOption, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig) + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4cdb7ac745d8a..00cfaae6bc83d 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -83,7 +83,8 @@ def __init__( lora_modules=None, prompt_adapters=None, request_logger=request_logger) - self._enabled = self._check_embedding_mode(model_config.task == "embed") + self._enabled = self._check_embedding_mode( + model_config.task == "embed") async def create_embedding( self,