diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py new file mode 100644 index 0000000000000..565dfa01346cc --- /dev/null +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -0,0 +1,9 @@ +import pytest + +from vllm import LLM + + +def test_empty_prompt(): + llm = LLM(model="gpt2") + with pytest.raises(ValueError, match='Prompt cannot be empty'): + llm.generate([""]) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py new file mode 100644 index 0000000000000..0a573a0066d32 --- /dev/null +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -0,0 +1,22 @@ +# imports for guided decoding tests +import re + +import openai +import pytest + +from ...utils import RemoteOpenAIServer + + +@pytest.mark.asyncio +async def test_empty_prompt(): + model_name = "gpt2" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + with pytest.raises(openai.BadRequestError, + match=re.compile('.+Prompt cannot be empty.+')): + await client.completions.create(model=model_name, + prompt="", + max_tokens=5, + temperature=0.0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3526c3f6e898e..84eb778cd2a22 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -594,6 +594,7 @@ def _add_processed_request( prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, ) -> None: + self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -1691,3 +1692,10 @@ def is_encoder_decoder_model(self): def is_embedding_model(self): return self.model_config.is_embedding_model + + def _validate_model_inputs(self, inputs: Union[LLMInputs, + EncoderDecoderLLMInputs]): + prompt_key = "encoder_prompt_token_ids" \ + if self.is_encoder_decoder_model() else "prompt_token_ids" + if not inputs.get(prompt_key): + raise ValueError("Prompt cannot be empty")