Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Oct 16, 2024
1 parent 62a15fd commit 8ec8d5c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_ambiguous_task(model_id):
seed=0,
dtype="float16",
)


MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
Expand Down
18 changes: 10 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120


Task = Literal["generate", "embed"]
TaskOption = Literal["auto", Task]

Expand Down Expand Up @@ -264,22 +263,25 @@ def _resolve_task(
"embed": ModelRegistry.is_embedding_model(architectures)
}
supported_tasks: Set[Task] = {
task for task, is_supported in task_support.items() if is_supported
task
for task, is_supported in task_support.items() if is_supported
}

if task_option == "auto":
if len(supported_tasks) > 1:
msg = (f"This model supports multiple tasks: {supported_tasks}."
" Please specify one explicitly via `--task`.")
msg = (
f"This model supports multiple tasks: {supported_tasks}."
" Please specify one explicitly via `--task`.")
raise ValueError(msg)

task = next(iter(supported_tasks))
else:
if task_option not in supported_tasks:
msg = (f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)

task = task_option

return task
Expand Down
7 changes: 3 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, TaskOption,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._enabled = self._check_embedding_mode(model_config.task == "embed")
self._enabled = self._check_embedding_mode(
model_config.task == "embed")

async def create_embedding(
self,
Expand Down

0 comments on commit 8ec8d5c

Please sign in to comment.