diff --git a/benchmarks/structured_schemas/structured_schema_1.json b/benchmarks/structured_schemas/structured_schema_1.json index 6003698469e8d..1bd189c9e704f 100644 --- a/benchmarks/structured_schemas/structured_schema_1.json +++ b/benchmarks/structured_schemas/structured_schema_1.json @@ -1,113 +1,25 @@ { - "$schema": - "https://json-schema.org/draft/2020-12/schema", - "title": - "User Profile", - "type": - "object", + "type": "array", + "items": { + "type": "object", "properties": { - "userId": { - "type": "string", - "description": "Unique identifier for the user." - }, - "personalInfo": { - "type": "object", - "properties": { - "firstName": { - "type": "string", - "description": "The user's first name." - }, - "lastName": { - "type": "string", - "description": "The user's last name." - }, - "age": { - "type": "integer", - "minimum": 0, - "description": "The user's age." - }, - "phoneNumbers": { - "type": - "array", - "items": { - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["home", "work", "mobile"], - "description": "Type of phone number." - }, - "number": { - "type": "string", - "pattern": "^\\+?[1-9]\\d{1,14}$", - "description": "Phone number in E.164 format." - } - }, - "required": ["type", "number"] - }, - "description": - "List of phone numbers associated with the user." - } - }, - "required": ["firstName", "lastName"] - }, - "address": { - "type": "object", - "properties": { - "street": { - "type": "string", - "description": "Street address." - }, - "city": { - "type": "string", - "description": "City name." - }, - "state": { - "type": "string", - "description": "State or province." - }, - "postalCode": { - "type": "string", - "pattern": "^\\d{5}(-\\d{4})?$", - "description": "Postal code." - }, - "country": { - "type": "string", - "description": "Country name." - } - }, - "required": ["street", "city", "state", "postalCode", "country"] - }, - "preferences": { - "type": "object", - "properties": { - "newsletterSubscribed": { - "type": - "boolean", - "description": - "Indicates if the user is subscribed to the newsletter." - }, - "favoriteCategories": { - "type": "array", - "items": { - "type": "string" - }, - "description": "List of user's favorite categories." - } - }, - "required": ["newsletterSubscribed"] - }, - "accountStatus": { - "type": "string", - "enum": ["active", "inactive", "suspended"], - "description": "Current status of the user's account." - }, - "registrationDate": { - "type": "string", - "format": "date-time", - "description": "ISO 8601 formatted date-time of user registration." - } + "name": { "type": "string" }, + "race": { "type": "string" }, + "class": { "type": "string" }, + "level": { "type": "integer" }, + "background": { "type": "string" }, + "alignment": { "type": "string" }, + "backstory": { "type": "string" } }, - "required": - ["userId", "personalInfo", "address", "accountStatus", "registrationDate"] -} \ No newline at end of file + "required": [ + "name", + "race", + "class", + "level", + "background", + "alignment", + "backstory" + ] + } +} + diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index b00e168db9d32..6d4278b4c8719 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -29,6 +29,7 @@ def sample_regex(): r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +# Note: Ensure this only uses attributes compatible with xgrammar @pytest.fixture def sample_json_schema(): return { @@ -44,9 +45,7 @@ def sample_json_schema(): "type": "array", "items": { "type": "string", - "maxLength": 10 - }, - "minItems": 3 + } }, "work_history": { "type": "array", @@ -71,8 +70,9 @@ def sample_json_schema(): } +# A schema unsupported by xgrammar @pytest.fixture -def sample_complex_json_schema(): +def unsupported_json_schema(): return { "type": "object", "properties": { @@ -150,7 +150,19 @@ def sample_guided_choice(): @pytest.fixture -def sample_sql_statements(): +def sample_sql_ebnf(): + return """ +root ::= select_statement +select_statement ::= "SELECT" column "from" table "where" condition +column ::= "col_1" | "col_2" +table ::= "table_1" | "table_2" +condition ::= column "=" number +number ::= "1" | "2" +""" + + +@pytest.fixture +def sample_sql_lark(): return (""" start: select_statement select_statement: "SELECT" column "from" table "where" condition diff --git a/tests/v1/entrypoints/llm/__init__.py b/tests/v1/entrypoints/llm/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py new file mode 100644 index 0000000000000..871739bcf1640 --- /dev/null +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import GuidedDecodingParams, SamplingParams + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" +GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_completion(monkeypatch, sample_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_json_schema, + backend=guided_decoding_backend)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_object(monkeypatch, guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams( + json_object=True, + backend=guided_decoding_backend)) + + outputs = llm.generate( + prompts=("Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old."), + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + for i in range(2): + generated_text = output.outputs[i].text + print(generated_text) + assert generated_text is not None + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=unsupported_json_schema, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar."): + llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {unsupported_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_ebnf, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_lark(monkeypatch, sample_sql_lark, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_lark, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_lark) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_ebnf_invalid(monkeypatch, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar="not a grammar", + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="Failed to convert the grammar " + "from Lark to EBNF."): + llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="Regex guided decoding is not supported."): + llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + # Once regex is supported -- + #assert outputs is not None + #for output in outputs: + # assert output is not None + # assert isinstance(output, RequestOutput) + # prompt = output.prompt + # generated_text = output.outputs[0].text + # print(generated_text) + # assert generated_text is not None + # assert re.fullmatch(sample_regex, generated_text) is not None + # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_choice_completion(monkeypatch, sample_guided_choice, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + choice=sample_guided_choice, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/v1/guided_decoding/__init__.py b/tests/v1/guided_decoding/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/guided_decoding/test_utils.py b/tests/v1/guided_decoding/test_utils.py new file mode 100644 index 0000000000000..edc304714e2a0 --- /dev/null +++ b/tests/v1/guided_decoding/test_utils.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +from vllm.v1.guided_decoding.utils import ( + has_xgrammar_unsupported_json_features) + + +def test_has_xgrammar_unsupported_json_features(): + schemas_with_unsupported_features: List[dict] = [{ + "type": "string", + "pattern": "^[a-zA-Z]+$" + }, { + "type": + "string", + "enum": ["active", "inactive", "pending"] + }, { + "type": "integer", + "minimum": 0 + }, { + "type": "integer", + "maximum": 120 + }, { + "type": "integer", + "exclusiveMinimum": 120 + }, { + "type": "integer", + "exclusiveMaximum": 120 + }, { + "type": "integer", + "multipleOf": 120 + }, { + "type": "number", + "minimum": 0 + }, { + "type": "number", + "maximum": 120 + }, { + "type": "number", + "exclusiveMinimum": 120 + }, { + "type": "number", + "exclusiveMaximum": 120 + }, { + "type": "number", + "multipleOf": 120 + }, { + "type": "array", + "uniqueItems": True + }, { + "type": "array", + "contains": { + "type": "string" + } + }, { + "type": "array", + "minContains": 1 + }, { + "type": "array", + "maxContains": 5 + }, { + "type": "array", + "minItems": 1 + }, { + "type": "array", + "maxItems": 10 + }, { + "type": "string", + "minLength": 1 + }, { + "type": "string", + "maxLength": 100 + }, { + "type": "string", + "format": "email" + }, { + "type": "object", + "minProperties": 1 + }, { + "type": "object", + "maxProperties": 5 + }, { + "type": "object", + "propertyNames": { + "pattern": "^[a-z]+$" + } + }, { + "type": "object", + "patternProperties": { + "^S": { + "type": "string" + } + } + }] + + for schema in schemas_with_unsupported_features: + assert has_xgrammar_unsupported_json_features(schema) + + schema_without_unsupported_features: dict = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "status": { + "type": "string" + }, + "scores": { + "type": "array", + "items": { + "type": "number" + } + }, + "address": { + "type": "object", + "properties": { + "street": { + "type": "string" + }, + "city": { + "type": "string" + } + } + } + } + } + + assert not has_xgrammar_unsupported_json_features( + schema_without_unsupported_features) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 87c9c0cd12b7b..47056bff0d144 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time -from collections import deque -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, SpeculativeConfig) @@ -14,6 +13,7 @@ SchedulerOutput) from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) +from vllm.v1.guided_decoding import GuidedDecodingManager from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -31,12 +31,14 @@ def __init__( lora_config: Optional[LoRAConfig], speculative_config: Optional[SpeculativeConfig], log_stats: bool, + guided_decoding_manager: GuidedDecodingManager, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config self.speculative_config = speculative_config self.log_stats = log_stats + self.guided_decoding_manager = guided_decoding_manager # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -58,8 +60,12 @@ def __init__( # req_id -> Request self.requests: Dict[str, Request] = {} - # Priority queues for requests. - self.waiting: Deque[Request] = deque() + # NOTE: Priority queues for requests. + # With list, we can safely pop the index + # of a request that are yet to be ready (in this case, + # the one that uses guided decoding) while still maintaining + # the order of all requests in existing waiting queue. + self.waiting: List[Request] = [] self.running: List[Request] = [] # The requests that have been scheduled and are being executed # by the executor. @@ -113,6 +119,14 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] + # NOTE: guided_decoding_request_ids maps + # guided request's (request that use structured decoding) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + guided_decoding_request_ids: Dict[str, int] = {} + req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens @@ -134,11 +148,14 @@ def schedule(self) -> "SchedulerOutput": req_index += 1 continue - num_new_tokens = (request.num_tokens_with_spec - - request.num_computed_tokens) + num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + # Guided decoding related. + if request.use_guided_decoding: + guided_decoding_request_ids[request.request_id] = req_index + # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( self._try_schedule_encoder_inputs(request, @@ -166,7 +183,7 @@ def schedule(self) -> "SchedulerOutput": preempted_req.num_computed_tokens = 0 self.request_preempted(preempted_req, scheduled_timestamp) - self.waiting.appendleft(preempted_req) + self.waiting.insert(0, preempted_req) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. @@ -220,12 +237,24 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: - while self.waiting and token_budget > 0: + # NOTE: We uses num_to_skip to determine + # which guided request within the waiting queue to skip + # over if the FSM of said request are yet to be ready. + num_to_skip: int = 0 + while num_to_skip < len(self.waiting) and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - request = self.waiting[0] + request = self.waiting[num_to_skip] + + if request.status == RequestStatus.WAITING_FOR_FSM: + if request.grammar and request.is_grammar_ready: + request.status = RequestStatus.WAITING + else: + num_to_skip += 1 + continue + # # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request: @@ -279,7 +308,10 @@ def schedule(self) -> "SchedulerOutput": # The request cannot be scheduled. break - self.waiting.popleft() + self.waiting.pop(num_to_skip) + if request.use_guided_decoding: + guided_decoding_request_ids[request.request_id] = req_index + req_index += 1 self.running.append(request) self.scheduled_req_ids.add(request.request_id) self.request_scheduled(request, scheduled_timestamp) @@ -330,6 +362,22 @@ def schedule(self) -> "SchedulerOutput": self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) + # Prepare the guided decoding bitmask for this batch. + grammar_bitmask = None + if guided_decoding_request_ids: + # Fill the bitmask using the index of each request equal to its + # position in the batch. Resize the bitmask down to the size of + # the batch. + grammar_bitmask = self.guided_decoding_manager.grammar_bitmask + assert grammar_bitmask is not None + for req_id, batch_index in guided_decoding_request_ids.items(): + request = self.requests[req_id] + assert request.grammar is not None + if not request.grammar.matcher.is_terminated(): + request.grammar.fill_bitmask(grammar_bitmask, batch_index) + if len(self.running) < grammar_bitmask.shape[0]: + grammar_bitmask = grammar_bitmask[:len(self.running)] + # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, @@ -368,6 +416,8 @@ def schedule(self) -> "SchedulerOutput": # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + guided_decoding_request_ids=guided_decoding_request_ids, + grammar_bitmask=grammar_bitmask, ) self.finished_req_ids = set() @@ -545,6 +595,27 @@ def update_from_output( new_logprobs = None new_token_ids: List[int] = [] + # Handle guided decoding FSM advancement if applicable + # NOTE: For all requests that uses guided decoding, the grammar + # should be ready at this point. + # PERF: This is currently expensive given that FSM is being + # advanced here. + if request.use_guided_decoding: + grammar = request.grammar + assert grammar is not None + if len(generated_token_ids) > 1: + logger.error( + "Structured output does not currently support " + "more than one token at a time. Only the first " + "token will be used.") + # accept_token advances the FSM + accepted = grammar.accept_token(generated_token_ids[0]) + if not accepted: + logger.error( + "Failed to advance FSM for request %s " + "for tokens %s. Please file an issue.", + req_id, generated_token_ids[0]) + if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: request.append_output_token_ids(output_token_id) diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 47413527c32f2..9cea11d88067a 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple if TYPE_CHECKING: + import torch + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange @@ -111,3 +113,9 @@ class SchedulerOutput: # List of (req_id, encoder_input_index) tuples. # Used to free the encoder cache. free_encoder_input_ids: List[Tuple[str, int]] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + guided_decoding_request_ids: Dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional["torch.Tensor"] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0c04e14cec2f6..ddfe3444e5826 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -74,6 +74,7 @@ def __init__( model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, + decoding_config=vllm_config.decoding_config, tokenizer=self.tokenizer, input_registry=input_registry, ) @@ -188,8 +189,8 @@ async def _generate( * 3) Adding the Request to the Detokenizer. * 4) Adding the Request to the EngineCore (separate process). - A separate output_handler loop runs in a background AsyncIO task, - pulling outputs from EngineCore and putting them into the + A separate output_handler loop runs in a background AsyncIO task, + pulling outputs from EngineCore and putting them into the per-request AsyncStream. The caller of generate() iterates the returned AsyncGenerator, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 041896f1c7cc5..1375a2b7a5222 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -26,6 +26,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor +from vllm.v1.guided_decoding import GuidedDecodingManager from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -61,6 +62,8 @@ def __init__( vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + self.guided_decoding_manager = GuidedDecodingManager(vllm_config) + # Setup scheduler. self.scheduler = Scheduler( scheduler_config=vllm_config.scheduler_config, @@ -69,6 +72,7 @@ def __init__( lora_config=vllm_config.lora_config, speculative_config=vllm_config.speculative_config, log_stats=self.log_stats, + guided_decoding_manager=self.guided_decoding_manager, ) # Setup MM Input Mapper. @@ -131,6 +135,9 @@ def add_request(self, request: EngineCoreRequest): request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) + if req.use_guided_decoding: + # Start grammar compilation asynchronously + self.guided_decoding_manager.populate_cache(req) self.scheduler.add_request(req) @@ -143,6 +150,8 @@ def abort_requests(self, request_ids: List[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) + self.guided_decoding_manager.remove_requests(request_ids) + def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -150,10 +159,30 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) + # Check cache for compiled grammars and add them to requests + # when they're ready. + self.guided_decoding_manager.setup_grammars() + scheduler_output = self.scheduler.schedule() + + # This case may occur when the only unfinished requests are + # guided decoding requests where the grammar has not finished + # compiling yet, so there's nothing to run. + if scheduler_output.total_num_scheduled_tokens == 0: + return EngineCoreOutputs( + outputs=[], scheduler_stats=self.scheduler.make_stats()) + + # Currently we will broadcast the bitmask. It is populated during + # each schedule() run. + if len(self.guided_decoding_manager.requests) > 0: + scheduler_output.grammar_bitmask = \ + self.guided_decoding_manager.grammar_bitmask + output = self.model_executor.execute_model(scheduler_output) + engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore + return engine_core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9f36e11d12d76..ff901b381bd5d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -24,6 +24,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor +from vllm.v1.guided_decoding.utils import validate_guided_decoding_request from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.utils import BackgroundProcHandle @@ -66,6 +67,15 @@ def make_client( return InprocClient(vllm_config, executor_class, log_stats) + @staticmethod + def _validate_request(request: EngineCoreRequest) -> None: + """Validate request before sending to EngineCore. + + Raises ValueError if request contents are known to be invalid or + unsupported. + """ + validate_guided_decoding_request(request.sampling_params) + @abstractmethod def shutdown(self): ... @@ -160,6 +170,7 @@ def get_output(self) -> EngineCoreOutputs: return self.engine_core.step() def add_request(self, request: EngineCoreRequest) -> None: + self._validate_request(request) self.engine_core.add_request(request) def abort_requests(self, request_ids: List[str]) -> None: @@ -368,6 +379,7 @@ def _call_utility(self, method: str, *args) -> Any: return future.result() def add_request(self, request: EngineCoreRequest) -> None: + self._validate_request(request) # NOTE: text prompt is not needed in the core engine as it has been # tokenized. request.prompt = None @@ -466,6 +478,7 @@ async def _call_utility_async(self, method: str, *args) -> Any: return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + self._validate_request(request) # NOTE: text prompt is not needed in the core engine as it has been # tokenized. request.prompt = None diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index ccf52250c1d6f..117a80e3ef0c6 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -71,6 +71,7 @@ def __init__( self.processor = Processor(model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, + decoding_config=vllm_config.decoding_config, tokenizer=self.tokenizer, input_registry=input_registry, mm_registry=mm_registry) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 2547cebaede7c..8da1768440c2c 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -3,7 +3,7 @@ import time from typing import Mapping, Optional, Union -from vllm.config import CacheConfig, LoRAConfig, ModelConfig +from vllm.config import CacheConfig, DecodingConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs @@ -27,6 +27,7 @@ def __init__( model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + decoding_config: DecodingConfig, tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, @@ -35,6 +36,7 @@ def __init__( self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config + self.decoding_config = decoding_config self.tokenizer = tokenizer self.generation_config_fields = model_config.try_get_generation_config( @@ -83,6 +85,18 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + def _validate_guided_decoding( + self, params: Union[SamplingParams, PoolingParams]) -> None: + if not isinstance(params, SamplingParams): + return + if self.decoding_config.guided_decoding_backend != "xgrammar": + raise ValueError( + "Only xgrammar guided decoding is supported in V1.") + if (params.guided_decoding and params.guided_decoding.backend + and params.guided_decoding.backend != 'xgrammar'): + raise ValueError( + "Only xgrammar guided decoding is supported in V1.") + def _validate_allowed_token_ids( self, params: Union[SamplingParams, PoolingParams], @@ -113,6 +127,7 @@ def process_inputs( self._validate_logprobs(params) self._validate_lora(lora_request) + self._validate_guided_decoding(params) self._validate_allowed_token_ids(params) if arrival_time is None: diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py new file mode 100644 index 0000000000000..27936e64841e8 --- /dev/null +++ b/vllm/v1/guided_decoding/__init__.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import copy +import enum +import threading +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Set, Tuple + +import torch +import xgrammar as xgr + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +if TYPE_CHECKING: + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class GuidedDecodingOptions(enum.Enum): + json = enum.auto() + json_object = enum.auto() + regex = enum.auto() + grammar = enum.auto() + choice = enum.auto() + + +GuidedDecodingKey = Tuple[GuidedDecodingOptions, str] +MAX_ROLLBACK_TOKENS = 100 + + +def apply_bitmask( + logits: torch.Tensor, + vocab_mask: torch.Tensor, + indices: List[int], +) -> None: + xgr.apply_token_bitmask_inplace(logits, vocab_mask, indices=indices) + + +@dataclass(slots=True, unsafe_hash=True) # type: ignore[call-overload] +class Grammar: + # NOTE: This would be a generic-enough class for + # supporting different backends, in the future. + # For now, just xgrammar. + # + # TODO: support max_rollback_tokens + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding + + vocab_size: int + matcher: xgr.GrammarMatcher = field(hash=False) + ctx: xgr.CompiledGrammar = field(hash=False) + max_rollback_tokens: int = field(default=MAX_ROLLBACK_TOKENS, kw_only=True) + num_processed_tokens: int = field( + default_factory=lambda: 0, + repr=False, + hash=False, + init=False, + ) + + def accept_token(self, token: int) -> bool: + # NOTE: accept_token will determines whether we accept this token + # and will also update the machine state + self.num_processed_tokens += 1 + return self.matcher.accept_token(token) + + # this should be ran in parallel with model decoding + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: + return self.matcher.fill_next_token_bitmask(bitmask, idx) + + def rollback(self, num_tokens: int): + if self.num_processed_tokens > 0: + self.num_processed_tokens -= num_tokens + self.matcher.rollback(num_tokens) + + def reset(self): + self.num_processed_tokens = 0 + self.matcher.reset() + + def __copy__(self): + return Grammar(matcher=xgr.GrammarMatcher(self.ctx), + vocab_size=self.vocab_size, + ctx=self.ctx, + max_rollback_tokens=self.max_rollback_tokens) + + +class GuidedDecodingManager: + + def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500): + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) # type: ignore[arg-type] + tokenizer_group.ping() + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.vllm_config = vllm_config + + tokenizer = tokenizer_group.get_lora_tokenizer(None) + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, vocab_size=self.vocab_size) + self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + + self.max_cache_size = max_cache_size + self.request_key_to_grammar: OrderedDict[GuidedDecodingKey, + Grammar] = OrderedDict() + + self.executor = ThreadPoolExecutor() + self.requests: Set[Request] = set() + self._requests_lock = threading.Lock() + self.grammar_bitmask = xgr.allocate_token_bitmask( + self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) + + def __getitem__(self, key: GuidedDecodingKey) -> Optional[Grammar]: + if key in self.request_key_to_grammar: + # Move accessed item to the end (most recently used) + value = self.request_key_to_grammar.pop(key) + self.request_key_to_grammar[key] = value + return value + return None + + def remove_requests(self, request_ids: List[str]) -> None: + with self._requests_lock: + self.requests = { + req + for req in self.requests if req.request_id not in request_ids + } + + def populate_cache(self, request: Request): + if not request.use_guided_decoding: + return False + grammar = self.request_key_to_grammar.get(request.guided_decoding_key) + if grammar: + request.grammar = copy.copy(grammar) + return False + request.grammar = self.cache(request) + return True + + def cache(self, request: Request): + return self.executor.submit(self._executor_loop, request) + + def _executor_loop(self, request: Request) -> Grammar: + key = request.guided_decoding_key + with self._requests_lock: + self.requests.add(request) + if key in self.request_key_to_grammar: + grammar = self.request_key_to_grammar[key] + return copy.copy(grammar) + grammar = self.initialize_grammar(key) + # If cache is full, remove the least recently used item + if len(self.request_key_to_grammar) >= self.max_cache_size: + self.request_key_to_grammar.popitem(last=False) + self.request_key_to_grammar[key] = grammar + return copy.copy(grammar) + + def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: + # Note that the request was validated in the engine core client, + # so at this point we know it is a supported type of request. + # + # TODO: we still need to handle xgrammar compilation failures + request_type, grammar_spec = key + + if request_type == GuidedDecodingOptions.json: + # TODO -- allow any_whitespace to be configurable + # pending merge of https://github.com/vllm-project/vllm/pull/12744 + ctx = self.compiler.compile_json_schema(grammar_spec, + any_whitespace=False) + elif request_type == GuidedDecodingOptions.json_object: + ctx = self.compiler.compile_builtin_json_grammar() + elif request_type == GuidedDecodingOptions.grammar: + ctx = self.compiler.compile_grammar(grammar_spec) + else: + logger.error("Validation should have already occurred. " + "Please file an issue.") + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})") + + return Grammar( + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vocab_size, + ctx=ctx, + max_rollback_tokens=self.vllm_config.speculative_config. + num_lookahead_slots + if self.vllm_config.speculative_config else MAX_ROLLBACK_TOKENS) + + def setup_grammars(self): + with self._requests_lock: + for req in self.requests: + if req.grammar is not None: + continue + + # Check if grammar is ready in cache + grammar = self[req.guided_decoding_key] + if grammar is not None: + req.grammar = copy.copy(grammar) + continue diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py new file mode 100644 index 0000000000000..e01d8771be244 --- /dev/null +++ b/vllm/v1/guided_decoding/utils.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from typing import List + +import xgrammar + +from vllm.sampling_params import SamplingParams + + +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for enum restrictions + if "enum" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj for key in [ + "minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf" + ]): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any(key in obj for key in [ + "uniqueItems", "contains", "minContains", "maxContains", + "minItems", "maxItems" + ]): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ["minLength", "maxLength", "format"]): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any(key in obj for key in [ + "minProperties", "maxProperties", "propertyNames", + "patternProperties" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for EBNF rule definition + if '::=' in line: + return False + + return True + + +def convert_lark_to_ebnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to EBNF format. + + EBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in EBNF format + + Examples: + >>> print(convert_lark_to_ebnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) + + +def choice_as_grammar(choice: List[str]) -> str: + + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + escaped_choices = (escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + + +def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: + """Validate that the request is supported by guided decoding. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + + if gd_params.regex: + raise ValueError("Regex guided decoding is not supported.") + + if gd_params.choice: + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgrammar.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") + return + + if gd_params.grammar: + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark + try: + # parse the grammar, but we aren't compiling it. + xgrammar.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 52d7faeeb0664..3f941a8b7209c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,18 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import enum +import functools +import json +from concurrent.futures import Future +from concurrent.futures._base import TimeoutError from typing import TYPE_CHECKING, List, Optional, Union -from vllm.lora.request import LoRARequest +from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) +from vllm.v1.guided_decoding import (Grammar, GuidedDecodingKey, + GuidedDecodingOptions) from vllm.v1.utils import ConstantList if TYPE_CHECKING: + + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange +logger = init_logger(__name__) + class Request: @@ -21,9 +33,9 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: List[int], - multi_modal_inputs: Optional[List["MultiModalKwargs"]], + multi_modal_inputs: Optional[List[MultiModalKwargs]], multi_modal_hashes: Optional[List[str]], - multi_modal_placeholders: Optional[List["PlaceholderRange"]], + multi_modal_placeholders: Optional[List[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, @@ -35,7 +47,9 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request - self.status = RequestStatus.WAITING + self.status = (RequestStatus.WAITING_FOR_FSM + if sampling_params.guided_decoding is not None else + RequestStatus.WAITING) self.events: List[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None assert sampling_params.max_tokens is not None @@ -65,8 +79,11 @@ def __init__( self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) + # Grammar fields, including the grammar object and the bitmask + self._grammar: Optional[Union[Future[Grammar], Grammar]] = None + @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + def from_engine_core_request(cls, request: EngineCoreRequest) -> Request: return cls( request_id=request.request_id, prompt=request.prompt, @@ -134,26 +151,78 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens + @property + def use_guided_decoding(self) -> bool: + return self.sampling_params.guided_decoding is not None + + @functools.cached_property + def guided_decoding_key(self) -> GuidedDecodingKey: + params = self.sampling_params.guided_decoding + assert params is not None, "params can't be None." + if params.json is not None: + if not isinstance(params.json, str): + json_str = json.dumps(params.json) + else: + json_str = params.json + return (GuidedDecodingOptions.json, json_str) + elif params.json_object: + return (GuidedDecodingOptions.json_object, "") + elif params.regex is not None: + return (GuidedDecodingOptions.regex, params.regex) + elif params.choice is not None: + if not isinstance(params.choice, str): + json_str = json.dumps(params.choice) + else: + json_str = params.choice + return (GuidedDecodingOptions.choice, json_str) + elif params.grammar is not None: + return (GuidedDecodingOptions.grammar, params.grammar) + else: + raise ValueError("No valid guided decoding parameter found") + + def _check_grammar_completion(self) -> bool: + if isinstance(self._grammar, Future): + try: + self._grammar = self._grammar.result(timeout=0.0001) + self.status = RequestStatus.WAITING + except TimeoutError: + return False + return True + + @property + def is_grammar_ready(self) -> bool: + return self._check_grammar_completion() + + @property + def grammar(self) -> Optional[Grammar]: + self._check_grammar_completion() + return self._grammar if isinstance(self._grammar, Grammar) else None + + @grammar.setter + def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None: + self._grammar = grammar + class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = 0 - RUNNING = 1 - PREEMPTED = 2 - # Note: anything after PREEMPTED (2) will be considered + WAITING_FOR_FSM = enum.auto() + RUNNING = enum.auto() + PREEMPTED = enum.auto() + # Note: anything after PREEMPTED will be considered # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() @staticmethod - def is_finished(status: "RequestStatus") -> bool: + def is_finished(status: RequestStatus) -> bool: return status > RequestStatus.PREEMPTED @staticmethod def get_finished_reason( - status: "RequestStatus") -> Union[FinishReason, None]: + status: RequestStatus) -> Union[FinishReason, None]: return _FINISHED_REASON_MAP.get(status) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4d0ae9a205a15..00f77b92d49df 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,6 +28,7 @@ FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient +from vllm.v1.guided_decoding import apply_bitmask from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -950,6 +951,43 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) + # Apply guided decoding bitmasks if present + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is not None: + # We receive the guided decoding bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the gpu runner is + # ordering the requests in the batch. We need to sort the bitmask to + # match the order of the requests used here. + req_id_indices: Dict[str, int] = {} + indices_match = True + for req_id in self.input_batch.req_ids: + if req_id not in scheduler_output.guided_decoding_request_ids: + # not a guided decoding request + continue + batch_index = self.input_batch.req_id_to_index[req_id] + if batch_index != scheduler_output.guided_decoding_request_ids[ + req_id]: + indices_match = False + req_id_indices[req_id] = batch_index + + sorted_bitmask: Optional[torch.Tensor] = None + if not indices_match: + # Sort the bitmask to match the order of the requests + sorted_bitmask = torch.zeros_like(grammar_bitmask) + for req_id, batch_index in req_id_indices.items(): + orig_index = scheduler_output.guided_decoding_request_ids[ + req_id] + sorted_bitmask[batch_index] = grammar_bitmask[orig_index] + grammar_bitmask = sorted_bitmask + + # TODO: compatibility with spec decode + apply_bitmask( + logits, + grammar_bitmask.to(self.device, non_blocking=True), + list(req_id_indices.values()), + ) + # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if not self.use_spec_decode: @@ -1364,7 +1402,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: - kv_cache_config: Configuration for the KV cache, including the KV + kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: @@ -1396,10 +1434,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Generates the KVCacheSpec by parsing the kv cache format from each + Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache + KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """