Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Set server's maximum number of generated tokens using generation_config.json #12242

Merged
merged 34 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5c85448
Adding max_new_tokens support to generation_config.json
mhendrey Jan 20, 2025
4ad6b45
Changed default_max_tokens to server_max_tokens
mhendrey Jan 20, 2025
95f9c97
Renamed default_max_tokens to server_max_tokens
mhendrey Jan 20, 2025
4786e56
Removed the float("inf") bug
mhendrey Jan 20, 2025
4980a73
Renamed default_max_tokens to server_max_tokens
mhendrey Jan 20, 2025
39d7d76
Rearranged lines to make the changes with existing as small as possible
mhendrey Jan 20, 2025
b6a24c4
Limit generated tokens by server's max_tokens setting when available
mhendrey Jan 20, 2025
aa7cff1
Changed syntax to pass format.sh tests
mhendrey Jan 20, 2025
2f6e43b
[Bugfix] Fix num_heads value for simple connector when tp enabled (#1…
ShangmingCai Jan 20, 2025
6baa0ea
[torch.compile] fix sym_tensor_indices (#12191)
youkaichao Jan 20, 2025
35b5948
Move linting to `pre-commit` (#11975)
hmellor Jan 20, 2025
0c2f332
[DOC] Fix typo in docstring and assert message (#12194)
terrytangyuan Jan 20, 2025
46249e5
[DOC] Add missing docstring in LLMEngine.add_request() (#12195)
terrytangyuan Jan 20, 2025
0b2e3de
[Bugfix] Fix incorrect types in LayerwiseProfileResults (#12196)
terrytangyuan Jan 20, 2025
090eca3
[Model] Add Qwen2 PRM model support (#12202)
Isotr0py Jan 20, 2025
5d36c1f
[Core] Interface for accessing model from `VllmRunner` (#10353)
DarkLight1337 Jan 20, 2025
df331a7
[misc] add placeholder format.sh (#12206)
youkaichao Jan 20, 2025
881964d
[CI/Build] Remove dummy CI steps (#12208)
DarkLight1337 Jan 20, 2025
5cc6a09
[CI/Build] Make pre-commit faster (#12212)
DarkLight1337 Jan 20, 2025
9f3d5a6
[Model] Upgrade Aria to transformers 4.48 (#12203)
DarkLight1337 Jan 20, 2025
957ca23
[misc] print a message to suggest how to bypass commit hooks (#12217)
youkaichao Jan 20, 2025
399d224
[core][bugfix] configure env var during import vllm (#12209)
youkaichao Jan 20, 2025
df06503
[V1] Remove `_get_cache_block_size` (#12214)
heheda12345 Jan 20, 2025
b89529b
[Misc] Pass `attention` to impl backend (#12218)
wangxiyuan Jan 20, 2025
a5d57f1
[Bugfix] Fix `HfExampleModels.find_hf_info` (#12223)
DarkLight1337 Jan 20, 2025
b1af379
[CI] Pass local python version explicitly to pre-commit mypy.sh (#12224)
heheda12345 Jan 20, 2025
0e3a719
Added tests to check max_tokens is properly set
mhendrey Jan 23, 2025
6867b37
Merge branch 'server_max_tokens'
mhendrey Jan 23, 2025
99243cf
Mucked up the rebasing. Fixing that now.
mhendrey Jan 23, 2025
1a15431
Reverting the serving_chat & serving_completion back and putting all …
mhendrey Jan 23, 2025
c10eb1f
Didn't quite revert back. Deleting empty line from both
mhendrey Jan 23, 2025
a3fc62b
Changed to using one-liner and edited engine arg for generation-config
mhendrey Jan 24, 2025
98949f6
Merge branch 'vllm-project:main' into main
mhendrey Jan 24, 2025
c71f429
Converted to a one-liner for taking minimum value & added to generati…
mhendrey Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens():

assert mock_engine.generate.call_args.args[1].max_tokens == 10

# Setting server's max_tokens in the generation_config.json
# lower than context_window - prompt_tokens
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 10 # Setting server-side max_tokens limit
}

# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)

# Test Case 1: No max_tokens specified in request
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10

# Test Case 2: Request's max_tokens set higher than server accepts
req.max_tokens = 15

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10

# Test Case 3: Request's max_tokens set lower than server accepts
req.max_tokens = 5

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 5

# Setting server's max_tokens in the generation_config.json
# higher than context_window - prompt_tokens
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 200 # Setting server-side max_tokens limit
}

# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)

# Test case 1: No max_tokens specified, defaults to context_window
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 93

# Test Case 2: Request's max_tokens set higher than server accepts
req.max_tokens = 100

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 93

# Test Case 3: Request's max_tokens set lower than server accepts
req.max_tokens = 5

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 5


def test_serving_chat_could_load_correct_generation_config():

Expand Down
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,12 +918,18 @@ def get_diff_sampling_param(self) -> Dict[str, Any]:
"top_k",
"top_p",
"min_p",
"max_new_tokens",
]
if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
# Huggingface definition of max_new_tokens is equivalent
# to vLLM's max_tokens
if "max_new_tokens" in diff_sampling_param:
diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
"max_new_tokens")
else:
diff_sampling_param = {}
return diff_sampling_param
Expand Down
28 changes: 20 additions & 8 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,16 @@ class ChatCompletionRequest(OpenAIBaseModel):

def to_beam_search_params(
self,
default_max_tokens: int,
server_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
max_tokens = server_max_tokens
# Don't allow user to exceed server limit. Should this notify user?
else:
max_tokens = min(max_tokens, server_max_tokens)

if default_sampling_params is None:
default_sampling_params = {}
Expand All @@ -401,13 +404,16 @@ def to_beam_search_params(

def to_sampling_params(
self,
default_max_tokens: int,
server_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
max_tokens = server_max_tokens
# Don't allow user to exceed server limit. Should this notify user?
else:
max_tokens = min(max_tokens, server_max_tokens)

if default_sampling_params is None:
default_sampling_params = {}
Expand Down Expand Up @@ -736,12 +742,15 @@ class CompletionRequest(OpenAIBaseModel):

def to_beam_search_params(
self,
default_max_tokens: int,
server_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
max_tokens = server_max_tokens
# Don't allow user to exceed server limit. Should this notify user?
else:
max_tokens = min(max_tokens, server_max_tokens)

if default_sampling_params is None:
default_sampling_params = {}
Expand All @@ -760,12 +769,15 @@ def to_beam_search_params(

def to_sampling_params(
self,
default_max_tokens: int,
server_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
max_tokens = server_max_tokens
# Don't allow user to exceed server limit. Should this notify user?
else:
max_tokens = min(max_tokens, server_max_tokens)

if default_sampling_params is None:
default_sampling_params = {}
Expand Down
13 changes: 10 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,24 @@ async def create_chat_completion(
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
server_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())

# Limit set by architecture or value in generation_config.json
if "max_tokens" in default_sampling_params:
server_max_tokens = min(
server_max_tokens,
default_sampling_params["max_tokens"])

if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, default_sampling_params)
server_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
server_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params)

Expand Down
13 changes: 10 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,24 @@ async def create_completion(
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
server_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())

# Limit set by architecture or value in generation_config.json
if "max_tokens" in default_sampling_params:
server_max_tokens = min(
server_max_tokens,
default_sampling_params["max_tokens"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since default_sampling_params is also passed to request.to_beam_search_params and request.to_sampling_params, let's handle this inside those methods instead.


if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, default_sampling_params)
server_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
server_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params)

Expand Down
Loading