From 90b4e5945ef5cee94928f3dc7e244ce730829c81 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 19 Oct 2024 02:31:58 +0800 Subject: [PATCH] [Model] Add user-configurable task for models that support both generation and embedding (#9424) Signed-off-by: Sumit Dubey --- docs/source/models/supported_models.rst | 8 ++ docs/source/models/vlm.rst | 4 +- ...ine_inference_vision_language_embedding.py | 1 + examples/openai_api_client_for_multimodal.py | 4 +- tests/conftest.py | 4 +- tests/core/test_chunked_prefill_scheduler.py | 15 ++- tests/core/test_scheduler.py | 56 ++++++----- tests/core/test_scheduler_encoder_decoder.py | 7 +- tests/distributed/test_pipeline_parallel.py | 23 ++++- tests/entrypoints/llm/test_chat.py | 92 +++++++++++++++++++ tests/entrypoints/llm/test_generate.py | 88 ------------------ tests/entrypoints/llm/test_init.py | 22 +++++ tests/entrypoints/openai/test_serving_chat.py | 2 +- tests/entrypoints/openai/test_vision.py | 2 + tests/entrypoints/test_chat_utils.py | 3 +- tests/lora/test_worker.py | 5 +- .../vision_language/test_phi3v.py | 1 + .../embedding/vision_language/test_phi3v.py | 1 + tests/models/utils.py | 6 +- tests/multimodal/test_mapper.py | 4 + tests/multimodal/test_processor_kwargs.py | 7 +- tests/quantization/test_configs.py | 3 +- tests/test_config.py | 57 ++++++++++-- tests/test_utils.py | 12 +-- tests/utils.py | 8 +- vllm/config.py | 77 +++++++++++----- vllm/core/scheduler.py | 2 +- vllm/engine/arg_utils.py | 17 +++- vllm/engine/llm_engine.py | 7 +- vllm/entrypoints/llm.py | 56 ++++++++--- vllm/entrypoints/openai/serving_embedding.py | 3 +- vllm/utils.py | 50 +++++++++- vllm/worker/worker.py | 5 +- 33 files changed, 451 insertions(+), 201 deletions(-) create mode 100644 tests/entrypoints/llm/test_chat.py create mode 100644 tests/entrypoints/llm/test_init.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b5fa83b437ac4..ee2844c8b27a0 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -294,6 +294,10 @@ Text Embedding - - ✅︎ +.. important:: + Some model architectures support both generation and embedding tasks. + In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. + Reward Modeling --------------- @@ -482,6 +486,10 @@ Multimodal Embedding - 🚧 - ✅︎ +.. important:: + Some model architectures support both generation and embedding tasks. + In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 7dd42ec1bb9c9..a7b55d1c0c1ff 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -181,8 +181,8 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc .. code-block:: bash - vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ - --trust-remote-code --limit-mm-per-prompt image=2 + vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 .. important:: Since OpenAI Vision API is based on `Chat Completions `_ API, diff --git a/examples/offline_inference_vision_language_embedding.py b/examples/offline_inference_vision_language_embedding.py index 8e62199e1db7b..cfedd145a015d 100644 --- a/examples/offline_inference_vision_language_embedding.py +++ b/examples/offline_inference_vision_language_embedding.py @@ -7,6 +7,7 @@ # Create an LLM. llm = LLM( model="TIGER-Lab/VLM2Vec-Full", + task="embedding", trust_remote_code=True, max_model_len=4096, max_num_seqs=2, diff --git a/examples/openai_api_client_for_multimodal.py b/examples/openai_api_client_for_multimodal.py index 704236be72d03..beb83e494ed0b 100644 --- a/examples/openai_api_client_for_multimodal.py +++ b/examples/openai_api_client_for_multimodal.py @@ -7,8 +7,8 @@ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja (multi-image inference with Phi-3.5-vision-instruct) -vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ - --trust-remote-code --limit-mm-per-prompt image=2 +vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 (audio inference with Ultravox) vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096 diff --git a/tests/conftest.py b/tests/conftest.py index 5df7da9ee64e2..ea7156c60e334 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TokenizerPoolConfig +from vllm.config import TaskOption, TokenizerPoolConfig from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel, @@ -619,6 +619,7 @@ class VllmRunner: def __init__( self, model_name: str, + task: TaskOption = "auto", tokenizer_name: Optional[str] = None, # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. @@ -634,6 +635,7 @@ def __init__( ) -> None: self.model = LLM( model=model_name, + task=task, tokenizer=tokenizer_name, trust_remote_code=True, dtype=dtype, diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index f97caa06ff02d..308dad1850c9a 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -33,7 +33,8 @@ def test_simple(): num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, num_seq_group, max_model_len, enable_chunked_prefill=True) @@ -78,6 +79,7 @@ def test_chunk(): max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -126,6 +128,7 @@ def test_complex(): max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -196,6 +199,7 @@ def test_maximal_decoding(): max_model_len = 8 max_num_batched_tokens = 2 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -289,6 +293,7 @@ def test_prompt_limit(): max_model_len = 64 max_num_batched_tokens = 32 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -321,7 +326,8 @@ def test_prompt_limit_exceed(): max_seqs = 64 max_model_len = 32 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True) @@ -348,6 +354,7 @@ def test_swap(): max_model_len = 200 max_num_batched_tokens = 30 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -404,6 +411,7 @@ def test_running_prefill_prioritized_over_swap(): max_model_len = 200 max_num_batched_tokens = 30 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -498,6 +506,7 @@ def test_chunked_prefill_preempt(): max_model_len = 200 max_num_batched_tokens = 30 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -563,6 +572,7 @@ def test_chunked_prefill_max_seqs(): max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, @@ -617,6 +627,7 @@ def test_perfix_caching(): max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index defa6c1bdaf78..00b6349b9f8c5 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -20,9 +20,10 @@ def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig( - 100, - 64, - 1, + "generate", + max_num_batched_tokens=100, + max_num_seqs=64, + max_model_len=1, ) cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 @@ -42,9 +43,10 @@ def test_scheduler_add_seq_group(): def test_scheduler_abort_seq_group(): block_size = 4 scheduler_config = SchedulerConfig( - 100, - 64, - 1, + "generate", + max_num_batched_tokens=100, + max_num_seqs=64, + max_model_len=1, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 @@ -70,9 +72,10 @@ def test_scheduler_schedule_simple(): num_seq_group = 4 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, - num_seq_group, - max_model_len, + "generate", + max_num_batched_tokens=64, + max_num_seqs=num_seq_group, + max_model_len=max_model_len, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 @@ -114,9 +117,10 @@ def test_scheduler_prefill_prioritized(): max_model_len = 30 max_batched_num_tokens = 30 scheduler_config = SchedulerConfig( - max_batched_num_tokens, - 2, - max_model_len, + "generate", + max_num_batched_tokens=max_batched_num_tokens, + max_num_seqs=2, + max_model_len=max_model_len, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 @@ -145,9 +149,10 @@ def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, - 2, - max_model_len, + "generate", + max_num_batched_tokens=64, + max_num_seqs=2, + max_model_len=max_model_len, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 @@ -204,9 +209,10 @@ def test_scheduler_max_seqs(): max_seq_group = 2 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, - max_seq_group, - max_model_len, + "generate", + max_num_batched_tokens=64, + max_num_seqs=max_seq_group, + max_model_len=max_model_len, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 @@ -248,9 +254,10 @@ def test_scheduler_max_seqs(): def test_scheduler_delay_factor(): block_size = 4 scheduler_config = SchedulerConfig( - 100, - 64, - 16, + "generate", + max_num_batched_tokens=100, + max_num_seqs=64, + max_model_len=16, delay_factor=0.5, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") @@ -350,9 +357,10 @@ def initialize_scheduler( ): block_size = block_size scheduler_config = SchedulerConfig( - max_token_budget, - max_num_seqs, - max_model_len, + "generate", + max_num_batched_tokens=max_token_budget, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = num_cpu_blocks diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 50c047f30b80d..7cd0416d321ef 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -36,7 +36,12 @@ def test_scheduler_schedule_simple_encoder_decoder(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + task="generate", + max_num_batched_tokens=64, + max_num_seqs=num_seq_group, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 88d0a4ba7f57b..fee201850f203 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -11,6 +11,7 @@ import pytest +from vllm.config import TaskOption from vllm.logger import init_logger from ..utils import compare_two_settings, fork_new_process_for_each_test @@ -31,6 +32,7 @@ class ParallelSetup(NamedTuple): class PPTestSettings: parallel_setups: List[ParallelSetup] distributed_backends: List[str] + task: TaskOption trust_remote_code: bool tokenizer_mode: Optional[str] @@ -39,6 +41,7 @@ def detailed( *, tp_base: int = 1, pp_base: int = 2, + task: TaskOption = "auto", trust_remote_code: bool = False, tokenizer_mode: Optional[str] = None, ): @@ -66,6 +69,7 @@ def detailed( chunked_prefill=False), ], distributed_backends=["mp", "ray"], + task=task, trust_remote_code=trust_remote_code, tokenizer_mode=tokenizer_mode, ) @@ -75,6 +79,7 @@ def fast( *, tp_base: int = 1, pp_base: int = 2, + task: TaskOption = "auto", trust_remote_code: bool = False, tokenizer_mode: Optional[str] = None, ): @@ -86,6 +91,7 @@ def fast( chunked_prefill=False), ], distributed_backends=["mp"], + task=task, trust_remote_code=trust_remote_code, tokenizer_mode=tokenizer_mode, ) @@ -94,7 +100,7 @@ def iter_params(self, model_name: str): for parallel_setup in self.parallel_setups: for distributed_backend in self.distributed_backends: yield (model_name, parallel_setup, distributed_backend, - self.trust_remote_code, self.tokenizer_mode) + self.task, self.trust_remote_code, self.tokenizer_mode) # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU @@ -213,6 +219,7 @@ def _compare_tp( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + task: TaskOption, trust_remote_code: bool, tokenizer_mode: Optional[str], num_gpus_available: int, @@ -240,6 +247,8 @@ def _compare_tp( common_args.append("--enable-chunked-prefill") if eager_mode: common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) if trust_remote_code: common_args.append("--trust-remote-code") if tokenizer_mode: @@ -297,7 +306,7 @@ def _compare_tp( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", + ("model_name", "parallel_setup", "distributed_backend", "task", "trust_remote_code", "tokenizer_mode"), [ params for model_name, settings in GENERATION_MODEL_SETTINGS.items() @@ -310,6 +319,7 @@ def test_tp_language_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + task: TaskOption, trust_remote_code: bool, tokenizer_mode: Optional[str], num_gpus_available, @@ -317,6 +327,7 @@ def test_tp_language_generation( _compare_tp(model_name, parallel_setup, distributed_backend, + task, trust_remote_code, tokenizer_mode, num_gpus_available, @@ -324,7 +335,7 @@ def test_tp_language_generation( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", + ("model_name", "parallel_setup", "distributed_backend", "task", "trust_remote_code", "tokenizer_mode"), [ params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items() @@ -337,6 +348,7 @@ def test_tp_language_embedding( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + task: TaskOption, trust_remote_code: bool, tokenizer_mode: Optional[str], num_gpus_available, @@ -344,6 +356,7 @@ def test_tp_language_embedding( _compare_tp(model_name, parallel_setup, distributed_backend, + task, trust_remote_code, tokenizer_mode, num_gpus_available, @@ -351,7 +364,7 @@ def test_tp_language_embedding( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", + ("model_name", "parallel_setup", "distributed_backend", "task", "trust_remote_code", "tokenizer_mode"), [ params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items() @@ -364,6 +377,7 @@ def test_tp_multimodal_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + task: TaskOption, trust_remote_code: bool, tokenizer_mode: Optional[str], num_gpus_available, @@ -371,6 +385,7 @@ def test_tp_multimodal_generation( _compare_tp(model_name, parallel_setup, distributed_backend, + task, trust_remote_code, tokenizer_mode, num_gpus_available, diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py new file mode 100644 index 0000000000000..b57348a4d9a58 --- /dev/null +++ b/tests/entrypoints/llm/test_chat.py @@ -0,0 +1,92 @@ +from typing import List + +import pytest + +from vllm import LLM + +from ..openai.test_vision import TEST_IMAGE_URLS + + +def test_chat(): + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 + + +def test_multi_chat(): + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +def test_chat_multi_image(image_urls: List[str]): + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + dtype="bfloat16", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + trust_remote_code=True, + limit_mm_per_prompt={"image": 2}, + ) + + messages = [{ + "role": + "user", + "content": [ + *({ + "type": "image_url", + "image_url": { + "url": image_url + } + } for image_url in image_urls), + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + outputs = llm.chat(messages) + assert len(outputs) >= 0 diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 6543c4bb1b58e..5e32d7baabe4b 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -6,7 +6,6 @@ from vllm import LLM, RequestOutput, SamplingParams from ...conftest import cleanup -from ..openai.test_vision import TEST_IMAGE_URLS MODEL_NAME = "facebook/opt-125m" @@ -104,90 +103,3 @@ def test_multiple_sampling_params(llm: LLM): # sampling_params is None, default params should be applied outputs = llm.generate(PROMPTS, sampling_params=None) assert len(PROMPTS) == len(outputs) - - -def test_chat(): - - llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") - - prompt1 = "Explain the concept of entropy." - messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, - ] - outputs = llm.chat(messages) - assert len(outputs) == 1 - - -def test_multi_chat(): - - llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") - - prompt1 = "Explain the concept of entropy." - prompt2 = "Explain what among us is." - - conversation1 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, - ] - - conversation2 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, - ] - - messages = [conversation1, conversation2] - - outputs = llm.chat(messages) - assert len(outputs) == 2 - - -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) -def test_chat_multi_image(image_urls: List[str]): - llm = LLM( - model="microsoft/Phi-3.5-vision-instruct", - dtype="bfloat16", - max_model_len=4096, - max_num_seqs=5, - enforce_eager=True, - trust_remote_code=True, - limit_mm_per_prompt={"image": 2}, - ) - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] - outputs = llm.chat(messages) - assert len(outputs) >= 0 diff --git a/tests/entrypoints/llm/test_init.py b/tests/entrypoints/llm/test_init.py new file mode 100644 index 0000000000000..c9a4ad44fea30 --- /dev/null +++ b/tests/entrypoints/llm/test_init.py @@ -0,0 +1,22 @@ +import pytest + +from vllm import LLM + +from ...utils import error_on_warning + +MODEL_NAME = "facebook/opt-125m" + + +def test_pos_args_deprecated(): + with error_on_warning(DeprecationWarning): + LLM(model=MODEL_NAME, tokenizer=MODEL_NAME) + + with error_on_warning(DeprecationWarning): + LLM(MODEL_NAME, tokenizer=MODEL_NAME) + + with pytest.warns(DeprecationWarning, match="'tokenizer'"): + LLM(MODEL_NAME, MODEL_NAME) + + with pytest.warns(DeprecationWarning, + match="'tokenizer', 'tokenizer_mode'"): + LLM(MODEL_NAME, MODEL_NAME, "auto") diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index ec550fe82c70f..d9342fad9f018 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -22,12 +22,12 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None - embedding_mode = False multimodal_config = MultiModalConfig() hf_config = MockHFConfig() diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 81d79601124a7..8311a5cb3c2d4 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -23,6 +23,8 @@ @pytest.fixture(scope="module") def server(): args = [ + "--task", + "generate", "--dtype", "bfloat16", "--max-model-len", diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6ded5102c9314..9165a1d397137 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -18,7 +18,8 @@ @pytest.fixture(scope="module") def phi3v_model_config(): return ModelConfig(PHI3V_MODEL_ID, - PHI3V_MODEL_ID, + task="generate", + tokenizer=PHI3V_MODEL_ID, tokenizer_mode="auto", trust_remote_code=True, dtype="bfloat16", diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 732e91a52c0a9..2f7ac85507425 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -15,7 +15,8 @@ def test_worker_apply_lora(sql_lora_files): worker = Worker( model_config=ModelConfig( "meta-llama/Llama-2-7b-hf", - "meta-llama/Llama-2-7b-hf", + task="auto", + tokenizer="meta-llama/Llama-2-7b-hf", tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -27,7 +28,7 @@ def test_worker_apply_lora(sql_lora_files): load_format="dummy", ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32), + scheduler_config=SchedulerConfig("generate", 32, 32, 32), device_config=DeviceConfig("cuda"), cache_config=CacheConfig(block_size=16, gpu_memory_utilization=1., diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index 12e8a961877cd..808421abd9103 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -89,6 +89,7 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, + task="generate", max_model_len=4096, max_num_seqs=2, dtype=dtype, diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py index ea6b56cd02625..0ca90e6bfa52e 100644 --- a/tests/models/embedding/vision_language/test_phi3v.py +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -28,6 +28,7 @@ def test_models( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). with vllm_runner(model, + task="embedding", max_model_len=4096, max_num_seqs=2, dtype=dtype, diff --git a/tests/models/utils.py b/tests/models/utils.py index 86a624483c58a..2ea233a9a599c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,7 +3,7 @@ import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, TaskOption from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs from vllm.utils import is_cpu @@ -248,6 +248,7 @@ def check_logprobs_close( def build_model_context(model_name: str, + task: TaskOption = "auto", tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, dtype: Optional[Union[str, torch.dtype]] = None, @@ -273,7 +274,8 @@ def build_model_context(model_name: str, model_config = ModelConfig( model_name, - tokenizer_name, + task=task, + tokenizer=tokenizer_name, tokenizer_mode="auto", trust_remote_code=trust_remote_code, dtype=dtype, diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 7d09b81060efd..13ad4a7966b9d 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -24,6 +24,7 @@ def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, @@ -67,6 +68,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype, model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, @@ -109,6 +111,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, @@ -139,6 +142,7 @@ def test_image_mapper_multi(image_assets, mm_registry, num_images): model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 7b9e0b6e5234b..5044740c3e734 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -221,6 +221,7 @@ def test_max_tokens_kwarg_overrides(num_crops): expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) @@ -256,6 +257,7 @@ def test_max_tokens_kwarg_overrides(num_crops): def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): """Ensure that max token calcs filters out invalid mm_processor_kwargs""" ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) @@ -278,12 +280,13 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): ### Test overrides for the mapper @pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) -def test_default_mapper_with_processer_kwargs(image_assets, num_crops): +def test_default_mapper_with_processor_kwargs(image_assets, num_crops): """Ensure that the mapper processor kwargs can fall back to HF models.""" # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs={"num_crops": num_crops}, limit_mm_per_prompt={"image": 1}) @@ -311,6 +314,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, init_num_crops, inference_num_crops) ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=init_kwargs, limit_mm_per_prompt={"image": 1}) @@ -348,6 +352,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, """Ensure that custom mappers filters out invalid mm_processor_kwargs""" # Should filter out the init time kwargs ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index d18233fe1aeae..cf77ccec7a191 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -57,7 +57,8 @@ def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None: try: model_config = ModelConfig(model_path, - model_path, + task="auto", + tokenizer=model_path, tokenizer_mode="auto", trust_remote_code=False, seed=0, diff --git a/tests/test_config.py b/tests/test_config.py index b89429005e1d0..69918b67607d9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,6 +2,42 @@ from vllm.config import ModelConfig + +@pytest.mark.parametrize(("model_id", "expected_task"), [ + ("facebook/opt-125m", "generate"), + ("intfloat/e5-mistral-7b-instruct", "embedding"), +]) +def test_auto_task(model_id, expected_task): + config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + assert config.task == expected_task + + +@pytest.mark.parametrize(("model_id", "bad_task"), [ + ("facebook/opt-125m", "embedding"), + ("intfloat/e5-mistral-7b-instruct", "generate"), +]) +def test_incorrect_task(model_id, bad_task): + with pytest.raises(ValueError, match=r"does not support the .* task"): + ModelConfig( + model_id, + task=bad_task, + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + MODEL_IDS_EXPECTED = [ ("Qwen/Qwen1.5-7B", 32768), ("mistralai/Mistral-7B-v0.1", 4096), @@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected): model_id, expected = model_id_expected model_config = ModelConfig( model_id, - model_id, + task="auto", + tokenizer=model_id, tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -32,7 +69,8 @@ def test_get_sliding_window(): # when use_sliding_window is False. qwen2_model_config = ModelConfig( "Qwen/Qwen1.5-7B", - "Qwen/Qwen1.5-7B", + task="auto", + tokenizer="Qwen/Qwen1.5-7B", tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -49,7 +87,8 @@ def test_get_sliding_window(): mistral_model_config = ModelConfig( "mistralai/Mistral-7B-v0.1", - "mistralai/Mistral-7B-v0.1", + task="auto", + tokenizer="mistralai/Mistral-7B-v0.1", tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -70,7 +109,8 @@ def test_rope_customization(): llama_model_config = ModelConfig( "meta-llama/Meta-Llama-3-8B-Instruct", - "meta-llama/Meta-Llama-3-8B-Instruct", + task="auto", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", @@ -82,7 +122,8 @@ def test_rope_customization(): llama_model_config = ModelConfig( "meta-llama/Meta-Llama-3-8B-Instruct", - "meta-llama/Meta-Llama-3-8B-Instruct", + task="auto", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", @@ -98,7 +139,8 @@ def test_rope_customization(): longchat_model_config = ModelConfig( "lmsys/longchat-13b-16k", - "lmsys/longchat-13b-16k", + task="auto", + tokenizer="lmsys/longchat-13b-16k", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", @@ -112,7 +154,8 @@ def test_rope_customization(): longchat_model_config = ModelConfig( "lmsys/longchat-13b-16k", - "lmsys/longchat-13b-16k", + task="auto", + tokenizer="lmsys/longchat-13b-16k", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", diff --git a/tests/test_utils.py b/tests/test_utils.py index 268e6f8194abb..0fed8e678fc76 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -59,7 +59,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with pytest.warns(DeprecationWarning, match="'old_arg'"): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) @@ -69,10 +69,10 @@ def test_deprecate_kwargs_never(): def dummy(*, old_arg: object = None, new_arg: object = None): pass - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) @@ -86,15 +86,15 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with pytest.warns(DeprecationWarning, match="'old_arg'"): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) is_deprecated = False - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) diff --git a/tests/utils.py b/tests/utils.py index 115cab80691f0..2ab7329485dfc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union import openai import pytest @@ -454,13 +454,13 @@ def multi_process_parallel( @contextmanager -def error_on_warning(): +def error_on_warning(category: Type[Warning] = Warning): """ Within the scope of this context manager, tests will fail if any warning - is emitted. + of the given category is emitted. """ with warnings.catch_warnings(): - warnings.simplefilter("error") + warnings.filterwarnings("error", category=category) yield diff --git a/vllm/config.py b/vllm/config.py index 4533fb017188c..7f8f936428543 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,8 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, - Optional, Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, + Mapping, Optional, Set, Tuple, Type, Union) import torch from transformers import PretrainedConfig @@ -33,6 +33,9 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 +Task = Literal["generate", "embedding"] +TaskOption = Literal["auto", Task] + class ModelConfig: """Configuration for the model. @@ -40,7 +43,11 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. It is also used as the content for `model_name` tag in metrics - output when `served_model_name` is not specified. + output when `served_model_name` is not specified. + task: The task to use the model for. Each vLLM instance only supports + one task, even if the same model can be used for multiple tasks. + When the model only supports one task, "auto" can be used to select + it; otherwise, you must specify explicitly which task to use. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, "slow" will always use the slow tokenizer, and @@ -108,6 +115,7 @@ class ModelConfig: def __init__(self, model: str, + task: TaskOption, tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, @@ -207,7 +215,11 @@ def __init__(self, self.override_neuron_config = override_neuron_config if is_neuron( ) else None - self._verify_embedding_mode() + + supported_tasks, task = self._resolve_task(task, self.hf_config) + self.supported_tasks = supported_tasks + self.task: Final = task + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -241,18 +253,41 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto', 'slow' or 'mistral'.") self.tokenizer_mode = tokenizer_mode - def _verify_embedding_mode(self) -> None: - architectures = getattr(self.hf_config, "architectures", []) + def _resolve_task( + self, + task_option: TaskOption, + hf_config: PretrainedConfig, + ) -> Tuple[Set[Task], Task]: + architectures = getattr(hf_config, "architectures", []) + + task_support: Dict[Task, bool] = { + # NOTE: Listed from highest to lowest priority, + # in case the model supports multiple of them + "generate": ModelRegistry.is_text_generation_model(architectures), + "embedding": ModelRegistry.is_embedding_model(architectures), + } + supported_tasks_lst: List[Task] = [ + task for task, is_supported in task_support.items() if is_supported + ] + supported_tasks = set(supported_tasks_lst) + + if task_option == "auto": + selected_task = next(iter(supported_tasks_lst)) - # TODO: Allow the same model architecture to be specified as either - # generation or embedding model - if "Phi3VForCausalLM" in architectures: - # Match both remote and local names - embedding_mode = "/VLM2Vec" in self.model + if len(supported_tasks) > 1: + logger.info( + "This model supports multiple tasks: %s. " + "Defaulting to '%s'.", supported_tasks, selected_task) else: - embedding_mode = ModelRegistry.is_embedding_model(architectures) + if task_option not in supported_tasks: + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}") + raise ValueError(msg) + + selected_task = task_option - self.embedding_mode = embedding_mode + return supported_tasks, selected_task def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -401,7 +436,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Async postprocessor is not necessary with embedding mode # since there is no token generation - if self.embedding_mode: + if self.task == "embedding": self.use_async_output_proc = False # Reminder: Please update docs/source/serving/compatibility_matrix.rst @@ -582,11 +617,6 @@ def is_encoder_decoder_model(self) -> bool: (hasattr(self.hf_config, "text_config") and getattr( self.hf_config.text_config, "is_encoder_decoder", False))) - @property - def is_embedding_model(self) -> bool: - """Extract the embedding model flag.""" - return self.embedding_mode - @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None @@ -943,6 +973,7 @@ class SchedulerConfig: """Scheduler configuration. Args: + task: The task to use the model for. max_num_batched_tokens: Maximum number of tokens to be processed in a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single @@ -957,7 +988,6 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. - embedding_mode: Whether the running model is for embedding. preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than @@ -972,13 +1002,13 @@ class SchedulerConfig: """ def __init__(self, + task: Task, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: bool = False, is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, @@ -1002,7 +1032,7 @@ def __init__(self, # for higher throughput. max_num_batched_tokens = max(max_model_len, 2048) - if embedding_mode: + if task == "embedding": # For embedding, choose specific value for higher throughput max_num_batched_tokens = max( max_num_batched_tokens, @@ -1022,12 +1052,12 @@ def __init__(self, "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens) + self.task: Final = task self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill - self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps self.multi_step_stream_outputs = multi_step_stream_outputs @@ -1239,6 +1269,7 @@ def maybe_create_spec_config( ngram_prompt_lookup_min = 0 draft_model_config = ModelConfig( model=speculative_model, + task=target_model_config.task, tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f0c8e6bab4862..8d3fce106dd2c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -313,7 +313,7 @@ def __init__( self.lora_config = lora_config version = "selfattn" - if (self.scheduler_config.embedding_mode + if (self.scheduler_config.task == "embedding" or self.cache_config.is_attention_free): version = "placeholder" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 41963dcb16922..480d3709224ba 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, - Tuple, Type, Union, cast) + Tuple, Type, Union, cast, get_args) import torch @@ -12,7 +12,7 @@ DeviceConfig, EngineConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig) + SpeculativeConfig, TaskOption, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -84,6 +84,7 @@ class EngineArgs: model: str = 'facebook/opt-125m' served_model_name: Optional[Union[str, List[str]]] = None tokenizer: Optional[str] = None + task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' trust_remote_code: bool = False @@ -198,6 +199,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=EngineArgs.model, help='Name or path of the huggingface model to use.') + parser.add_argument( + '--task', + default=EngineArgs.task, + choices=get_args(TaskOption), + help='The task to use the model for. Each vLLM instance only ' + 'supports one task, even if the same model can be used for ' + 'multiple tasks. When the model only supports one task, "auto" ' + 'can be used to select it; otherwise, you must specify explicitly ' + 'which task to use.') parser.add_argument( '--tokenizer', type=nullable_str, @@ -838,6 +848,7 @@ def from_cli_args(cls, args: argparse.Namespace): def create_model_config(self) -> ModelConfig: return ModelConfig( model=self.model, + task=self.task, # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, @@ -1026,13 +1037,13 @@ def create_engine_config(self) -> EngineConfig: " please file an issue with detailed information.") scheduler_config = SchedulerConfig( + task=model_config.task, max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, - embedding_mode=model_config.embedding_mode, is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 61c21887e6816..eede3486e5e8f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -344,7 +344,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) - if not self.model_config.embedding_mode: + if self.model_config.task != "embedding": self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. @@ -1116,7 +1116,7 @@ def _process_model_outputs(self, seq_group.metrics.model_execute_time = ( o.model_execute_time) - if self.model_config.embedding_mode: + if self.model_config.task == "embedding": self._process_sequence_group_outputs(seq_group, output) else: self.output_processor.process_prompt_logprob(seq_group, output) @@ -1855,9 +1855,6 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() - def is_embedding_model(self): - return self.model_config.is_embedding_model - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): if self.model_config.is_multimodal_model: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 088ec35798de8..1f7893d54de68 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -8,7 +8,7 @@ from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import EngineArgs, TaskOption from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, @@ -29,7 +29,7 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_kwargs, is_list_of +from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of logger = init_logger(__name__) @@ -108,6 +108,12 @@ class LLM: DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" + DEPRECATE_INIT_POSARGS: ClassVar[bool] = True + """ + A flag to toggle whether to deprecate positional arguments in + :meth:`LLM.__init__`. + """ + @classmethod @contextmanager def deprecate_legacy_api(cls): @@ -117,6 +123,13 @@ def deprecate_legacy_api(cls): cls.DEPRECATE_LEGACY = False + @deprecate_args( + start_index=2, # Ignore self and model + is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS, + additional_message=( + "All positional arguments other than `model` will be " + "replaced with keyword arguments in an upcoming version."), + ) def __init__( self, model: str, @@ -139,6 +152,8 @@ def __init__( disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "auto", **kwargs, ) -> None: ''' @@ -153,6 +168,7 @@ def __init__( engine_args = EngineArgs( model=model, + task=task, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, skip_tokenizer_init=skip_tokenizer_init, @@ -316,10 +332,21 @@ def generate( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - if self.llm_engine.model_config.embedding_mode: - raise ValueError( + task = self.llm_engine.model_config.task + if task != "generate": + messages = [ "LLM.generate() is only supported for (conditional) generation " - "models (XForCausalLM, XForConditionalGeneration).") + "models (XForCausalLM, XForConditionalGeneration).", + ] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "generate" in supported_tasks: + messages.append( + "Your model supports the 'generate' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task generate`.") + + raise ValueError(" ".join(messages)) if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( @@ -692,10 +719,18 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - if not self.llm_engine.model_config.embedding_mode: - raise ValueError( - "LLM.encode() is only supported for embedding models (XModel)." - ) + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.encode() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( @@ -905,6 +940,3 @@ def _run_engine( def _is_encoder_decoder_model(self): return self.llm_engine.is_encoder_decoder_model() - - def _is_embedding_model(self): - return self.llm_engine.is_embedding_model() diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e9504cfa64b65..6c46aae2838f6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -83,7 +83,8 @@ def __init__( lora_modules=None, prompt_adapters=None, request_logger=request_logger) - self._enabled = self._check_embedding_mode(model_config.embedding_mode) + self._enabled = self._check_embedding_mode( + model_config.task == "embedding") async def create_embedding( self, diff --git a/vllm/utils.py b/vllm/utils.py index 07769da3c86d4..0147d595fec70 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1034,10 +1034,54 @@ def identity(value: T) -> T: F = TypeVar('F', bound=Callable[..., Any]) +def deprecate_args( + start_index: int, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None, +) -> Callable[[F], F]: + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + params = inspect.signature(fn).parameters + pos_types = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + pos_kws = [ + kw for kw, param in params.items() if param.kind in pos_types + ] + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_args = pos_kws[start_index:len(args)] + if deprecated_args: + msg = ( + f"The positional arguments {deprecated_args} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + def deprecate_kwargs( - *kws: str, - is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None) -> Callable[[F], F]: + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None, +) -> Callable[[F], F]: deprecated_kws = set(kws) if not callable(is_deprecated): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9c46bb4258609..018ab5b828786 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -92,7 +92,7 @@ def __init__( ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_runner_cls is not None: ModelRunnerClass = model_runner_cls - elif self._is_embedding_model(): + elif model_config.task == "embedding": ModelRunnerClass = EmbeddingModelRunner elif self._is_encoder_decoder_model(): ModelRunnerClass = EncoderDecoderModelRunner @@ -147,9 +147,6 @@ def stop_profile(self): def _is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model - def _is_embedding_model(self): - return self.model_config.is_embedding_model - def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until