diff --git a/vllm/config.py b/vllm/config.py index 38d9108f250f0..bb84fdde1247e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -109,6 +109,26 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int: total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size + def get_max_model_len(self) -> int: + max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(self.hf_config, key, None) + if max_len_key is not None: + max_model_len = min(max_model_len, max_len_key) + return max_model_len + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ce1f0f4ece877..99fe593b4cb01 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -155,10 +155,9 @@ def create_engine_configs( parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray) - max_model_len = getattr(model_config.hf_config, - 'max_position_embeddings', float('inf')) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, max_model_len) + self.max_num_seqs, + model_config.get_max_model_len()) return model_config, cache_config, parallel_config, scheduler_config diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 309a5ee85d05a..783e96d2b5d8b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str: return prompt -async def check_length(request, prompt, model_config): - if hasattr(model_config.hf_config, "max_sequence_length"): - context_len = model_config.hf_config.max_sequence_length - elif hasattr(model_config.hf_config, "seq_length"): - context_len = model_config.hf_config.seq_length - elif hasattr(model_config.hf_config, "max_position_embeddings"): - context_len = model_config.hf_config.max_position_embeddings - elif hasattr(model_config.hf_config, "seq_length"): - context_len = model_config.hf_config.seq_length - else: - context_len = 2048 - +async def check_length(request, prompt): input_ids = tokenizer(prompt).input_ids token_num = len(input_ids) - if token_num + request.max_tokens > context_len: + if token_num + request.max_tokens > max_model_len: return create_error_response( HTTPStatus.BAD_REQUEST, - f"This model's maximum context length is {context_len} tokens. " + f"This model's maximum context length is {max_model_len} tokens. " f"However, you requested {request.max_tokens + token_num} tokens " f"({token_num} in the messages, " f"{request.max_tokens} in the completion). " @@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request): "logit_bias is not currently supported") prompt = await get_gen_prompt(request) - error_check_ret = await check_length(request, prompt, engine_model_config) + error_check_ret = await check_length(request, prompt) if error_check_ret is not None: return error_check_ret @@ -591,6 +580,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) + max_model_len = engine_model_config.get_max_model_len() # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(engine_args.tokenizer,