Skip to content

Commit

Permalink
[Fix] Add model sequence length into model config (vllm-project#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Jul 26, 2023
1 parent 82ad323 commit 58a072b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
20 changes: 20 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
Expand Down
5 changes: 2 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,9 @@ def create_engine_configs(
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
max_model_len = getattr(model_config.hf_config,
'max_position_embeddings', float('inf'))
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_model_len)
self.max_num_seqs,
model_config.get_max_model_len())
return model_config, cache_config, parallel_config, scheduler_config


Expand Down
20 changes: 5 additions & 15 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
return prompt


async def check_length(request, prompt, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = model_config.hf_config.max_position_embeddings
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
else:
context_len = 2048

async def check_length(request, prompt):
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)

if token_num + request.max_tokens > context_len:
if token_num + request.max_tokens > max_model_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
Expand Down Expand Up @@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported")

prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine_model_config)
error_check_ret = await check_length(request, prompt)
if error_check_ret is not None:
return error_check_ret

Expand Down Expand Up @@ -591,6 +580,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()

# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
Expand Down

0 comments on commit 58a072b

Please sign in to comment.