From d719c931ee3edae0e9e988624b0632ccf69f317a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 24 Jan 2025 00:41:46 -0500 Subject: [PATCH 01/84] feat: initial guided decoding implementation on scheduler Signed-off-by: Aaron Pham --- vllm/utils.py | 69 +++++++++ vllm/v1/core/guided_decoding/__init__.py | 182 +++++++++++++++++++++++ vllm/v1/core/guided_decoding/grammar.py | 147 ++++++++++++++++++ vllm/v1/core/scheduler.py | 142 ++++++++++++++---- vllm/v1/engine/core.py | 2 + vllm/v1/request.py | 43 +++++- vllm/v1/worker/gpu_input_batch.py | 8 + vllm/v1/worker/gpu_model_runner.py | 24 ++- 8 files changed, 578 insertions(+), 39 deletions(-) create mode 100644 vllm/v1/core/guided_decoding/__init__.py create mode 100644 vllm/v1/core/guided_decoding/grammar.py diff --git a/vllm/utils.py b/vllm/utils.py index 17bffd2846b46..d31434b612a05 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,6 +22,7 @@ import threading import time import traceback +import types import uuid import warnings import weakref @@ -2206,3 +2207,71 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) + + +class LazyLoader(types.ModuleType): + """ + LazyLoader module borrowed from Tensorflow + https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py + with a addition of "module caching". This will throw an + exception if module cannot be imported. + + Lazily import a module, mainly to avoid pulling in large dependencies. + `contrib`, and `ffmpeg` are examples of modules that are large and not always + needed, and this allows them to only be loaded when they are used. + """ + + def __init__( + self, + local_name: str, + parent_module_globals: Dict[str, Any], + name: str, + warning: Optional[str] = None, + exc_msg: Optional[str] = None, + exc: Type[Exception] = Exception, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._warning = warning + self._exc_msg = exc_msg + self._exc = exc + self._module: types.ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> types.ModuleType: + """Load the module and insert it into the parent's globals.""" + from . import warn_deprecated + + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise self._exc(f"{self._exc_msg} (reason: {err})") from None + + # Emit a warning if one was specified + if self._warning: + warnings.warn(self._warning, + category=DeprecationWarning, + stacklevel=4) + # Make sure to only warn once. + self._warning = None + + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on lookups + # that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> List[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py new file mode 100644 index 0000000000000..24a8d17fc6a13 --- /dev/null +++ b/vllm/v1/core/guided_decoding/__init__.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import copy +import threading +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, + get_args) + +from transformers import PreTrainedTokenizer + +from vllm.config import DecodingConfig, ModelConfig +from vllm.logger import init_logger +from vllm.utils import LazyLoader +from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus + +from .grammar import Grammar + +if TYPE_CHECKING: + import xgrammar as xgr + from transformers import PreTrainedTokenizer + from typing_extensions import LiteralString + + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + +__all__ = ["Grammar", "GuidedDecodingManager"] + + +@dataclass +class GrammarCache: + value: Grammar | None + event: threading.Event + + +T = TypeVar("T", bound=str) + + +class GuidedDecodingManager(ABC, Generic[T]): + + @abstractmethod + def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: + ... + + def flush(self): + with self._lock: + self.cache.clear() + + def cache_grammar(self, request: Request): + return self.executor.submit(self._add_grammar_to_cache, request) + + def get_grammar(self, request: Request): + with self._lock: + entry = self.cache.get(request.guided_decoding_key) + if entry is None or not entry.event.is_set(): return None + return copy.copy(entry.value) if entry.value else None + + def should_add(self, request: Request): + if not request.use_guided_decoding: return False + request.grammar = self.get_grammar(request) + if not request.grammar: + request.grammar = self.cache_grammar(request) + request.status = RequestStatus.WAITING_FOR_FSM + return True + return False + + def _add_grammar_to_cache(self, request: Request): + key = request.guided_decoding_key + with self._lock: + cache_hit = False + if key in self.cache: + cache_hit, entry = True, self.cache[key] + else: + entry = GrammarCache(None, threading.Event()) + self.cache[key] = entry + + if cache_hit: + entry.event.wait() + else: + entry.value = self.initialize_cache(key) + entry.event.set() + return copy.copy(entry.value) if entry.value else None + + @classmethod + def from_backend(cls, /, backend: LiteralString = "xgrammar", *, + tokenizer_group: BaseTokenizerGroup, + model_config: ModelConfig) -> GuidedDecodingManager[T]: + manager_cls = cls._registry.get(backend) + if manager_cls is None: raise ValueError( f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}") + return manager_cls(tokenizer_group=tokenizer_group, model_config=model_config) + + _registry: dict[str, type[GuidedDecodingManager[T]]] = {} + _backend: T + + def __init__(self, *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig): + self.model_config = model_config + self.tokenizer = tokenizer_group.get_lora_tokenizer(None) + self.cache: dict[GuidedDecodingKey, GrammarCache] = {} + self.executor = ThreadPoolExecutor() + self._lock = threading.Lock() + + def __init_subclass__(cls, **kwargs: Any): + if not hasattr(cls, '__orig_bases__'): + raise TypeError( + f"{cls.__qualname__} must be subclass of GuidedDecodingManager" + ) + + backend = None + for base in cls.__orig_bases__: + if (origin := get_args(base)) and issubclass( + base.__origin__, GuidedDecodingManager): + backend = get_args(origin[0])[0] + break + + if backend is None: + raise TypeError( + f"Class {cls.__qualname__} must specify backend as a Literal type" + ) + + if backend in cls._registry: + name = cls._registry[backend].__qualname__ + raise ValueError( + f"Backend '{backend}' is already registered to {name}") + + # Set the backend value from the Literal type + cls._backend = backend + cls._registry[backend] = cls + + +class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]): + # cache GrammarCompiler instances based on given tokenizer + _compiler_cache: dict[str, xgr.GrammarCompiler] = {} + _compiler: xgr.GrammarCompiler | None = None + + def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: + request_type, grammar_spec = key + compiler = XGrammarManager.get_compiler(self.tokenizer) + if request_type == "json": + if type(grammar_spec) is not str: + ctx = compiler.compile_builtin_json_grammar() + else: + ctx = compiler.compile_json_schema(grammar_spec) + elif request_type == "grammar": + ctx = compiler.compile_grammar(grammar_spec) + else: + raise ValueError("grammar is not of valid supported types.") + return Grammar.from_backend( + self._backend, + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.model_config.hf_text_config.vocab_size, + ctx=ctx) + + def flush(self): + super().flush() + if self._compiler: self._compiler.clear_cache() + for compiler in self._compiler_cache.values(): + compiler.clear_cache() + self._compiler_cache.clear() + + @classmethod + def get_compiler( + cls, + tokenizer: PreTrainedTokenizer, + *, + max_threads: int = 8, + # passthrough to TokenizerInfo + vocab_size: int | None = None, + stop_token_ids: list[int] | int | None = None + ) -> xgr.GrammarCompiler: + cache_key = str(hash(tokenizer)) + if cache_key not in cls._compiler_cache: + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + stop_token_ids=stop_token_ids, + vocab_size=vocab_size) + cls._compiler_cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, max_threads=max_threads) + return cls._compiler_cache[cache_key] diff --git a/vllm/v1/core/guided_decoding/grammar.py b/vllm/v1/core/guided_decoding/grammar.py new file mode 100644 index 0000000000000..baa9b6c693ca0 --- /dev/null +++ b/vllm/v1/core/guided_decoding/grammar.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, + overload) + +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + import torch + import xgrammar as xgr + from typing_extensions import LiteralString, Self +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + torch = LazyLoader("torch", globals(), "torch") + +T = TypeVar("T", bound=str) + + +class Grammar(ABC, Generic[T]): + finished: bool = False + + @abstractmethod + def accept_token(self, token: int) -> bool: + """Whether to accept the token and advance the machine state.""" + + @abstractmethod + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + """Fill the bitmask for the token at the given index.""" + + @abstractmethod + def allocate_bitmask(self, batch_size: int, + vocab_size: int) -> torch.Tensor: + """Allocate a bitmask for the given batch size and vocabulary size.""" + + @staticmethod + @abstractmethod + def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + """Apply the bitmask to the logits.""" + + @abstractmethod + def reset(self): + """Reset the machine state.""" + + @abstractmethod + def copy(self) -> Self: + """Copy the grammar object.""" + + def __copy__(self): + return self.copy() + + _registry: dict[str, type[Grammar[T]]] = {} + _backend: T + + def __init_subclass__(cls): + if not hasattr(cls, '__orig_bases__'): + raise TypeError( + f"Class {cls.__qualname__} must be a subclass of GrammarObject" + ) + + backend = None + for base in cls.__orig_bases__: + if (origin := get_args(base)) and issubclass( + base.__origin__, Grammar): + backend = get_args(origin[0])[0] + break + + if backend is None: + raise TypeError( + f"Class {cls.__qualname__} must specify backend as Literal type" + ) + + if backend in cls._registry: + name = cls._registry[backend].__qualname__ + raise ValueError( + f"Backend '{backend}' is already registered to {name}") + + # Set the backend value from the Literal type + cls._backend = backend + cls._registry[backend] = cls + + @overload + @classmethod + def from_backend( + cls, + backend: Literal["xgrammar"] = ..., + *, + matcher: xgr.GrammarMatcher = ..., + vocab_size: int = ..., + ctx: xgr.CompiledGrammar = ..., + ) -> XGrammar: + ... + + @overload + @classmethod + def from_backend(cls, + backend: LiteralString = ..., + **kwargs: Any) -> Grammar: + ... + + @classmethod + def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Grammar[T]: + grammar_cls = cls._registry.get(backend) + if grammar_cls is None: + raise ValueError( + f"No grammar implementation registered for '{backend}'") + return grammar_cls(**kwargs) + + +class XGrammar(Grammar[Literal["xgrammar"]]): + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding + + def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, + ctx: xgr.CompiledGrammar) -> None: + # TODO: support max_rollback_tokens + self.matcher = matcher + self.vocab_size = vocab_size + self.ctx = ctx + + def accept_token(self, token: int) -> bool: + # NOTE: accept_token will determines whether we accept this token + # and will also update the machine state + return self.matcher.accept_token(token) + + def allocate_bitmask(self, batch_size: int, + vocab_size: int) -> torch.Tensor: + return xgr.allocate_token_bitmask(batch_size, vocab_size) + + # this should be ran in parallel with model decoding + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(bitmask, idx) + + @staticmethod + def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape + xgr.apply_token_bitmask_inplace(logits, vocab_mask) + + def reset(self): + self.matcher.reset() + + def copy(self): + return XGrammar(matcher=xgr.GrammarMatcher(self.ctx), + vocab_size=self.vocab_size, + ctx=self.ctx) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 8ded5e5787133..97bb87f996e93 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,13 +1,19 @@ +from __future__ import annotations + from collections import deque from dataclasses import dataclass -from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Optional, + Set, Tuple, Union) -from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig +from vllm.config import (CacheConfig, DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) +from vllm.v1.core.guided_decoding import GuidedDecodingManager +from vllm.v1.core.guided_decoding.grammar import Grammar from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats @@ -15,6 +21,8 @@ from vllm.v1.request import Request, RequestStatus if TYPE_CHECKING: + import torch + from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange @@ -28,14 +36,17 @@ def __init__( scheduler_config: SchedulerConfig, model_config: ModelConfig, cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + lora_config: LoRAConfig | None, + decoding_config: DecodingConfig, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config + self.model_config = model_config + self.decoding_config = decoding_config # TODO: Support LoRA. assert lora_config is None, "V1 does not support LoRA yet." - # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = \ @@ -59,6 +70,22 @@ def __init__( self.waiting: Deque[Request] = deque() self.running: List[Request] = [] + # A set of unready requests that might be waiting for grammar compilation + # we can also use this for tracking spec decode request + self.grammar_queue: Deque[Request] = deque() + # initialize the tokenizer on the scheduler (this is used for constrained decoding) + tokenizer_group = init_tokenizer_from_configs( + model_config=model_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + lora_config=lora_config) + tokenizer_group.ping() + # setup guided decoding, right now uses xgrammar + self.guided_decoding_manager = GuidedDecodingManager[Any].from_backend( + backend=decoding_config.guided_decoding_backend, + tokenizer_group=tokenizer_group, + model_config=model_config) + # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished # requests so that they can free the cached states for those requests. @@ -105,6 +132,24 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] + # we need to check the grammar queue for any requests that have finished FSM compilation + newly_ready_reqs: List[Request] = [] + remaining_grammar_reqs: Deque[Request] = deque() + while self.grammar_queue: + request = self.grammar_queue.popleft() + grammar = self.guided_decoding_manager.get(request) + if grammar is not None: + request.grammar = grammar + request.status = RequestStatus.WAITING + newly_ready_reqs.append(request) + else: + remaining_grammar_reqs.append(request) + self.grammar_queue = remaining_grammar_reqs + + # append all newly ready requests to waiting queue with higher priority + for req in newly_ready_reqs: + self.waiting.appendleft(req) + req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens @@ -112,6 +157,12 @@ def schedule(self) -> "SchedulerOutput": scheduled_encoder_inputs: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens + # Create a shared bitmask tensor for the whole batch + vocab_size = self.model_config.get_vocab_size() + guided_decoding_bitmasks: Dict[str, torch.Tensor] = {} + guided_decoding_reqs: List[Request] = [] + batch_has_grammar = False + # First, schedule the RUNNING requests. # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be # in the "partial" state, where the request has some tokens computed @@ -125,6 +176,12 @@ def schedule(self) -> "SchedulerOutput": assert not has_partial_request assert token_budget > 0 request = self.running[req_index] + + # Skip requests waiting for FSM + if request.status == RequestStatus.WAITING_FOR_FSM: + req_index += 1 + continue + num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -182,6 +239,14 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Track if we need guided decoding + # Create individual bitmask for requests with grammar + if request.grammar is not None: + batch_has_grammar = True + if request.request_id not in guided_decoding_bitmasks: + bitmask = request.grammar.allocate_bitmask(1, vocab_size) + guided_decoding_bitmasks[request.request_id] = bitmask + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: @@ -193,6 +258,20 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] + + # Skip requests waiting for FSM + if request.status == RequestStatus.WAITING_FOR_FSM: + self.waiting.rotate(-1) + continue + + # Track guided decoding needs + if request.grammar is not None: + batch_has_grammar = True + if request.request_id not in guided_decoding_bitmasks: + bitmask = request.grammar.allocate_bitmask( + 1, vocab_size) + guided_decoding_bitmasks[request.request_id] = bitmask + # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) @@ -293,6 +372,7 @@ def schedule(self) -> "SchedulerOutput": req.num_computed_tokens) for req in scheduled_running_reqs ] preempted_req_ids = {req.request_id for req in preempted_reqs} + scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_resumed_reqs=resumed_reqs_data, @@ -307,6 +387,7 @@ def schedule(self) -> "SchedulerOutput": # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, + guided_decoding_bitmasks=guided_decoding_bitmasks, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), ) @@ -485,9 +566,13 @@ def _check_stop(self, request: Request) -> bool: return False def add_request(self, request: Request) -> None: - self.waiting.append(request) self.requests[request.request_id] = request + add_to_grammar_queue = self.guided_decoding_manager.collect(request) + + if add_to_grammar_queue: self.grammar_queue.append(request) + else: self.waiting.append(request) + def finish_requests( self, request_ids: Union[str, Iterable[str]], @@ -550,7 +635,6 @@ class NewRequestData: mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams block_ids: List[int] - num_computed_tokens: int @classmethod def from_request( @@ -559,17 +643,16 @@ def from_request( block_ids: List[int], num_computed_tokens: int, ) -> "NewRequestData": - return cls( - req_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, - mm_inputs=request.mm_inputs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, - sampling_params=request.sampling_params, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - ) + return cls(req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + grammar=request.grammar) @dataclass @@ -578,6 +661,7 @@ class ResumedRequestData: req_id: str block_ids: List[int] num_computed_tokens: int + grammar: Optional[Grammar] @classmethod def from_request( @@ -586,11 +670,10 @@ def from_request( block_ids: List[int], num_computed_tokens: int, ) -> "ResumedRequestData": - return cls( - req_id=request.request_id, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - ) + return cls(req_id=request.request_id, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + grammar=request.grammar) @dataclass @@ -599,6 +682,7 @@ class RunningRequestData: req_id: str new_block_ids: List[int] num_computed_tokens: int + grammar: Optional[Grammar] @classmethod def from_request( @@ -607,11 +691,10 @@ def from_request( new_block_ids: List[int], num_computed_tokens: int, ) -> "RunningRequestData": - return cls( - req_id=request.request_id, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) + return cls(req_id=request.request_id, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + grammar=request.grammar) @dataclass @@ -626,6 +709,9 @@ class SchedulerOutput: scheduled_encoder_inputs: Dict[str, List[int]] num_common_prefix_blocks: int + # request_id -> bitmask + guided_decoding_bitmasks: Dict[str, torch.Tensor] + preempted_req_ids: Set[str] finished_req_ids: Set[str] free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index cf94033a38d96..4cd2383738642 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -59,7 +59,9 @@ def __init__( scheduler_config=vllm_config.scheduler_config, model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, + parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config, + decoding_config=vllm_config.decoding_config, ) self.mm_input_mapper_server = MMInputMapperServer( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index eefcdaf29e753..b744d342870d9 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,17 +1,26 @@ +from __future__ import annotations + import enum -from typing import TYPE_CHECKING, List, Optional, Union +from functools import cached_property +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, + Union) -from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.core.guided_decoding import Grammar from vllm.v1.core.kv_cache_utils import BlockHashType +GuidedDecodingObject = Union[str, Dict[str, Any]] +GuidedDecodingKey = Tuple[Literal["json", "regex", "grammar", "choice"], + GuidedDecodingObject] + class Request: @@ -27,6 +36,7 @@ def __init__( eos_token_id: Optional[int], arrival_time: float, lora_request: Optional[LoRARequest] = None, + grammar: Optional[Grammar] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -70,6 +80,9 @@ def __init__( self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) + # grammar objects + self.grammar = grammar + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( @@ -131,18 +144,32 @@ def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: self._kv_block_hashes.append(block_hash) + @property + def use_guided_decoding(self) -> bool: + return self.sampling_params.guided_decoding is not None + + @cached_property + def guided_decoding_key(self) -> GuidedDecodingKey: + params = self.sampling_params.guided_decoding + if params.json is not None: return ("json", params.json) + elif params.regex is not None: return ("regex", params.regex) + elif params.choice is not None: return ("choice", params.choice) + elif params.grammar is not None: return ("grammar", params.grammar) + else: raise ValueError("No valid guided decoding parameter found") + class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = 0 - RUNNING = 1 - PREEMPTED = 2 + WAITING_FOR_FSM = enum.auto() + RUNNING = enum.auto() + PREEMPTED = enum.auto() # Note: anything after PREEMPTED (2) will be considered # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 28d8e39053874..9e9a5febd1b46 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,4 +1,5 @@ # Datastructures defining an input batch +from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set @@ -8,6 +9,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.core.guided_decoding.grammar import Grammar from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable @@ -32,6 +34,7 @@ class CachedRequestState: mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None + grammar: Optional[Grammar] = None @property def num_tokens(self) -> int: @@ -166,6 +169,7 @@ def __init__( self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() + self.guided_decoding_reqs: Set[str] = set() def add_request( self, @@ -233,6 +237,8 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) + if request.grammar is not None: self.guided_decoding_reqs.add(req_id) + def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: @@ -249,6 +255,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) + self.guided_decoding_reqs.discard(req_id) return req_index def clear(self) -> None: @@ -264,6 +271,7 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() + self.guided_decoding_reqs.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4b3c325ded906..d39bdb1bf2546 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -248,6 +248,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) + # we should advance the FSM here + if req_id in scheduler_output.guided_decoding_bitmasks and req_state.grammar is not None: + token_idx = scheduler_output.num_scheduled_tokens[req_id] - 1 + token_id = self.input_batch.token_ids_cpu[req_index, token_idx] + # Advance the FSM state + if not req_state.grammar.accept_token(token_id): + # This shouldn't happen since we masked the logits, but handle gracefully + logger.error( + f"Grammar rejected token {token_id} for request {req_id}" + ) + req_state.status = RequestStatus.FINISHED_ABORTED + continue + req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -768,6 +781,11 @@ def execute_model( hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(hidden_states, None) + # We will need to apply the logits inplace from here + # so the scheduler_output should contains both the grammar + # of the running request to advance as well as the specific bitmask + # broadcasted from the scheduler.schedule() + # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(scheduler_output) sampler_output = self.model.sample( @@ -1007,7 +1025,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: - kv_cache_config: Configuration for the KV cache, including the KV + kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: @@ -1039,10 +1057,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Generates the KVCacheSpec by parsing the kv cache format from each + Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache + KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ From 36bc041e30429d160ba4818e347db419d25727dd Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 28 Jan 2025 04:59:38 +0000 Subject: [PATCH 02/84] chore: --wip-- Signed-off-by: Aaron Pham --- vllm/v1/core/guided_decoding/__init__.py | 78 +++++++++++++----------- vllm/v1/core/guided_decoding/grammar.py | 23 ++++--- vllm/v1/core/scheduler.py | 70 ++++++++++----------- vllm/v1/request.py | 4 +- vllm/v1/worker/gpu_input_batch.py | 8 +-- 5 files changed, 99 insertions(+), 84 deletions(-) diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py index 24a8d17fc6a13..20a1468e1329d 100644 --- a/vllm/v1/core/guided_decoding/__init__.py +++ b/vllm/v1/core/guided_decoding/__init__.py @@ -5,12 +5,11 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, - get_args) +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args from transformers import PreTrainedTokenizer -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.utils import LazyLoader from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus @@ -23,6 +22,8 @@ from typing_extensions import LiteralString from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + + from .grammar import XGrammar else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -48,58 +49,67 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: def flush(self): with self._lock: - self.cache.clear() + self.grammar_cache.clear() + + def cache(self, request: Request): + + def _executor_loop(request: Request): + key = request.guided_decoding_key + with self._lock: + cache_hit = False + if key in self.grammar_cache: + cache_hit, entry = True, self.grammar_cache[key] + else: + entry = GrammarCache(None, threading.Event()) + self.grammar_cache[key] = entry + + if cache_hit: + entry.event.wait() + else: + entry.value = self.initialize_cache(key) + entry.event.set() + return copy.copy(entry.value) if entry.value else None - def cache_grammar(self, request: Request): - return self.executor.submit(self._add_grammar_to_cache, request) + return self.executor.submit(_executor_loop, request) - def get_grammar(self, request: Request): + def get(self, request: Request): with self._lock: - entry = self.cache.get(request.guided_decoding_key) + entry = self.grammar_cache.get(request.guided_decoding_key) if entry is None or not entry.event.is_set(): return None return copy.copy(entry.value) if entry.value else None - def should_add(self, request: Request): + def collect(self, request: Request): if not request.use_guided_decoding: return False - request.grammar = self.get_grammar(request) + request.grammar = self.get(request) if not request.grammar: - request.grammar = self.cache_grammar(request) + request.grammar = self.cache(request) request.status = RequestStatus.WAITING_FOR_FSM return True return False - def _add_grammar_to_cache(self, request: Request): - key = request.guided_decoding_key - with self._lock: - cache_hit = False - if key in self.cache: - cache_hit, entry = True, self.cache[key] - else: - entry = GrammarCache(None, threading.Event()) - self.cache[key] = entry - - if cache_hit: - entry.event.wait() - else: - entry.value = self.initialize_cache(key) - entry.event.set() - return copy.copy(entry.value) if entry.value else None - @classmethod - def from_backend(cls, /, backend: LiteralString = "xgrammar", *, + def from_backend(cls, + backend: LiteralString = "xgrammar", + /, + *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig) -> GuidedDecodingManager[T]: manager_cls = cls._registry.get(backend) - if manager_cls is None: raise ValueError( f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}") - return manager_cls(tokenizer_group=tokenizer_group, model_config=model_config) + if manager_cls is None: + raise ValueError( + f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}" + ) + return manager_cls(tokenizer_group=tokenizer_group, + model_config=model_config) _registry: dict[str, type[GuidedDecodingManager[T]]] = {} _backend: T - def __init__(self, *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig): + def __init__(self, *, tokenizer_group: BaseTokenizerGroup, + model_config: ModelConfig): self.model_config = model_config self.tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.cache: dict[GuidedDecodingKey, GrammarCache] = {} + self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} self.executor = ThreadPoolExecutor() self._lock = threading.Lock() @@ -136,7 +146,7 @@ class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]): _compiler_cache: dict[str, xgr.GrammarCompiler] = {} _compiler: xgr.GrammarCompiler | None = None - def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: + def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar: request_type, grammar_spec = key compiler = XGrammarManager.get_compiler(self.tokenizer) if request_type == "json": diff --git a/vllm/v1/core/guided_decoding/grammar.py b/vllm/v1/core/guided_decoding/grammar.py index baa9b6c693ca0..b634ae169a393 100644 --- a/vllm/v1/core/guided_decoding/grammar.py +++ b/vllm/v1/core/guided_decoding/grammar.py @@ -4,17 +4,19 @@ from typing import (TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, overload) +from typing_extensions import Annotated, LiteralString + from vllm.utils import LazyLoader if TYPE_CHECKING: import torch import xgrammar as xgr - from typing_extensions import LiteralString, Self + from typing_extensions import Self else: xgr = LazyLoader("xgr", globals(), "xgrammar") torch = LazyLoader("torch", globals(), "torch") -T = TypeVar("T", bound=str) +T = TypeVar("T", bound=Annotated[LiteralString, str]) class Grammar(ABC, Generic[T]): @@ -93,13 +95,19 @@ def from_backend( @overload @classmethod - def from_backend(cls, - backend: LiteralString = ..., - **kwargs: Any) -> Grammar: + def from_backend( + cls, + backend: Literal["outlines"] = ..., + *, + guide: str = ..., + whitespace_pattern: str | None = ..., + ) -> XGrammar: ... @classmethod - def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Grammar[T]: + def from_backend(cls, + backend: LiteralString = "xgrammar", + **kwargs: Any) -> Grammar[T]: grammar_cls = cls._registry.get(backend) if grammar_cls is None: raise ValueError( @@ -108,8 +116,7 @@ def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Gra class XGrammar(Grammar[Literal["xgrammar"]]): - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string - # for jump-forward decoding + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, ctx: xgr.CompiledGrammar) -> None: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 97bb87f996e93..a404e4268923c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,9 +1,10 @@ from __future__ import annotations from collections import deque +from concurrent import futures from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Optional, - Set, Tuple, Union) +from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Literal, + Optional, Set, Tuple, Union) from vllm.config import (CacheConfig, DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -70,22 +71,6 @@ def __init__( self.waiting: Deque[Request] = deque() self.running: List[Request] = [] - # A set of unready requests that might be waiting for grammar compilation - # we can also use this for tracking spec decode request - self.grammar_queue: Deque[Request] = deque() - # initialize the tokenizer on the scheduler (this is used for constrained decoding) - tokenizer_group = init_tokenizer_from_configs( - model_config=model_config, - scheduler_config=scheduler_config, - parallel_config=parallel_config, - lora_config=lora_config) - tokenizer_group.ping() - # setup guided decoding, right now uses xgrammar - self.guided_decoding_manager = GuidedDecodingManager[Any].from_backend( - backend=decoding_config.guided_decoding_backend, - tokenizer_group=tokenizer_group, - model_config=model_config) - # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished # requests so that they can free the cached states for those requests. @@ -117,6 +102,21 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) + # A request queue for grammar compilation + self.grammar: Deque[Request] = deque() + # initialize the tokenizer on the scheduler (this is used for constrained decoding) + tokenizer_group = init_tokenizer_from_configs( + model_config=model_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + lora_config=lora_config) + tokenizer_group.ping() + # setup guided decoding, right now uses xgrammar + self.guided_decoding_manager = GuidedDecodingManager.from_backend( + backend=decoding_config.guided_decoding_backend, + tokenizer_group=tokenizer_group, + model_config=model_config) + def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -133,21 +133,20 @@ def schedule(self) -> "SchedulerOutput": preempted_reqs: List[Request] = [] # we need to check the grammar queue for any requests that have finished FSM compilation - newly_ready_reqs: List[Request] = [] - remaining_grammar_reqs: Deque[Request] = deque() - while self.grammar_queue: - request = self.grammar_queue.popleft() - grammar = self.guided_decoding_manager.get(request) - if grammar is not None: - request.grammar = grammar + newly_grammar_reqs: List[Request] = [] + scheduled_grammar_reqs: Deque[Request] = deque() + while self.grammar: + request = self.grammar.popleft() + try: + request.grammar = request.grammar.result(timeout=0.05) request.status = RequestStatus.WAITING - newly_ready_reqs.append(request) - else: - remaining_grammar_reqs.append(request) - self.grammar_queue = remaining_grammar_reqs + newly_grammar_reqs.append(request) + except futures._base.TimeoutError: + scheduled_grammar_reqs.append(request) + self.grammar = scheduled_grammar_reqs # append all newly ready requests to waiting queue with higher priority - for req in newly_ready_reqs: + for req in newly_grammar_reqs: self.waiting.appendleft(req) req_to_new_block_ids: Dict[str, List[int]] = {} @@ -161,7 +160,6 @@ def schedule(self) -> "SchedulerOutput": vocab_size = self.model_config.get_vocab_size() guided_decoding_bitmasks: Dict[str, torch.Tensor] = {} guided_decoding_reqs: List[Request] = [] - batch_has_grammar = False # First, schedule the RUNNING requests. # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be @@ -242,7 +240,6 @@ def schedule(self) -> "SchedulerOutput": # Track if we need guided decoding # Create individual bitmask for requests with grammar if request.grammar is not None: - batch_has_grammar = True if request.request_id not in guided_decoding_bitmasks: bitmask = request.grammar.allocate_bitmask(1, vocab_size) guided_decoding_bitmasks[request.request_id] = bitmask @@ -266,7 +263,6 @@ def schedule(self) -> "SchedulerOutput": # Track guided decoding needs if request.grammar is not None: - batch_has_grammar = True if request.request_id not in guided_decoding_bitmasks: bitmask = request.grammar.allocate_bitmask( 1, vocab_size) @@ -568,10 +564,10 @@ def _check_stop(self, request: Request) -> bool: def add_request(self, request: Request) -> None: self.requests[request.request_id] = request - add_to_grammar_queue = self.guided_decoding_manager.collect(request) - - if add_to_grammar_queue: self.grammar_queue.append(request) - else: self.waiting.append(request) + if self.guided_decoding_manager.collect(request): + self.grammar.append(request) + else: + self.waiting.append(request) def finish_requests( self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b744d342870d9..d2a3b6e80d798 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -11,6 +11,8 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: + from concurrent.futures import Future + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange @@ -81,7 +83,7 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # grammar objects - self.grammar = grammar + self.grammar: Optional[Grammar[Any] | Future[Grammar[Any]]] = grammar @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9e9a5febd1b46..5d62b09c6312d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -169,7 +169,7 @@ def __init__( self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() - self.guided_decoding_reqs: Set[str] = set() + self.grammar_reqs: Set[str] = set() def add_request( self, @@ -237,7 +237,7 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) - if request.grammar is not None: self.guided_decoding_reqs.add(req_id) + if request.grammar is not None: self.grammar_reqs.add(req_id) def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) @@ -255,7 +255,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) - self.guided_decoding_reqs.discard(req_id) + self.grammar_reqs.discard(req_id) return req_index def clear(self) -> None: @@ -271,7 +271,7 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() - self.guided_decoding_reqs.clear() + self.grammar_reqs.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: From 39068c88abbe2c2f2b43aa29360b3ff628449099 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 28 Jan 2025 17:30:49 +0000 Subject: [PATCH 03/84] chore: remove lazy loader Signed-off-by: Aaron Pham --- vllm/utils.py | 68 ------------------------ vllm/v1/core/guided_decoding/__init__.py | 5 +- vllm/v1/core/guided_decoding/grammar.py | 9 +--- 3 files changed, 3 insertions(+), 79 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index d31434b612a05..46a72bbe047a4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2207,71 +2207,3 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) - - -class LazyLoader(types.ModuleType): - """ - LazyLoader module borrowed from Tensorflow - https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py - with a addition of "module caching". This will throw an - exception if module cannot be imported. - - Lazily import a module, mainly to avoid pulling in large dependencies. - `contrib`, and `ffmpeg` are examples of modules that are large and not always - needed, and this allows them to only be loaded when they are used. - """ - - def __init__( - self, - local_name: str, - parent_module_globals: Dict[str, Any], - name: str, - warning: Optional[str] = None, - exc_msg: Optional[str] = None, - exc: Type[Exception] = Exception, - ): - self._local_name = local_name - self._parent_module_globals = parent_module_globals - self._warning = warning - self._exc_msg = exc_msg - self._exc = exc - self._module: types.ModuleType | None = None - - super().__init__(str(name)) - - def _load(self) -> types.ModuleType: - """Load the module and insert it into the parent's globals.""" - from . import warn_deprecated - - # Import the target module and insert it into the parent's namespace - try: - module = importlib.import_module(self.__name__) - self._parent_module_globals[self._local_name] = module - # The additional add to sys.modules ensures library is actually loaded. - sys.modules[self._local_name] = module - except ModuleNotFoundError as err: - raise self._exc(f"{self._exc_msg} (reason: {err})") from None - - # Emit a warning if one was specified - if self._warning: - warnings.warn(self._warning, - category=DeprecationWarning, - stacklevel=4) - # Make sure to only warn once. - self._warning = None - - # Update this object's dict so that if someone keeps a reference to the - # LazyLoader, lookups are efficient (__getattr__ is only called on lookups - # that fail). - self.__dict__.update(module.__dict__) - return module - - def __getattr__(self, item: Any) -> Any: - if self._module is None: - self._module = self._load() - return getattr(self._module, item) - - def __dir__(self) -> List[str]: - if self._module is None: - self._module = self._load() - return dir(self._module) diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py index 20a1468e1329d..3032a41857e35 100644 --- a/vllm/v1/core/guided_decoding/__init__.py +++ b/vllm/v1/core/guided_decoding/__init__.py @@ -8,24 +8,21 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args from transformers import PreTrainedTokenizer +import xgrammar as xgr from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.utils import LazyLoader from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus from .grammar import Grammar if TYPE_CHECKING: - import xgrammar as xgr from transformers import PreTrainedTokenizer from typing_extensions import LiteralString from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from .grammar import XGrammar -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") logger = init_logger(__name__) diff --git a/vllm/v1/core/guided_decoding/grammar.py b/vllm/v1/core/guided_decoding/grammar.py index b634ae169a393..feee130e5e3e1 100644 --- a/vllm/v1/core/guided_decoding/grammar.py +++ b/vllm/v1/core/guided_decoding/grammar.py @@ -5,16 +5,11 @@ overload) from typing_extensions import Annotated, LiteralString - -from vllm.utils import LazyLoader +import torch +import xgrammar as xgr if TYPE_CHECKING: - import torch - import xgrammar as xgr from typing_extensions import Self -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - torch = LazyLoader("torch", globals(), "torch") T = TypeVar("T", bound=Annotated[LiteralString, str]) From 2bb535e7d17b12217560567be004341329bc21fa Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 30 Jan 2025 16:27:20 +0000 Subject: [PATCH 04/84] fix: update types and attach bitmask to requests Signed-off-by: Aaron Pham --- vllm/utils.py | 1 - vllm/v1/core/guided_decoding/__init__.py | 114 +++-------------------- vllm/v1/core/guided_decoding/grammar.py | 105 +-------------------- vllm/v1/core/scheduler.py | 45 ++++----- vllm/v1/request.py | 28 ++++-- 5 files changed, 62 insertions(+), 231 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 46a72bbe047a4..17bffd2846b46 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,7 +22,6 @@ import threading import time import traceback -import types import uuid import warnings import weakref diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py index 3032a41857e35..7561aded48bb9 100644 --- a/vllm/v1/core/guided_decoding/__init__.py +++ b/vllm/v1/core/guided_decoding/__init__.py @@ -1,13 +1,11 @@ from __future__ import annotations -import copy +import copy, enum import threading -from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args +from typing import TYPE_CHECKING, TypeVar -from transformers import PreTrainedTokenizer import xgrammar as xgr from vllm.config import ModelConfig @@ -17,8 +15,7 @@ from .grammar import Grammar if TYPE_CHECKING: - from transformers import PreTrainedTokenizer - from typing_extensions import LiteralString + from typing_extensions import Self from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -31,18 +28,11 @@ @dataclass class GrammarCache: - value: Grammar | None + value: Optional[Grammar] event: threading.Event -T = TypeVar("T", bound=str) - - -class GuidedDecodingManager(ABC, Generic[T]): - - @abstractmethod - def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: - ... +class GuidedDecodingManager: def flush(self): with self._lock: @@ -84,68 +74,21 @@ def collect(self, request: Request): return True return False - @classmethod - def from_backend(cls, - backend: LiteralString = "xgrammar", - /, - *, - tokenizer_group: BaseTokenizerGroup, - model_config: ModelConfig) -> GuidedDecodingManager[T]: - manager_cls = cls._registry.get(backend) - if manager_cls is None: - raise ValueError( - f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}" - ) - return manager_cls(tokenizer_group=tokenizer_group, - model_config=model_config) - - _registry: dict[str, type[GuidedDecodingManager[T]]] = {} - _backend: T - - def __init__(self, *, tokenizer_group: BaseTokenizerGroup, + def __init__(self, *, backend: str, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig): + self._backend = backend self.model_config = model_config self.tokenizer = tokenizer_group.get_lora_tokenizer(None) self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} self.executor = ThreadPoolExecutor() self._lock = threading.Lock() - - def __init_subclass__(cls, **kwargs: Any): - if not hasattr(cls, '__orig_bases__'): - raise TypeError( - f"{cls.__qualname__} must be subclass of GuidedDecodingManager" - ) - - backend = None - for base in cls.__orig_bases__: - if (origin := get_args(base)) and issubclass( - base.__origin__, GuidedDecodingManager): - backend = get_args(origin[0])[0] - break - - if backend is None: - raise TypeError( - f"Class {cls.__qualname__} must specify backend as a Literal type" - ) - - if backend in cls._registry: - name = cls._registry[backend].__qualname__ - raise ValueError( - f"Backend '{backend}' is already registered to {name}") - - # Set the backend value from the Literal type - cls._backend = backend cls._registry[backend] = cls - -class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]): - # cache GrammarCompiler instances based on given tokenizer - _compiler_cache: dict[str, xgr.GrammarCompiler] = {} - _compiler: xgr.GrammarCompiler | None = None - - def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar: + def initialize_cache(self, key: GuidedDecodingKey) -> Self: request_type, grammar_spec = key - compiler = XGrammarManager.get_compiler(self.tokenizer) + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, stop_token_ids=stop_token_ids, vocab_size=vocab_size) + compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=max_threads) if request_type == "json": if type(grammar_spec) is not str: ctx = compiler.compile_builtin_json_grammar() @@ -155,35 +98,6 @@ def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar: ctx = compiler.compile_grammar(grammar_spec) else: raise ValueError("grammar is not of valid supported types.") - return Grammar.from_backend( - self._backend, - matcher=xgr.GrammarMatcher(ctx), - vocab_size=self.model_config.hf_text_config.vocab_size, - ctx=ctx) - - def flush(self): - super().flush() - if self._compiler: self._compiler.clear_cache() - for compiler in self._compiler_cache.values(): - compiler.clear_cache() - self._compiler_cache.clear() - - @classmethod - def get_compiler( - cls, - tokenizer: PreTrainedTokenizer, - *, - max_threads: int = 8, - # passthrough to TokenizerInfo - vocab_size: int | None = None, - stop_token_ids: list[int] | int | None = None - ) -> xgr.GrammarCompiler: - cache_key = str(hash(tokenizer)) - if cache_key not in cls._compiler_cache: - tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, - stop_token_ids=stop_token_ids, - vocab_size=vocab_size) - cls._compiler_cache[cache_key] = xgr.GrammarCompiler( - tokenizer_info, max_threads=max_threads) - return cls._compiler_cache[cache_key] + return Grammar(matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.model_config.hf_text_config.vocab_size, + ctx=ctx) diff --git a/vllm/v1/core/guided_decoding/grammar.py b/vllm/v1/core/guided_decoding/grammar.py index feee130e5e3e1..2b4c294e5100d 100644 --- a/vllm/v1/core/guided_decoding/grammar.py +++ b/vllm/v1/core/guided_decoding/grammar.py @@ -14,103 +14,8 @@ T = TypeVar("T", bound=Annotated[LiteralString, str]) -class Grammar(ABC, Generic[T]): +class Grammar: finished: bool = False - - @abstractmethod - def accept_token(self, token: int) -> bool: - """Whether to accept the token and advance the machine state.""" - - @abstractmethod - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: - """Fill the bitmask for the token at the given index.""" - - @abstractmethod - def allocate_bitmask(self, batch_size: int, - vocab_size: int) -> torch.Tensor: - """Allocate a bitmask for the given batch size and vocabulary size.""" - - @staticmethod - @abstractmethod - def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - """Apply the bitmask to the logits.""" - - @abstractmethod - def reset(self): - """Reset the machine state.""" - - @abstractmethod - def copy(self) -> Self: - """Copy the grammar object.""" - - def __copy__(self): - return self.copy() - - _registry: dict[str, type[Grammar[T]]] = {} - _backend: T - - def __init_subclass__(cls): - if not hasattr(cls, '__orig_bases__'): - raise TypeError( - f"Class {cls.__qualname__} must be a subclass of GrammarObject" - ) - - backend = None - for base in cls.__orig_bases__: - if (origin := get_args(base)) and issubclass( - base.__origin__, Grammar): - backend = get_args(origin[0])[0] - break - - if backend is None: - raise TypeError( - f"Class {cls.__qualname__} must specify backend as Literal type" - ) - - if backend in cls._registry: - name = cls._registry[backend].__qualname__ - raise ValueError( - f"Backend '{backend}' is already registered to {name}") - - # Set the backend value from the Literal type - cls._backend = backend - cls._registry[backend] = cls - - @overload - @classmethod - def from_backend( - cls, - backend: Literal["xgrammar"] = ..., - *, - matcher: xgr.GrammarMatcher = ..., - vocab_size: int = ..., - ctx: xgr.CompiledGrammar = ..., - ) -> XGrammar: - ... - - @overload - @classmethod - def from_backend( - cls, - backend: Literal["outlines"] = ..., - *, - guide: str = ..., - whitespace_pattern: str | None = ..., - ) -> XGrammar: - ... - - @classmethod - def from_backend(cls, - backend: LiteralString = "xgrammar", - **kwargs: Any) -> Grammar[T]: - grammar_cls = cls._registry.get(backend) - if grammar_cls is None: - raise ValueError( - f"No grammar implementation registered for '{backend}'") - return grammar_cls(**kwargs) - - -class XGrammar(Grammar[Literal["xgrammar"]]): # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, @@ -135,15 +40,15 @@ def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: @staticmethod def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - # Note: In this method, if the tensors have different dimensions - # on CPU device fails, but on GPU it runs without error. Hence the - # unsqueeze above for scores, to match the token bitmask shape xgr.apply_token_bitmask_inplace(logits, vocab_mask) def reset(self): self.matcher.reset() def copy(self): - return XGrammar(matcher=xgr.GrammarMatcher(self.ctx), + return Grammar(matcher=xgr.GrammarMatcher(self.ctx), vocab_size=self.vocab_size, ctx=self.ctx) + + def __copy__(self): + return self.copy() diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a404e4268923c..51c281f56dbcb 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -3,6 +3,7 @@ from collections import deque from concurrent import futures from dataclasses import dataclass +from re import A from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union) @@ -38,7 +39,7 @@ def __init__( model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig, - lora_config: LoRAConfig | None, + lora_config: Optional[LoRAConfig], decoding_config: DecodingConfig, ) -> None: self.scheduler_config = scheduler_config @@ -112,7 +113,7 @@ def __init__( lora_config=lora_config) tokenizer_group.ping() # setup guided decoding, right now uses xgrammar - self.guided_decoding_manager = GuidedDecodingManager.from_backend( + self.guided_decoding_manager = GuidedDecodingManager( backend=decoding_config.guided_decoding_backend, tokenizer_group=tokenizer_group, model_config=model_config) @@ -138,6 +139,8 @@ def schedule(self) -> "SchedulerOutput": while self.grammar: request = self.grammar.popleft() try: + # When request first added via add_request, then it will be a future call + # check timeout and add it directly to previous queue request.grammar = request.grammar.result(timeout=0.05) request.status = RequestStatus.WAITING newly_grammar_reqs.append(request) @@ -158,8 +161,6 @@ def schedule(self) -> "SchedulerOutput": # Create a shared bitmask tensor for the whole batch vocab_size = self.model_config.get_vocab_size() - guided_decoding_bitmasks: Dict[str, torch.Tensor] = {} - guided_decoding_reqs: List[Request] = [] # First, schedule the RUNNING requests. # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be @@ -256,17 +257,8 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[0] - # Skip requests waiting for FSM - if request.status == RequestStatus.WAITING_FOR_FSM: - self.waiting.rotate(-1) - continue - - # Track guided decoding needs - if request.grammar is not None: - if request.request_id not in guided_decoding_bitmasks: - bitmask = request.grammar.allocate_bitmask( - 1, vocab_size) - guided_decoding_bitmasks[request.request_id] = bitmask + # allocate bitmask on request on first round + if request.grammar: request.allocate_grammar_bitmask(vocab_size=vocab_size) # Get already-cached tokens. computed_blocks, num_computed_tokens = \ @@ -365,7 +357,7 @@ def schedule(self) -> "SchedulerOutput": running_reqs_data = [ self._make_running_request_data( req, req_to_new_block_ids[req.request_id], - req.num_computed_tokens) for req in scheduled_running_reqs + req.num_computed_tokens, grammar=req.grammar, grammar_bitmask=req.grammar_bitmask) for req in scheduled_running_reqs ] preempted_req_ids = {req.request_id for req in preempted_reqs} @@ -395,6 +387,9 @@ def _make_running_request_data( request: Request, new_block_ids: List[int], num_computed_tokens: int, + *, + grammar: Optional[Grammar] = None, + grammar_bitmask: Optional[Any] = None, ) -> "RunningRequestData": # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating # them at each scheduling step. @@ -402,6 +397,8 @@ def _make_running_request_data( req_data = self.running_reqs_data[request.request_id] req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens + req_data.grammar = grammar + req_data.grammar_bitmask=grammar_bitmask else: req_data = RunningRequestData.from_request(request, new_block_ids, num_computed_tokens) @@ -632,6 +629,9 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] + grammar: Optional[Grammar] + grammar_bitmask: Any + @classmethod def from_request( cls, @@ -648,7 +648,7 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, - grammar=request.grammar) + grammar=request.grammar, grammar_bitmask=request.grammar_bitmask) @dataclass @@ -657,7 +657,9 @@ class ResumedRequestData: req_id: str block_ids: List[int] num_computed_tokens: int + grammar: Optional[Grammar] + grammar_bitmask: Any @classmethod def from_request( @@ -669,7 +671,7 @@ def from_request( return cls(req_id=request.request_id, block_ids=block_ids, num_computed_tokens=num_computed_tokens, - grammar=request.grammar) + grammar=request.grammar, grammar_bitmask=request.grammar_bitmask) @dataclass @@ -678,7 +680,9 @@ class RunningRequestData: req_id: str new_block_ids: List[int] num_computed_tokens: int + grammar: Optional[Grammar] + grammar_bitmask: Any @classmethod def from_request( @@ -690,7 +694,7 @@ def from_request( return cls(req_id=request.request_id, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, - grammar=request.grammar) + grammar=request.grammar, grammar_bitmask=request.grammar_bitmask) @dataclass @@ -705,9 +709,6 @@ class SchedulerOutput: scheduled_encoder_inputs: Dict[str, List[int]] num_common_prefix_blocks: int - # request_id -> bitmask - guided_decoding_bitmasks: Dict[str, torch.Tensor] - preempted_req_ids: Set[str] finished_req_ids: Set[str] free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d2a3b6e80d798..861507813ffd4 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -19,10 +19,14 @@ from vllm.v1.core.guided_decoding import Grammar from vllm.v1.core.kv_cache_utils import BlockHashType -GuidedDecodingObject = Union[str, Dict[str, Any]] -GuidedDecodingKey = Tuple[Literal["json", "regex", "grammar", "choice"], - GuidedDecodingObject] +class GuidedDecodingOptions(enum.Enum): + json = enum.auto() + regex = enum.auto() + grammar = enum.auto() + choice = enum.auto() +GuidedDecodingObject = Union[str, Dict[str, Any]] +GuidedDecodingKey = Tuple[GuidedDecodingOptions, GuidedDecodingObject] class Request: @@ -83,7 +87,8 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # grammar objects - self.grammar: Optional[Grammar[Any] | Future[Grammar[Any]]] = grammar + self.grammar: Optional[Union[Grammar, Future[Grammar]]] = grammar + self._grammar_bitmask = None @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": @@ -150,15 +155,22 @@ def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: def use_guided_decoding(self) -> bool: return self.sampling_params.guided_decoding is not None + @property + def grammar_bitmask(self): return self._grammar_bitmask + @cached_property def guided_decoding_key(self) -> GuidedDecodingKey: params = self.sampling_params.guided_decoding - if params.json is not None: return ("json", params.json) - elif params.regex is not None: return ("regex", params.regex) - elif params.choice is not None: return ("choice", params.choice) - elif params.grammar is not None: return ("grammar", params.grammar) + if params.json is not None: return (GuidedDecodingOptions.json, params.json) + elif params.regex is not None: return (GuidedDecodingOptions.regex, params.regex) + elif params.choice is not None: return (GuidedDecodingOptions.choice, params.choice) + elif params.grammar is not None: return (GuidedDecodingOptions.grammar, params.grammar) else: raise ValueError("No valid guided decoding parameter found") + def allocate_grammar_bitmask(self, vocab_size: int): + if self._grammar_bitmask is None: self._grammar_bitmask = self.grammar.allocate_bitmask(1, vocab_size=vocab_size) + return self._grammar_bitmask + class RequestStatus(enum.IntEnum): """Status of a request.""" From 420f52f934c5421bbfaf6ca71d09004de1999512 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 30 Jan 2025 23:34:08 +0000 Subject: [PATCH 05/84] chore: --wip-- Signed-off-by: Aaron Pham --- vllm/v1/core/guided_decoding/__init__.py | 72 +++++++++++++++------- vllm/v1/core/guided_decoding/grammar.py | 54 ----------------- vllm/v1/core/scheduler.py | 77 ++++++------------------ vllm/v1/engine/core.py | 26 ++++++++ vllm/v1/request.py | 34 +++++++---- vllm/v1/worker/gpu_input_batch.py | 2 +- 6 files changed, 119 insertions(+), 146 deletions(-) delete mode 100644 vllm/v1/core/guided_decoding/grammar.py diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py index 7561aded48bb9..9469c653fe2ed 100644 --- a/vllm/v1/core/guided_decoding/__init__.py +++ b/vllm/v1/core/guided_decoding/__init__.py @@ -1,29 +1,59 @@ from __future__ import annotations -import copy, enum +import copy import threading from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Optional +import torch import xgrammar as xgr -from vllm.config import ModelConfig -from vllm.logger import init_logger +from vllm.config import VllmConfig from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus -from .grammar import Grammar - if TYPE_CHECKING: - from typing_extensions import Self - from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup - from .grammar import XGrammar +__all__ = ["Grammar", "GuidedDecodingManager"] + -logger = init_logger(__name__) +class Grammar: + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding -__all__ = ["Grammar", "GuidedDecodingManager"] + def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, + ctx: xgr.CompiledGrammar) -> None: + self.matcher = matcher + self.vocab_size = vocab_size + self.ctx = ctx + + def accept_token(self, token: int) -> bool: + # NOTE: accept_token will determines whether we accept this token + # and will also update the machine state + return self.matcher.accept_token(token) + + def allocate_bitmask(self, batch_size: int, + vocab_size: int) -> torch.Tensor: + return xgr.allocate_token_bitmask(batch_size, vocab_size) + + # this should be ran in parallel with model decoding + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(bitmask, idx) + + @staticmethod + def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + xgr.apply_token_bitmask_inplace(logits, vocab_mask) + + def reset(self): + self.matcher.reset() + + def copy(self): + return Grammar(matcher=xgr.GrammarMatcher(self.ctx), + vocab_size=self.vocab_size, + ctx=self.ctx) + + def __copy__(self): + return self.copy() @dataclass @@ -74,20 +104,17 @@ def collect(self, request: Request): return True return False - def __init__(self, *, backend: str, tokenizer_group: BaseTokenizerGroup, - model_config: ModelConfig): - self._backend = backend - self.model_config = model_config + def __init__(self, *, vllm_config: VllmConfig, + tokenizer_group: BaseTokenizerGroup): + self.vllm_config = vllm_config self.tokenizer = tokenizer_group.get_lora_tokenizer(None) self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} self.executor = ThreadPoolExecutor() self._lock = threading.Lock() - cls._registry[backend] = cls - def initialize_cache(self, key: GuidedDecodingKey) -> Self: + def initialize_cache(self, key: GuidedDecodingKey, max_threads: int = 8): request_type, grammar_spec = key - tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, stop_token_ids=stop_token_ids, vocab_size=vocab_size) + tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=max_threads) if request_type == "json": if type(grammar_spec) is not str: @@ -98,6 +125,7 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Self: ctx = compiler.compile_grammar(grammar_spec) else: raise ValueError("grammar is not of valid supported types.") - return Grammar(matcher=xgr.GrammarMatcher(ctx), - vocab_size=self.model_config.hf_text_config.vocab_size, - ctx=ctx) + return Grammar( + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vllm_config.model_config.hf_text_config.vocab_size, + ctx=ctx) diff --git a/vllm/v1/core/guided_decoding/grammar.py b/vllm/v1/core/guided_decoding/grammar.py deleted file mode 100644 index 2b4c294e5100d..0000000000000 --- a/vllm/v1/core/guided_decoding/grammar.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, - overload) - -from typing_extensions import Annotated, LiteralString -import torch -import xgrammar as xgr - -if TYPE_CHECKING: - from typing_extensions import Self - -T = TypeVar("T", bound=Annotated[LiteralString, str]) - - -class Grammar: - finished: bool = False - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding - - def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, - ctx: xgr.CompiledGrammar) -> None: - # TODO: support max_rollback_tokens - self.matcher = matcher - self.vocab_size = vocab_size - self.ctx = ctx - - def accept_token(self, token: int) -> bool: - # NOTE: accept_token will determines whether we accept this token - # and will also update the machine state - return self.matcher.accept_token(token) - - def allocate_bitmask(self, batch_size: int, - vocab_size: int) -> torch.Tensor: - return xgr.allocate_token_bitmask(batch_size, vocab_size) - - # this should be ran in parallel with model decoding - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: - self.matcher.fill_next_token_bitmask(bitmask, idx) - - @staticmethod - def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - xgr.apply_token_bitmask_inplace(logits, vocab_mask) - - def reset(self): - self.matcher.reset() - - def copy(self): - return Grammar(matcher=xgr.GrammarMatcher(self.ctx), - vocab_size=self.vocab_size, - ctx=self.ctx) - - def __copy__(self): - return self.copy() diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 51c281f56dbcb..a1de5b20475e3 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -14,8 +14,7 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.guided_decoding import GuidedDecodingManager -from vllm.v1.core.guided_decoding.grammar import Grammar +from vllm.v1.core.guided_decoding import Grammar from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats @@ -40,13 +39,11 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, lora_config: Optional[LoRAConfig], - decoding_config: DecodingConfig, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config self.model_config = model_config - self.decoding_config = decoding_config # TODO: Support LoRA. assert lora_config is None, "V1 does not support LoRA yet." # Scheduling constraints. @@ -103,21 +100,6 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) - # A request queue for grammar compilation - self.grammar: Deque[Request] = deque() - # initialize the tokenizer on the scheduler (this is used for constrained decoding) - tokenizer_group = init_tokenizer_from_configs( - model_config=model_config, - scheduler_config=scheduler_config, - parallel_config=parallel_config, - lora_config=lora_config) - tokenizer_group.ping() - # setup guided decoding, right now uses xgrammar - self.guided_decoding_manager = GuidedDecodingManager( - backend=decoding_config.guided_decoding_backend, - tokenizer_group=tokenizer_group, - model_config=model_config) - def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -133,25 +115,6 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - # we need to check the grammar queue for any requests that have finished FSM compilation - newly_grammar_reqs: List[Request] = [] - scheduled_grammar_reqs: Deque[Request] = deque() - while self.grammar: - request = self.grammar.popleft() - try: - # When request first added via add_request, then it will be a future call - # check timeout and add it directly to previous queue - request.grammar = request.grammar.result(timeout=0.05) - request.status = RequestStatus.WAITING - newly_grammar_reqs.append(request) - except futures._base.TimeoutError: - scheduled_grammar_reqs.append(request) - self.grammar = scheduled_grammar_reqs - - # append all newly ready requests to waiting queue with higher priority - for req in newly_grammar_reqs: - self.waiting.appendleft(req) - req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens @@ -238,13 +201,6 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget - # Track if we need guided decoding - # Create individual bitmask for requests with grammar - if request.grammar is not None: - if request.request_id not in guided_decoding_bitmasks: - bitmask = request.grammar.allocate_bitmask(1, vocab_size) - guided_decoding_bitmasks[request.request_id] = bitmask - # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: @@ -258,7 +214,8 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[0] # allocate bitmask on request on first round - if request.grammar: request.allocate_grammar_bitmask(vocab_size=vocab_size) + if request.grammar: + request.allocate_grammar_bitmask(vocab_size=vocab_size) # Get already-cached tokens. computed_blocks, num_computed_tokens = \ @@ -356,8 +313,12 @@ def schedule(self) -> "SchedulerOutput": ] running_reqs_data = [ self._make_running_request_data( - req, req_to_new_block_ids[req.request_id], - req.num_computed_tokens, grammar=req.grammar, grammar_bitmask=req.grammar_bitmask) for req in scheduled_running_reqs + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + grammar=req.grammar, + grammar_bitmask=req.grammar_bitmask) + for req in scheduled_running_reqs ] preempted_req_ids = {req.request_id for req in preempted_reqs} @@ -375,7 +336,6 @@ def schedule(self) -> "SchedulerOutput": # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - guided_decoding_bitmasks=guided_decoding_bitmasks, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), ) @@ -398,7 +358,7 @@ def _make_running_request_data( req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens req_data.grammar = grammar - req_data.grammar_bitmask=grammar_bitmask + req_data.grammar_bitmask = grammar_bitmask else: req_data = RunningRequestData.from_request(request, new_block_ids, num_computed_tokens) @@ -480,6 +440,8 @@ def update_from_output( scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", ) -> EngineCoreOutputs: + # concern: batchsize >>>1000 + # compilation << update # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -560,11 +522,7 @@ def _check_stop(self, request: Request) -> bool: def add_request(self, request: Request) -> None: self.requests[request.request_id] = request - - if self.guided_decoding_manager.collect(request): - self.grammar.append(request) - else: - self.waiting.append(request) + self.waiting.append(request) def finish_requests( self, @@ -648,7 +606,8 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, - grammar=request.grammar, grammar_bitmask=request.grammar_bitmask) + grammar=request.grammar, + grammar_bitmask=request.grammar_bitmask) @dataclass @@ -671,7 +630,8 @@ def from_request( return cls(req_id=request.request_id, block_ids=block_ids, num_computed_tokens=num_computed_tokens, - grammar=request.grammar, grammar_bitmask=request.grammar_bitmask) + grammar=request.grammar, + grammar_bitmask=request.grammar_bitmask) @dataclass @@ -694,7 +654,8 @@ def from_request( return cls(req_id=request.request_id, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, - grammar=request.grammar, grammar_bitmask=request.grammar_bitmask) + grammar=request.grammar, + grammar_bitmask=request.grammar_bitmask) @dataclass diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4cd2383738642..d3906d3adc651 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -16,6 +16,7 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx +from vllm.v1.core.guided_decoding import GuidedDecodingManager from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, @@ -67,6 +68,28 @@ def __init__( self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) + # initialize the tokenizer on the scheduler (this is used for constrained decoding) + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) + tokenizer_group.ping() + # setup guided decoding, right now uses xgrammar + self.guided_decoding_manager = GuidedDecodingManager( + vllm_config=vllm_config, tokenizer_group=tokenizer_group) + + # while self.grammar: + # request = self.grammar.popleft() + # try: + # # When request first added via add_request, then it will be a future call + # # check timeout and add it directly to previous queue + # request.grammar = request.grammar.result(timeout=0.05) + # request.status = RequestStatus.WAITING + # newly_grammar_reqs.append(request) + # except futures._base.TimeoutError: + # scheduled_grammar_reqs.append(request) + def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() @@ -127,6 +150,9 @@ def step(self) -> EngineCoreOutputs: scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) + # update FSM async here + # two broadcast (bitmask + calculate) <-- manager + # copy CPU -> CPU IPC (concat multiple bitmask?) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) return engine_core_outputs diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 861507813ffd4..79610f28d2655 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -19,15 +19,18 @@ from vllm.v1.core.guided_decoding import Grammar from vllm.v1.core.kv_cache_utils import BlockHashType + class GuidedDecodingOptions(enum.Enum): - json = enum.auto() - regex = enum.auto() - grammar = enum.auto() - choice = enum.auto() + json = enum.auto() + regex = enum.auto() + grammar = enum.auto() + choice = enum.auto() + GuidedDecodingObject = Union[str, Dict[str, Any]] GuidedDecodingKey = Tuple[GuidedDecodingOptions, GuidedDecodingObject] + class Request: def __init__( @@ -156,19 +159,28 @@ def use_guided_decoding(self) -> bool: return self.sampling_params.guided_decoding is not None @property - def grammar_bitmask(self): return self._grammar_bitmask + def grammar_bitmask(self): + return self._grammar_bitmask @cached_property def guided_decoding_key(self) -> GuidedDecodingKey: params = self.sampling_params.guided_decoding - if params.json is not None: return (GuidedDecodingOptions.json, params.json) - elif params.regex is not None: return (GuidedDecodingOptions.regex, params.regex) - elif params.choice is not None: return (GuidedDecodingOptions.choice, params.choice) - elif params.grammar is not None: return (GuidedDecodingOptions.grammar, params.grammar) - else: raise ValueError("No valid guided decoding parameter found") + assert params is not None, "params can't be None." + if params.json is not None: + return (GuidedDecodingOptions.json, params.json) + elif params.regex is not None: + return (GuidedDecodingOptions.regex, params.regex) + elif params.choice is not None: + return (GuidedDecodingOptions.choice, params.choice) + elif params.grammar is not None: + return (GuidedDecodingOptions.grammar, params.grammar) + else: + raise ValueError("No valid guided decoding parameter found") def allocate_grammar_bitmask(self, vocab_size: int): - if self._grammar_bitmask is None: self._grammar_bitmask = self.grammar.allocate_bitmask(1, vocab_size=vocab_size) + if self._grammar_bitmask is None: + self._grammar_bitmask = self.grammar.allocate_bitmask( + 1, vocab_size=vocab_size) return self._grammar_bitmask diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 5d62b09c6312d..f418075e51244 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,7 +9,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.core.guided_decoding.grammar import Grammar +from vllm.v1.core.guided_decoding import Grammar from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable From 9daf140fd2e8874d2810feba9f55d02b5561e3d1 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 8 Feb 2025 19:52:04 +0000 Subject: [PATCH 06/84] chore: --wip-- cleanup Signed-off-by: Aaron Pham --- vllm/v1/core/guided_decoding/__init__.py | 131 ----------------------- vllm/v1/engine/core.py | 2 +- vllm/v1/request.py | 21 +--- 3 files changed, 4 insertions(+), 150 deletions(-) delete mode 100644 vllm/v1/core/guided_decoding/__init__.py diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py deleted file mode 100644 index 9469c653fe2ed..0000000000000 --- a/vllm/v1/core/guided_decoding/__init__.py +++ /dev/null @@ -1,131 +0,0 @@ -from __future__ import annotations - -import copy -import threading -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional - -import torch -import xgrammar as xgr - -from vllm.config import VllmConfig -from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus - -if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup - -__all__ = ["Grammar", "GuidedDecodingManager"] - - -class Grammar: - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding - - def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, - ctx: xgr.CompiledGrammar) -> None: - self.matcher = matcher - self.vocab_size = vocab_size - self.ctx = ctx - - def accept_token(self, token: int) -> bool: - # NOTE: accept_token will determines whether we accept this token - # and will also update the machine state - return self.matcher.accept_token(token) - - def allocate_bitmask(self, batch_size: int, - vocab_size: int) -> torch.Tensor: - return xgr.allocate_token_bitmask(batch_size, vocab_size) - - # this should be ran in parallel with model decoding - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: - self.matcher.fill_next_token_bitmask(bitmask, idx) - - @staticmethod - def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - xgr.apply_token_bitmask_inplace(logits, vocab_mask) - - def reset(self): - self.matcher.reset() - - def copy(self): - return Grammar(matcher=xgr.GrammarMatcher(self.ctx), - vocab_size=self.vocab_size, - ctx=self.ctx) - - def __copy__(self): - return self.copy() - - -@dataclass -class GrammarCache: - value: Optional[Grammar] - event: threading.Event - - -class GuidedDecodingManager: - - def flush(self): - with self._lock: - self.grammar_cache.clear() - - def cache(self, request: Request): - - def _executor_loop(request: Request): - key = request.guided_decoding_key - with self._lock: - cache_hit = False - if key in self.grammar_cache: - cache_hit, entry = True, self.grammar_cache[key] - else: - entry = GrammarCache(None, threading.Event()) - self.grammar_cache[key] = entry - - if cache_hit: - entry.event.wait() - else: - entry.value = self.initialize_cache(key) - entry.event.set() - return copy.copy(entry.value) if entry.value else None - - return self.executor.submit(_executor_loop, request) - - def get(self, request: Request): - with self._lock: - entry = self.grammar_cache.get(request.guided_decoding_key) - if entry is None or not entry.event.is_set(): return None - return copy.copy(entry.value) if entry.value else None - - def collect(self, request: Request): - if not request.use_guided_decoding: return False - request.grammar = self.get(request) - if not request.grammar: - request.grammar = self.cache(request) - request.status = RequestStatus.WAITING_FOR_FSM - return True - return False - - def __init__(self, *, vllm_config: VllmConfig, - tokenizer_group: BaseTokenizerGroup): - self.vllm_config = vllm_config - self.tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} - self.executor = ThreadPoolExecutor() - self._lock = threading.Lock() - - def initialize_cache(self, key: GuidedDecodingKey, max_threads: int = 8): - request_type, grammar_spec = key - tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=max_threads) - if request_type == "json": - if type(grammar_spec) is not str: - ctx = compiler.compile_builtin_json_grammar() - else: - ctx = compiler.compile_json_schema(grammar_spec) - elif request_type == "grammar": - ctx = compiler.compile_grammar(grammar_spec) - else: - raise ValueError("grammar is not of valid supported types.") - return Grammar( - matcher=xgr.GrammarMatcher(ctx), - vocab_size=self.vllm_config.model_config.hf_text_config.vocab_size, - ctx=ctx) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0cd0820224c7b..0c063ff3bf416 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -17,9 +17,9 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx -from vllm.v1.core.guided_decoding import GuidedDecodingManager from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, EngineCoreRequestUnion, EngineCoreResetPrefixCache) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a424e057b0b99..96bdb392014fb 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -2,7 +2,8 @@ from __future__ import annotations import enum -from typing import TYPE_CHECKING, List, Optional, Union +import functools +from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any, Tuple from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics @@ -44,7 +45,6 @@ def __init__( eos_token_id: Optional[int], arrival_time: float, lora_request: Optional[LoRARequest] = None, - grammar: Optional[Grammar] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -90,10 +90,6 @@ def __init__( self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) - # grammar objects - self.grammar: Optional[Union[Grammar, Future[Grammar]]] = grammar - self._grammar_bitmask = None - @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( @@ -155,11 +151,7 @@ def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: def use_guided_decoding(self) -> bool: return self.sampling_params.guided_decoding is not None - @property - def grammar_bitmask(self): - return self._grammar_bitmask - - @cached_property + @functools.cached_property def guided_decoding_key(self) -> GuidedDecodingKey: params = self.sampling_params.guided_decoding assert params is not None, "params can't be None." @@ -174,17 +166,10 @@ def guided_decoding_key(self) -> GuidedDecodingKey: else: raise ValueError("No valid guided decoding parameter found") - def allocate_grammar_bitmask(self, vocab_size: int): - if self._grammar_bitmask is None: - self._grammar_bitmask = self.grammar.allocate_bitmask( - 1, vocab_size=vocab_size) - return self._grammar_bitmask - class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = 0 - WAITING_FOR_FSM = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED (2) will be considered From 15a454755c202f8c3414b470b82b0cb811a8209b Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 12 Feb 2025 06:32:28 +0000 Subject: [PATCH 07/84] feat: base implementation Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 8 +- vllm/v1/engine/core.py | 90 ++++++++++++++----- vllm/v1/guided_decoding/__init__.py | 133 ++++++++++++++++++++++++++++ vllm/v1/request.py | 43 ++++++++- vllm/v1/worker/gpu_input_batch.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 15 +++- 6 files changed, 257 insertions(+), 40 deletions(-) create mode 100644 vllm/v1/guided_decoding/__init__.py diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e5fa836795038..bdb629182854f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import time -from collections import deque -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union - from collections import deque from concurrent import futures from dataclasses import dataclass @@ -20,6 +17,7 @@ SchedulerOutput) from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) +from vllm.v1.guided_decoding import Grammar from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -359,8 +357,6 @@ def _make_cached_request_data( new_block_ids: List[int], num_computed_tokens: int, *, - grammar: Optional[Grammar] = None, - grammar_bitmask: Optional[Any] = None, resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -370,8 +366,6 @@ def _make_cached_request_data( req_data.resumed_from_preemption = resumed_from_preemption req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens - req_data.grammar = grammar - req_data.grammar_bitmask = grammar_bitmask else: req_data = CachedRequestData.from_request(request, resumed_from_preemption, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 610410e5ab8c3..30c49bfd7816f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -8,6 +8,7 @@ from typing import Any, List, Tuple, Type import psutil +import torch import zmq import zmq.asyncio @@ -15,6 +16,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler @@ -22,9 +24,8 @@ EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import GuidedDecodingOptions, Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -76,20 +77,14 @@ def __init__( parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) tokenizer_group.ping() - # setup guided decoding, right now uses xgrammar + self.tokenizer_group = tokenizer_group + self.use_guided_decoding = False + + # Initialize guided decoding manager + from vllm.v1.guided_decoding import GuidedDecodingManager self.guided_decoding_manager = GuidedDecodingManager( - vllm_config=vllm_config, tokenizer_group=tokenizer_group) - - # while self.grammar: - # request = self.grammar.popleft() - # try: - # # When request first added via add_request, then it will be a future call - # # check timeout and add it directly to previous queue - # request.grammar = request.grammar.result(timeout=0.05) - # request.status = RequestStatus.WAITING - # newly_grammar_reqs.append(request) - # except futures._base.TimeoutError: - # scheduled_grammar_reqs.append(request) + tokenizer_group=self.tokenizer_group, + model_config=vllm_config.model_config) def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: @@ -130,15 +125,17 @@ def add_request(self, request: EngineCoreRequest): request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) + if req.use_guided_decoding: + self.use_guided_decoding = True + # Start grammar compilation asynchronously + self.guided_decoding_manager.should_cache(req) + else: + self.use_guided_decoding = False self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): """Abort requests from the scheduler.""" - - # TODO: The scheduler doesn't really need to know the - # specific finish reason, TBD whether we propagate that - # (i.e. client-aborted vs stop criteria met). self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) @@ -149,13 +146,47 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) + # Calculate bitmasks for all active requests + if self.use_guided_decoding: + self.calculate_grammar_bitmasks() + scheduler_output = self.scheduler.schedule() + + # Attach bitmasks to scheduler output for broadcasting to workers + if self.use_guided_decoding: + scheduler_output.guided_decoding_bitmasks = { + req.request_id: req.bitmask + for req in self.scheduler.running + if req.use_guided_decoding and req.is_grammar_ready + } + output = self.model_executor.execute_model(scheduler_output) - # update FSM async here - # two broadcast (bitmask + calculate) <-- manager - # copy CPU -> CPU IPC (concat multiple bitmask?) + engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) + + if self.use_guided_decoding: + # Advance FSM for each request using guided decoding + for req in self.scheduler.running: + if not req.use_guided_decoding or not req.is_grammar_ready: + continue + + # Get the generated tokens for this request + if req.request_id in output.outputs: + generated_tokens = output.outputs[req.request_id].token_ids + # Advance FSM for each generated token + for token in generated_tokens: + if not req.grammar.accept_token(token): + # Token was rejected by grammar - mark request as finished with error + self.scheduler.finish_requests( + [req.request_id], + RequestStatus.FINISHED_GRAMMAR_ERROR) + break + + # Update bitmask for next token prediction if request is still running + if req.request_id not in self.scheduler.finished_requests: + req.grammar.fill_bitmask(req.bitmask, 0) + return engine_core_outputs def shutdown(self): @@ -167,6 +198,21 @@ def profile(self, is_start: bool = True): def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() + def calculate_grammar_bitmasks(self): + for req in self.scheduler.running: + # ignore requests that doesn't use guided decoding + # or ignore requests that grammar is not ready + if not req.use_guided_decoding or not req.is_grammar_ready: + continue + + # Check if grammar is ready in cache + grammar = self.guided_decoding_manager.get(req) + if grammar is not None: + req.grammar = grammar + req.allocate_bitmask(1, + self.guided_decoding_manager.vocab_size) + continue + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py new file mode 100644 index 0000000000000..d9e1119b37459 --- /dev/null +++ b/vllm/v1/guided_decoding/__init__.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import copy +import threading +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch +import xgrammar as xgr + +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + from vllm.v1.request import GuidedDecodingKey, Request + + +class Grammar: + finished: bool = False + + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding + + def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, + ctx: xgr.CompiledGrammar) -> None: + # TODO: support max_rollback_tokens + self.matcher = matcher + self.vocab_size = vocab_size + self.ctx = ctx + + def accept_token(self, token: int) -> bool: + # NOTE: accept_token will determines whether we accept this token + # and will also update the machine state + return self.matcher.accept_token(token) + + def allocate_bitmask(self, batch_size: int, + vocab_size: int) -> torch.Tensor: + return xgr.allocate_token_bitmask(batch_size, vocab_size) + + # this should be ran in parallel with model decoding + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(bitmask, idx) + + @staticmethod + def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + xgr.apply_token_bitmask_inplace(logits, vocab_mask) + + def reset(self): + self.matcher.reset() + + def copy(self): + return Grammar(matcher=xgr.GrammarMatcher(self.ctx), + vocab_size=self.vocab_size, + ctx=self.ctx) + + def __copy__(self): + return self.copy() + + +@dataclass +class GrammarCache: + value: Optional[Grammar] + event: threading.Event + + +class GuidedDecodingManager: + + def __init__(self, tokenizer_group: BaseTokenizerGroup, + model_config: ModelConfig): + self.tokenizer_group = tokenizer_group + self.model_config = model_config + self.vocab_size = model_config.get_vocab_size() + self.tokenizer = tokenizer_group.get_lora_tokenizer(None) + self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} + self.executor = ThreadPoolExecutor() + self._lock = threading.Lock() + + def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: + request_type, grammar_spec = key + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + self.tokenizer, vocab_size=self.vocab_size) + compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + + if request_type == "json": + if type(grammar_spec) is not str: + ctx = compiler.compile_builtin_json_grammar() + else: + ctx = compiler.compile_json_schema(grammar_spec) + elif request_type == "grammar": + ctx = compiler.compile_grammar(grammar_spec) + else: + raise ValueError("grammar is not of valid supported types.") + + return Grammar(matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.model_config.hf_text_config.vocab_size, + ctx=ctx) + + def should_cache(self, request: Request): + if not request.use_guided_decoding: return False + request.grammar = self.get(request) + if not request.grammar: + request.grammar = self.cache(request) + request.status = RequestStatus.WAITING_FOR_FSM + return True + return False + + def cache(self, request: Request): + return self.executor.submit(self._executor_loop, request) + + def _executor_loop(self, request: Request): + key = request.guided_decoding_key + with self._lock: + cache_hit = False + if key in self.grammar_cache: + cache_hit, entry = True, self.grammar_cache[key] + else: + entry = GrammarCache(None, threading.Event()) + self.grammar_cache[key] = entry + + if cache_hit: + entry.event.wait() + else: + entry.value = self.initialize_cache(key) + entry.event.set() + return copy.copy(entry.value) if entry.value else None + + def get(self, request: Request): + with self._lock: + entry = self.grammar_cache.get(request.guided_decoding_key) + if entry is None or not entry.event.is_set(): return None + return copy.copy(entry.value) if entry.value else None diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 8edcc305d40d8..6cb56588eb3a0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations + import enum import functools -from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any, Tuple +from concurrent.futures import Future +from concurrent.futures._base import TimeoutError +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, @@ -11,11 +14,12 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: - from concurrent.futures import Future + import torch from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.guided_decoding import Grammar class GuidedDecodingOptions(enum.Enum): @@ -79,6 +83,10 @@ def __init__( self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) + # Grammar fields, including the grammar object and the bitmask + self._grammar: Future[Grammar] | Grammar | None = None + self._bitmask = None + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( @@ -163,10 +171,40 @@ def guided_decoding_key(self) -> GuidedDecodingKey: else: raise ValueError("No valid guided decoding parameter found") + @property + def grammar(self) -> Optional[Grammar | Future[Grammar]]: + return self._grammar + + @grammar.setter + def grammar(self, grammar: Grammar | Future[Grammar]) -> None: + self._grammar = grammar + + def allocate_bitmask(self, batch_size: int, vocab_size: int) -> None: + if isinstance(self._grammar, Future): + try: + self.grammar = self.grammar.result(timeout=0.05) + self.status = RequestStatus.WAITING + except TimeoutError: + pass + if self.grammar: + self._bitmask = self.grammar.allocate_bitmask( + batch_size, vocab_size) + + @functools.cached_property + def bitmask(self) -> Optional[torch.Tensor]: + return self._bitmask + + @property + def is_grammar_ready(self) -> bool: + if isinstance(self._grammar, Future): + return self._grammar.done() + return self._grammar is not None + class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = 0 + WAITING_FOR_FSM = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED (2) will be considered @@ -175,6 +213,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_GRAMMAR_ERROR = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 11a3e9b1f0632..f77e51ce5b9b3 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple import numpy as np import torch @@ -12,7 +12,6 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.core.guided_decoding import Grammar from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable @@ -37,7 +36,6 @@ class CachedRequestState: mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None - grammar: Optional[Grammar] = None lora_request: Optional[LoRARequest] = None @@ -250,7 +248,8 @@ def add_request( if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs - if request.grammar is not None: self.grammar_reqs.add(req_id) + if request.grammar is not None: + self.grammar_reqs.add(req_id) # Add request lora ID if request.lora_request: @@ -366,7 +365,6 @@ def condense(self, empty_req_indices: List[int]) -> None: self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] - # Decrement last_req_index since it is now empty. last_req_index -= 1 def make_sampling_metadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dd42687cabfb4..6417702c1b74e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,6 +27,7 @@ FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.guided_decoding import Grammar from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -850,10 +851,16 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) - # We will need to apply the logits inplace from here - # so the scheduler_output should contains both the grammar - # of the running request to advance as well as the specific bitmask - # broadcasted from the scheduler.schedule() + # Apply guided decoding bitmasks if present + if hasattr(scheduler_output, 'guided_decoding_bitmasks'): + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + if req_id in scheduler_output.guided_decoding_bitmasks: + # TODO: We need to ensure that the bitmask + bitmask = scheduler_output.guided_decoding_bitmasks[req_id] + if bitmask is not None: + # Apply bitmask to logits + bitmask = bitmask.to(self.device, non_blocking=True) + Grammar.apply_bitmask(logits, bitmask) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(batch_changed) From 49f7b9602a0aad01b8ab342597aae3eca13cf481 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 12 Feb 2025 06:45:54 +0000 Subject: [PATCH 08/84] fix: update the states within the scheduler Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 34 ++++++++++-------- vllm/v1/core/scheduler_output.py | 32 ++++++++++------- vllm/v1/engine/core.py | 47 ++---------------------- vllm/v1/guided_decoding/__init__.py | 21 ++++++----- vllm/v1/worker/gpu_input_batch.py | 15 +++++--- vllm/v1/worker/gpu_model_runner.py | 55 ++++++++++++++++------------- 6 files changed, 95 insertions(+), 109 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index bdb629182854f..6c6e51b4fd28c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -2,11 +2,7 @@ import time from collections import deque -from concurrent import futures -from dataclasses import dataclass -from re import A -from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Literal, - Optional, Set, Tuple, Union) +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger @@ -17,7 +13,6 @@ SchedulerOutput) from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) -from vllm.v1.guided_decoding import Grammar from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -109,6 +104,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_resumed_reqs: List[Request] = [] scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] + guided_decoding_request_ids: Set[str] = set() req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} @@ -124,11 +120,6 @@ def schedule(self) -> "SchedulerOutput": while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - # Skip requests waiting for FSM - if request.status == RequestStatus.WAITING_FOR_FSM: - req_index += 1 - continue - num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -207,6 +198,9 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[0] + if request.use_guided_decoding: + guided_decoding_request_ids.add(request.request_id) + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request: @@ -346,7 +340,7 @@ def schedule(self) -> "SchedulerOutput": # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), - ) + guided_decoding_request_ids=guided_decoding_request_ids) self.finished_req_ids = set() return scheduler_output @@ -449,8 +443,6 @@ def update_from_output( scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", ) -> EngineCoreOutputs: - # concern: batchsize >>>1000 - # compilation << update # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs @@ -496,6 +488,18 @@ def update_from_output( new_logprobs = None new_token_ids = None + # Handle guided decoding FSM advancement if applicable + if request.use_guided_decoding and request.is_grammar_ready: + req_index = model_runner_output.req_id_to_index.get(req_id) + if req_index is not None: + token_id = sampled_token_ids[req_index] + # grammar should already be ready here. + can_accept_token = request.grammar.accept_token(token_id) + if not can_accept_token: + request.status = RequestStatus.FINISHED_GRAMMAR_ERROR + self._free_request(request) + continue + if request.num_computed_tokens == request.num_tokens: req_index = model_runner_output.req_id_to_index[req_id] # NOTE(woosuk): Currently, we assume that each request @@ -562,8 +566,8 @@ def _check_stop(self, request: Request) -> bool: return False def add_request(self, request: Request) -> None: - self.requests[request.request_id] = request self.waiting.append(request) + self.requests[request.request_id] = request self.request_queued(request) def finish_requests( diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 990b3dd0ed780..b065dad54a88e 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -4,10 +4,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple if TYPE_CHECKING: + import torch + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange from vllm.sampling_params import SamplingParams + from vllm.v1.guided_decoding import Grammar from vllm.v1.request import Request @@ -24,6 +27,8 @@ class NewRequestData: block_ids: List[int] num_computed_tokens: int lora_request: Optional["LoRARequest"] + grammar: Optional["Grammar"] + bitmask: Optional["torch.Tensor"] @classmethod def from_request( @@ -32,18 +37,18 @@ def from_request( block_ids: List[int], num_computed_tokens: int, ) -> "NewRequestData": - return cls( - req_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, - mm_inputs=request.mm_inputs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, - sampling_params=request.sampling_params, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - lora_request=request.lora_request, - ) + return cls(req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, + grammar=request.grammar, + bitmask=request.bitmask) @dataclass @@ -106,3 +111,6 @@ class SchedulerOutput: # List of (req_id, encoder_input_index) tuples. # Used to free the encoder cache. free_encoder_input_ids: List[Tuple[str, int]] + + # Set of request ids for all requests that uses guided decoding + guided_decoding_request_ids: Set[str] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 30c49bfd7816f..bb64b6ace6c4e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -16,7 +16,6 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler @@ -24,6 +23,7 @@ EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor +from vllm.v1.guided_decoding import GuidedDecodingManager from vllm.v1.request import GuidedDecodingOptions, Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.version import __version__ as VLLM_VERSION @@ -71,20 +71,10 @@ def __init__( vllm_config.model_config) # initialize the tokenizer on the scheduler (this is used for constrained decoding) - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, - lora_config=vllm_config.lora_config) - tokenizer_group.ping() - self.tokenizer_group = tokenizer_group + # and guided decoding manager self.use_guided_decoding = False - - # Initialize guided decoding manager - from vllm.v1.guided_decoding import GuidedDecodingManager self.guided_decoding_manager = GuidedDecodingManager( - tokenizer_group=self.tokenizer_group, - model_config=vllm_config.model_config) + vllm_config=vllm_config) def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: @@ -152,41 +142,10 @@ def step(self) -> EngineCoreOutputs: scheduler_output = self.scheduler.schedule() - # Attach bitmasks to scheduler output for broadcasting to workers - if self.use_guided_decoding: - scheduler_output.guided_decoding_bitmasks = { - req.request_id: req.bitmask - for req in self.scheduler.running - if req.use_guided_decoding and req.is_grammar_ready - } - output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) - - if self.use_guided_decoding: - # Advance FSM for each request using guided decoding - for req in self.scheduler.running: - if not req.use_guided_decoding or not req.is_grammar_ready: - continue - - # Get the generated tokens for this request - if req.request_id in output.outputs: - generated_tokens = output.outputs[req.request_id].token_ids - # Advance FSM for each generated token - for token in generated_tokens: - if not req.grammar.accept_token(token): - # Token was rejected by grammar - mark request as finished with error - self.scheduler.finish_requests( - [req.request_id], - RequestStatus.FINISHED_GRAMMAR_ERROR) - break - - # Update bitmask for next token prediction if request is still running - if req.request_id not in self.scheduler.finished_requests: - req.grammar.fill_bitmask(req.bitmask, 0) - return engine_core_outputs def shutdown(self): diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index d9e1119b37459..aac6573ba7e4f 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -10,6 +10,8 @@ import torch import xgrammar as xgr +from vllm.config import VllmConfig +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.request import RequestStatus if TYPE_CHECKING: @@ -19,16 +21,15 @@ class Grammar: - finished: bool = False - - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, ctx: xgr.CompiledGrammar) -> None: + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding # TODO: support max_rollback_tokens self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx + self.prefilled = False def accept_token(self, token: int) -> bool: # NOTE: accept_token will determines whether we accept this token @@ -67,11 +68,15 @@ class GrammarCache: class GuidedDecodingManager: - def __init__(self, tokenizer_group: BaseTokenizerGroup, - model_config: ModelConfig): - self.tokenizer_group = tokenizer_group - self.model_config = model_config - self.vocab_size = model_config.get_vocab_size() + def __init__(self, vllm_config: VllmConfig): + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) + tokenizer_group.ping() + self.model_config = vllm_config.model_config + self.vocab_size = vllm_config.model_config.get_vocab_size() self.tokenizer = tokenizer_group.get_lora_tokenizer(None) self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} self.executor = ThreadPoolExecutor() diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f77e51ce5b9b3..cd09f6313f5e4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Datastructures defining an input batch -from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch @@ -12,10 +11,13 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.guided_decoding import Grammar from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: + import torch + from vllm.multimodal.inputs import PlaceholderRange @@ -38,6 +40,8 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None + grammar: Optional[Grammar] = None + bitmask: Optional[torch.Tensor] = None @property def num_tokens(self) -> int: @@ -171,7 +175,7 @@ def __init__( self.lora_id_to_request_ids: Dict[int, Set[str]] = {} self.lora_id_to_lora_request: Dict[int, LoRARequest] = {} - self.grammar_reqs: Set[str] = set() + self.grammar_reqs: Dict[str, torch.Tensor] = {} # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own @@ -249,7 +253,7 @@ def add_request( self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs if request.grammar is not None: - self.grammar_reqs.add(req_id) + self.grammar_reqs[req_id] = request.bitmask # Add request lora ID if request.lora_request: @@ -279,7 +283,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.grammar_reqs.discard(req_id) + self.grammar_reqs.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) # LoRA @@ -365,6 +369,7 @@ def condense(self, empty_req_indices: List[int]) -> None: self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 def make_sampling_metadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6417702c1b74e..cbf02f233f253 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -260,19 +260,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: assert req_index is not None removed_req_indices.append(req_index) - # we should advance the FSM here - if req_id in scheduler_output.guided_decoding_bitmasks and req_state.grammar is not None: - token_idx = scheduler_output.num_scheduled_tokens[req_id] - 1 - token_id = self.input_batch.token_ids_cpu[req_index, token_idx] - # Advance the FSM state - if not req_state.grammar.accept_token(token_id): - # This shouldn't happen since we masked the logits, but handle gracefully - logger.error( - f"Grammar rejected token {token_id} for request {req_id}" - ) - req_state.status = RequestStatus.FINISHED_ABORTED - continue - req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -296,6 +283,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, + grammar=new_req_data.grammar, + bitmask=new_req_data.bitmask, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -359,6 +348,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) + # Fill the bitmask + if req_id in scheduler_output.guided_decoding_request_ids and req_state.grammar is not None: + if not req_state.grammar.prefilled: + req_state.grammar.prefilled = True + else: + token_idx = scheduler_output.num_scheduled_tokens[ + req_id] - 1 + if not req_state.grammar.matcher.is_terminated(): + assert req_state.bitmask is not None + req_state.grammar.fill_bitmask(req_state.bitmask, + token_idx) + # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) @@ -377,7 +378,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> Tuple[FlashAttentionMetadata, torch.Tensor, + Optional[List[torch.Tensor]]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -387,6 +391,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) + # Prepare bitmasks for guided decoding + bitmask: Optional[torch.Tensor] = None + # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens_list: List[int] = [] @@ -397,6 +404,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_scheduled_tokens_list.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + if req_id in scheduler_output.guided_decoding_request_ids: + bitmask = self.requests[req_id].bitmask num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, dtype=np.int32) assert max_num_scheduled_tokens > 0 @@ -538,7 +547,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - return attn_metadata, logits_indices + + return attn_metadata, logits_indices, bitmask def _compute_cascade_attn_prefix_len( self, @@ -798,7 +808,8 @@ def execute_model( encoder_outputs = [] # Prepare the decoder inputs. - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + attn_metadata, logits_indices, bitmask = self._prepare_inputs( + scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -852,15 +863,9 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply guided decoding bitmasks if present - if hasattr(scheduler_output, 'guided_decoding_bitmasks'): - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - if req_id in scheduler_output.guided_decoding_bitmasks: - # TODO: We need to ensure that the bitmask - bitmask = scheduler_output.guided_decoding_bitmasks[req_id] - if bitmask is not None: - # Apply bitmask to logits - bitmask = bitmask.to(self.device, non_blocking=True) - Grammar.apply_bitmask(logits, bitmask) + if bitmask is not None: + Grammar.apply_bitmask(logits, + bitmask.to(self.device, non_blocking=True)) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(batch_changed) From cd357e5adf5f7a0e4601602293782f3853d6626a Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 20:36:27 -0500 Subject: [PATCH 09/84] [CI/Build] Ignore ruff warning up007 This change ignores the following warning from ruff: UP007 Use `X | Y` for type annotations We need to continue using the `Optiona[T]` syntax until Python 3.10 or above is our minimum supported version. I'm not sure why I only see this in one of my environments, but not another. In any case, it seems harmless to explicitly ignore it. Signed-off-by: Russell Bryant --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9892967b82d79..849e8781e24ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,8 @@ ignore = [ "UP032", # Python 3.8 typing "UP006", "UP035", - + # Can remove once 3.10+ is the minimum Python version + "UP007", ] [tool.mypy] From 9a7b08118ff6291607169c89c17cce4ca1883f06 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 20:35:46 -0500 Subject: [PATCH 10/84] Resolve ruff errors `pre-commit run -a ruff` now passes. Signed-off-by: Russell Bryant --- vllm/v1/engine/core.py | 7 +++---- vllm/v1/guided_decoding/__init__.py | 12 ++++++------ vllm/v1/request.py | 10 +++++----- vllm/v1/worker/gpu_model_runner.py | 3 ++- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bb64b6ace6c4e..80fa25f629937 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -8,7 +8,6 @@ from typing import Any, List, Tuple, Type import psutil -import torch import zmq import zmq.asyncio @@ -24,7 +23,7 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.guided_decoding import GuidedDecodingManager -from vllm.v1.request import GuidedDecodingOptions, Request, RequestStatus +from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.version import __version__ as VLLM_VERSION @@ -70,8 +69,8 @@ def __init__( self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) - # initialize the tokenizer on the scheduler (this is used for constrained decoding) - # and guided decoding manager + # initialize the tokenizer on the scheduler (this is used for + # constrained decoding) and guided decoding manager self.use_guided_decoding = False self.guided_decoding_manager = GuidedDecodingManager( vllm_config=vllm_config) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index aac6573ba7e4f..faf41f63aa452 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -15,8 +15,6 @@ from vllm.v1.request import RequestStatus if TYPE_CHECKING: - from vllm.config import ModelConfig - from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.request import GuidedDecodingKey, Request @@ -24,8 +22,8 @@ class Grammar: def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, ctx: xgr.CompiledGrammar) -> None: - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding - # TODO: support max_rollback_tokens + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding TODO: support max_rollback_tokens self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx @@ -103,7 +101,8 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: ctx=ctx) def should_cache(self, request: Request): - if not request.use_guided_decoding: return False + if not request.use_guided_decoding: + return False request.grammar = self.get(request) if not request.grammar: request.grammar = self.cache(request) @@ -134,5 +133,6 @@ def _executor_loop(self, request: Request): def get(self, request: Request): with self._lock: entry = self.grammar_cache.get(request.guided_decoding_key) - if entry is None or not entry.event.is_set(): return None + if entry is None or not entry.event.is_set(): + return None return copy.copy(entry.value) if entry.value else None diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6cb56588eb3a0..4343e6934dc79 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -40,9 +40,9 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: List[int], - multi_modal_inputs: Optional[List["MultiModalKwargs"]], + multi_modal_inputs: Optional[List[MultiModalKwargs]], multi_modal_hashes: Optional[List[str]], - multi_modal_placeholders: Optional[List["PlaceholderRange"]], + multi_modal_placeholders: Optional[List[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, @@ -88,7 +88,7 @@ def __init__( self._bitmask = None @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + def from_engine_core_request(cls, request: EngineCoreRequest) -> Request: return cls( request_id=request.request_id, prompt=request.prompt, @@ -216,12 +216,12 @@ class RequestStatus(enum.IntEnum): FINISHED_GRAMMAR_ERROR = enum.auto() @staticmethod - def is_finished(status: "RequestStatus") -> bool: + def is_finished(status: RequestStatus) -> bool: return status > RequestStatus.PREEMPTED @staticmethod def get_finished_reason( - status: "RequestStatus") -> Union[FinishReason, None]: + status: RequestStatus) -> Union[FinishReason, None]: return _FINISHED_REASON_MAP.get(status) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbf02f233f253..146458367ac1c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -349,7 +349,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_data.new_block_ids) # Fill the bitmask - if req_id in scheduler_output.guided_decoding_request_ids and req_state.grammar is not None: + if (req_id in scheduler_output.guided_decoding_request_ids + and req_state.grammar is not None): if not req_state.grammar.prefilled: req_state.grammar.prefilled = True else: From 2e43e046ee2f5f7a7305f24eddee2ac0522e024d Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 13 Feb 2025 02:02:19 +0000 Subject: [PATCH 11/84] chore: manage requests within manager class Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 22 ++++++++++++++++---- vllm/v1/engine/core.py | 18 +++++----------- vllm/v1/guided_decoding/__init__.py | 19 +++++++++-------- vllm/v1/request.py | 32 ++++++++++++++--------------- vllm/v1/worker/gpu_model_runner.py | 3 +++ 5 files changed, 52 insertions(+), 42 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 6c6e51b4fd28c..37ee91a9a88c5 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -124,6 +124,11 @@ def schedule(self) -> "SchedulerOutput": num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + if request.status == RequestStatus.WAITING_FOR_FSM: + # wait for grammar to be ready + req_index += 1 + continue + # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( self._try_schedule_encoder_inputs(request, @@ -201,6 +206,16 @@ def schedule(self) -> "SchedulerOutput": if request.use_guided_decoding: guided_decoding_request_ids.add(request.request_id) + if request.status == RequestStatus.WAITING_FOR_FSM: + if request.is_grammar_ready: + request.status = RequestStatus.WAITING + request.grammar.prefilled = True + # else: + # # Skip this request but keep in the queue + # self.waiting.popleft() + # self.waiting.append(request) + # continue + # # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request: @@ -256,7 +271,7 @@ def schedule(self) -> "SchedulerOutput": self.waiting.popleft() self.running.append(request) - if request.status == RequestStatus.WAITING: + if RequestStatus.is_waiting(request.status): scheduled_new_reqs.append(request) self.request_scheduled(request, scheduled_timestamp) elif request.status == RequestStatus.PREEMPTED: @@ -489,14 +504,13 @@ def update_from_output( new_token_ids = None # Handle guided decoding FSM advancement if applicable - if request.use_guided_decoding and request.is_grammar_ready: + if request.use_guided_decoding and request.is_grammar_ready and not request.grammar.prefilled: req_index = model_runner_output.req_id_to_index.get(req_id) if req_index is not None: token_id = sampled_token_ids[req_index] - # grammar should already be ready here. + # accept token will also advance the FSM can_accept_token = request.grammar.accept_token(token_id) if not can_accept_token: - request.status = RequestStatus.FINISHED_GRAMMAR_ERROR self._free_request(request) continue diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 80fa25f629937..8e6a95552db54 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -69,9 +69,6 @@ def __init__( self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) - # initialize the tokenizer on the scheduler (this is used for - # constrained decoding) and guided decoding manager - self.use_guided_decoding = False self.guided_decoding_manager = GuidedDecodingManager( vllm_config=vllm_config) @@ -115,11 +112,8 @@ def add_request(self, request: EngineCoreRequest): req = Request.from_engine_core_request(request) if req.use_guided_decoding: - self.use_guided_decoding = True # Start grammar compilation asynchronously self.guided_decoding_manager.should_cache(req) - else: - self.use_guided_decoding = False self.scheduler.add_request(req) @@ -136,8 +130,7 @@ def step(self) -> EngineCoreOutputs: outputs=[], scheduler_stats=self.scheduler.make_stats()) # Calculate bitmasks for all active requests - if self.use_guided_decoding: - self.calculate_grammar_bitmasks() + self.calculate_grammar_bitmasks() scheduler_output = self.scheduler.schedule() @@ -157,18 +150,17 @@ def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() def calculate_grammar_bitmasks(self): - for req in self.scheduler.running: - # ignore requests that doesn't use guided decoding - # or ignore requests that grammar is not ready - if not req.use_guided_decoding or not req.is_grammar_ready: - continue + for req in self.guided_decoding_manager.requests: # Check if grammar is ready in cache grammar = self.guided_decoding_manager.get(req) if grammar is not None: + print(req.use_guided_decoding, req.is_grammar_ready, + req.grammar, grammar) req.grammar = grammar req.allocate_bitmask(1, self.guided_decoding_manager.vocab_size) + print(req.status) continue diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index faf41f63aa452..72c54ba166744 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -1,18 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -import copy import threading from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Optional, Set import torch import xgrammar as xgr from vllm.config import VllmConfig from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.request import RequestStatus +from vllm.v1.request import GuidedDecodingOptions, RequestStatus if TYPE_CHECKING: from vllm.v1.request import GuidedDecodingKey, Request @@ -76,9 +75,10 @@ def __init__(self, vllm_config: VllmConfig): self.model_config = vllm_config.model_config self.vocab_size = vllm_config.model_config.get_vocab_size() self.tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} + self.grammar_cache: Dict[GuidedDecodingKey, GrammarCache] = {} self.executor = ThreadPoolExecutor() self._lock = threading.Lock() + self.requests: Set[Request] = set() def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: request_type, grammar_spec = key @@ -86,12 +86,12 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: self.tokenizer, vocab_size=self.vocab_size) compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - if request_type == "json": - if type(grammar_spec) is not str: + if request_type == GuidedDecodingOptions.json: + if not isinstance(grammar_spec, str): ctx = compiler.compile_builtin_json_grammar() else: ctx = compiler.compile_json_schema(grammar_spec) - elif request_type == "grammar": + elif request_type == GuidedDecodingOptions.grammar: ctx = compiler.compile_grammar(grammar_spec) else: raise ValueError("grammar is not of valid supported types.") @@ -115,6 +115,7 @@ def cache(self, request: Request): def _executor_loop(self, request: Request): key = request.guided_decoding_key + self.requests.add(request) with self._lock: cache_hit = False if key in self.grammar_cache: @@ -128,11 +129,11 @@ def _executor_loop(self, request: Request): else: entry.value = self.initialize_cache(key) entry.event.set() - return copy.copy(entry.value) if entry.value else None + return entry.value if entry.value else None def get(self, request: Request): with self._lock: entry = self.grammar_cache.get(request.guided_decoding_key) if entry is None or not entry.event.is_set(): return None - return copy.copy(entry.value) if entry.value else None + return entry.value if entry.value else None diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 4343e6934dc79..3c3ae09eea82e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -4,6 +4,7 @@ import enum import functools +import json from concurrent.futures import Future from concurrent.futures._base import TimeoutError from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -54,7 +55,7 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request - self.status = RequestStatus.WAITING + self.status = RequestStatus.WAITING_FOR_FSM if sampling_params.guided_decoding is not None else RequestStatus.WAITING self.events: List[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None assert sampling_params.max_tokens is not None @@ -84,7 +85,7 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # Grammar fields, including the grammar object and the bitmask - self._grammar: Future[Grammar] | Grammar | None = None + self.grammar: Future[Grammar] | Grammar | None = None self._bitmask = None @classmethod @@ -161,7 +162,10 @@ def guided_decoding_key(self) -> GuidedDecodingKey: params = self.sampling_params.guided_decoding assert params is not None, "params can't be None." if params.json is not None: - return (GuidedDecodingOptions.json, params.json) + key = params.json + if params.json_object or type(key) is not str: + key = json.dumps(params.json) + return (GuidedDecodingOptions.json, key) elif params.regex is not None: return (GuidedDecodingOptions.regex, params.regex) elif params.choice is not None: @@ -171,22 +175,15 @@ def guided_decoding_key(self) -> GuidedDecodingKey: else: raise ValueError("No valid guided decoding parameter found") - @property - def grammar(self) -> Optional[Grammar | Future[Grammar]]: - return self._grammar - - @grammar.setter - def grammar(self, grammar: Grammar | Future[Grammar]) -> None: - self._grammar = grammar - def allocate_bitmask(self, batch_size: int, vocab_size: int) -> None: if isinstance(self._grammar, Future): try: self.grammar = self.grammar.result(timeout=0.05) self.status = RequestStatus.WAITING except TimeoutError: - pass - if self.grammar: + return + + if self.grammar is not None: self._bitmask = self.grammar.allocate_bitmask( batch_size, vocab_size) @@ -197,8 +194,8 @@ def bitmask(self) -> Optional[torch.Tensor]: @property def is_grammar_ready(self) -> bool: if isinstance(self._grammar, Future): - return self._grammar.done() - return self._grammar is not None + return not self._grammar.running() and self._grammar.done() + return self.status == RequestStatus.WAITING and self._grammar is not None class RequestStatus(enum.IntEnum): @@ -213,12 +210,15 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() - FINISHED_GRAMMAR_ERROR = enum.auto() @staticmethod def is_finished(status: RequestStatus) -> bool: return status > RequestStatus.PREEMPTED + @staticmethod + def is_waiting(status: "RequestStatus") -> bool: + return status < RequestStatus.WAITING_FOR_FSM + @staticmethod def get_finished_reason( status: RequestStatus) -> Union[FinishReason, None]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 146458367ac1c..ec3df3b7d3dfc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -393,6 +393,8 @@ def _prepare_inputs( self.input_batch.block_table.commit(num_reqs) # Prepare bitmasks for guided decoding + # OPTIMIZATION: We shouldn't copy over + # the bitmask like this multiple times bitmask: Optional[torch.Tensor] = None # Get the number of scheduled tokens for each request. @@ -406,6 +408,7 @@ def _prepare_inputs( max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) if req_id in scheduler_output.guided_decoding_request_ids: + print(self.requests[req_id]) bitmask = self.requests[req_id].bitmask num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, dtype=np.int32) From ccde524cab58ec9fd9e74fa7171ab38cf8bb58f0 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 20:59:46 -0500 Subject: [PATCH 12/84] Drop grammar getter/setter on Request Since the getter/setter methods only did a literal get/set of the underlying attribute, just make the underlying attribute public. It'll do the same thing with less code. Signed-off-by: Russell Bryant --- vllm/v1/request.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3c3ae09eea82e..b6ff2db0d2a23 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -176,7 +176,7 @@ def guided_decoding_key(self) -> GuidedDecodingKey: raise ValueError("No valid guided decoding parameter found") def allocate_bitmask(self, batch_size: int, vocab_size: int) -> None: - if isinstance(self._grammar, Future): + if isinstance(self.grammar, Future): try: self.grammar = self.grammar.result(timeout=0.05) self.status = RequestStatus.WAITING @@ -193,9 +193,9 @@ def bitmask(self) -> Optional[torch.Tensor]: @property def is_grammar_ready(self) -> bool: - if isinstance(self._grammar, Future): - return not self._grammar.running() and self._grammar.done() - return self.status == RequestStatus.WAITING and self._grammar is not None + if isinstance(self.grammar, Future): + return not self.grammar.running() and self.grammar.done() + return self.status == RequestStatus.WAITING and self.grammar is not None class RequestStatus(enum.IntEnum): From 1587d34b8dfbfe76d90c60112c81cde2f2a72894 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 21:20:27 -0500 Subject: [PATCH 13/84] mypy: Fix return type of GPUModelRunner._prepare_inputs() The bitmask is an `Optional[torch.Tensor]``, not an `Optional[List[torch.Tensor]]``. Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ec3df3b7d3dfc..fb9b4c77a478c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -381,8 +381,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> Tuple[FlashAttentionMetadata, torch.Tensor, - Optional[List[torch.Tensor]]]: + ) -> Tuple[FlashAttentionMetadata, torch.Tensor, Optional[torch.Tensor]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs From 227cc7f59901b0be501719db69cbe3cb19fcedb2 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 21:51:32 -0500 Subject: [PATCH 14/84] Resolve remaining mypy warnings We needed the getter/setter for grammar on Request after all to deal with typing and to express where the type must be a grammar instead of a Future[grammar]. We also needed to be explicit in some places that the grammar is not None. Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 7 +++++-- vllm/v1/request.py | 24 ++++++++++++++++-------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 37ee91a9a88c5..4e125e3bca081 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -2,6 +2,7 @@ import time from collections import deque +from concurrent.futures import Future from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig @@ -207,7 +208,7 @@ def schedule(self) -> "SchedulerOutput": guided_decoding_request_ids.add(request.request_id) if request.status == RequestStatus.WAITING_FOR_FSM: - if request.is_grammar_ready: + if request.grammar and request.is_grammar_ready: request.status = RequestStatus.WAITING request.grammar.prefilled = True # else: @@ -504,7 +505,9 @@ def update_from_output( new_token_ids = None # Handle guided decoding FSM advancement if applicable - if request.use_guided_decoding and request.is_grammar_ready and not request.grammar.prefilled: + if (request.use_guided_decoding and request.grammar + and request.is_grammar_ready + and not request.grammar.prefilled): req_index = model_runner_output.req_id_to_index.get(req_id) if req_index is not None: token_id = sampled_token_ids[req_index] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b6ff2db0d2a23..c33a534e793c5 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -85,7 +85,7 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # Grammar fields, including the grammar object and the bitmask - self.grammar: Future[Grammar] | Grammar | None = None + self._grammar: Future[Grammar] | Grammar | None = None self._bitmask = None @classmethod @@ -176,15 +176,15 @@ def guided_decoding_key(self) -> GuidedDecodingKey: raise ValueError("No valid guided decoding parameter found") def allocate_bitmask(self, batch_size: int, vocab_size: int) -> None: - if isinstance(self.grammar, Future): + if isinstance(self._grammar, Future): try: - self.grammar = self.grammar.result(timeout=0.05) + self._grammar = self._grammar.result(timeout=0.05) self.status = RequestStatus.WAITING except TimeoutError: return - if self.grammar is not None: - self._bitmask = self.grammar.allocate_bitmask( + if self._grammar is not None: + self._bitmask = self._grammar.allocate_bitmask( batch_size, vocab_size) @functools.cached_property @@ -193,9 +193,17 @@ def bitmask(self) -> Optional[torch.Tensor]: @property def is_grammar_ready(self) -> bool: - if isinstance(self.grammar, Future): - return not self.grammar.running() and self.grammar.done() - return self.status == RequestStatus.WAITING and self.grammar is not None + if isinstance(self._grammar, Future): + return not self._grammar.running() and self._grammar.done() + return self.status == RequestStatus.WAITING and self._grammar is not None + + @property + def grammar(self) -> Optional[Grammar]: + return self._grammar if isinstance(self._grammar, Grammar) else None + + @grammar.setter + def grammar(self, grammar: Grammar | Future[Grammar]) -> None: + self._grammar = grammar class RequestStatus(enum.IntEnum): From c0b235d9d0a6c705128dd402961dfb0bd5e6751d Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 21:56:32 -0500 Subject: [PATCH 15/84] Finish getting pre-commit to pass Resolve remaining ruff and yapf issues. Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 9 ++++----- vllm/v1/request.py | 9 ++++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 4e125e3bca081..2d5c350ab3bc7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -2,7 +2,6 @@ import time from collections import deque -from concurrent.futures import Future from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig @@ -207,10 +206,10 @@ def schedule(self) -> "SchedulerOutput": if request.use_guided_decoding: guided_decoding_request_ids.add(request.request_id) - if request.status == RequestStatus.WAITING_FOR_FSM: - if request.grammar and request.is_grammar_ready: - request.status = RequestStatus.WAITING - request.grammar.prefilled = True + if (request.status == RequestStatus.WAITING_FOR_FSM + and request.grammar and request.is_grammar_ready): + request.status = RequestStatus.WAITING + request.grammar.prefilled = True # else: # # Skip this request but keep in the queue # self.waiting.popleft() diff --git a/vllm/v1/request.py b/vllm/v1/request.py index c33a534e793c5..a707696d6f98a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -55,7 +55,9 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request - self.status = RequestStatus.WAITING_FOR_FSM if sampling_params.guided_decoding is not None else RequestStatus.WAITING + self.status = (RequestStatus.WAITING_FOR_FSM + if sampling_params.guided_decoding is not None else + RequestStatus.WAITING) self.events: List[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None assert sampling_params.max_tokens is not None @@ -195,7 +197,8 @@ def bitmask(self) -> Optional[torch.Tensor]: def is_grammar_ready(self) -> bool: if isinstance(self._grammar, Future): return not self._grammar.running() and self._grammar.done() - return self.status == RequestStatus.WAITING and self._grammar is not None + return (self.status == RequestStatus.WAITING + and self._grammar is not None) @property def grammar(self) -> Optional[Grammar]: @@ -224,7 +227,7 @@ def is_finished(status: RequestStatus) -> bool: return status > RequestStatus.PREEMPTED @staticmethod - def is_waiting(status: "RequestStatus") -> bool: + def is_waiting(status: RequestStatus) -> bool: return status < RequestStatus.WAITING_FOR_FSM @staticmethod From 49fdce01971598536ff81cfb8d137fa3babbf5c5 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 12 Feb 2025 22:27:03 -0500 Subject: [PATCH 16/84] Updat michael's suggestions Co-authored-by: Michael Goin --- vllm/v1/request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index a707696d6f98a..fd3519677477c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -215,7 +215,7 @@ class RequestStatus(enum.IntEnum): WAITING_FOR_FSM = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() - # Note: anything after PREEMPTED (2) will be considered + # Note: anything after PREEMPTED will be considered # as a finished status. FINISHED_STOPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto() From e9a2304918895c1b7d3f8d7a79b84fd17cf67e16 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 13 Feb 2025 04:44:16 +0000 Subject: [PATCH 17/84] chore: update according to Michael's review Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 1 - vllm/v1/core/scheduler_output.py | 4 +- vllm/v1/engine/core.py | 6 +- vllm/v1/guided_decoding/__init__.py | 102 ++++++++++++++-------------- vllm/v1/request.py | 42 ++++++------ vllm/v1/worker/gpu_input_batch.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 10 +-- 7 files changed, 83 insertions(+), 86 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 2d5c350ab3bc7..5cd7bc9fc271a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -365,7 +365,6 @@ def _make_cached_request_data( request: Request, new_block_ids: List[int], num_computed_tokens: int, - *, resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index b065dad54a88e..187dc3979555e 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -28,7 +28,7 @@ class NewRequestData: num_computed_tokens: int lora_request: Optional["LoRARequest"] grammar: Optional["Grammar"] - bitmask: Optional["torch.Tensor"] + grammar_bitmask: Optional["torch.Tensor"] @classmethod def from_request( @@ -48,7 +48,7 @@ def from_request( num_computed_tokens=num_computed_tokens, lora_request=request.lora_request, grammar=request.grammar, - bitmask=request.bitmask) + grammar_bitmask=request.grammar_bitmask) @dataclass diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 8e6a95552db54..89500ac12edf1 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -153,13 +153,13 @@ def calculate_grammar_bitmasks(self): for req in self.guided_decoding_manager.requests: # Check if grammar is ready in cache - grammar = self.guided_decoding_manager.get(req) + grammar = self.guided_decoding_manager.get_grammar(req) if grammar is not None: print(req.use_guided_decoding, req.is_grammar_ready, req.grammar, grammar) req.grammar = grammar - req.allocate_bitmask(1, - self.guided_decoding_manager.vocab_size) + req.allocate_grammar_bitmask( + 1, self.guided_decoding_manager.vocab_size) print(req.status) continue diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 72c54ba166744..075b6a7564619 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -1,28 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import enum import threading from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Set +from typing import TYPE_CHECKING, Dict, Set, Tuple import torch import xgrammar as xgr from vllm.config import VllmConfig from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.request import GuidedDecodingOptions, RequestStatus if TYPE_CHECKING: - from vllm.v1.request import GuidedDecodingKey, Request + from vllm.v1.request import Request + + +class GuidedDecodingOptions(enum.Enum): + json = enum.auto() + regex = enum.auto() + grammar = enum.auto() + choice = enum.auto() + + +GuidedDecodingKey = Tuple[GuidedDecodingOptions, str] class Grammar: + # NOTE: This would be a generic-enough class for + # supporting different backends, in the future. + # For now, just xgrammar. + # + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding + # TODO: support max_rollback_tokens def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, ctx: xgr.CompiledGrammar) -> None: - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string - # for jump-forward decoding TODO: support max_rollback_tokens self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx @@ -57,12 +71,6 @@ def __copy__(self): return self.copy() -@dataclass -class GrammarCache: - value: Optional[Grammar] - event: threading.Event - - class GuidedDecodingManager: def __init__(self, vllm_config: VllmConfig): @@ -74,39 +82,25 @@ def __init__(self, vllm_config: VllmConfig): tokenizer_group.ping() self.model_config = vllm_config.model_config self.vocab_size = vllm_config.model_config.get_vocab_size() - self.tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.grammar_cache: Dict[GuidedDecodingKey, GrammarCache] = {} - self.executor = ThreadPoolExecutor() - self._lock = threading.Lock() - self.requests: Set[Request] = set() - def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: - request_type, grammar_spec = key tokenizer_info = xgr.TokenizerInfo.from_huggingface( - self.tokenizer, vocab_size=self.vocab_size) - compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + tokenizer_group.get_lora_tokenizer(None), + vocab_size=self.vocab_size) + self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - if request_type == GuidedDecodingOptions.json: - if not isinstance(grammar_spec, str): - ctx = compiler.compile_builtin_json_grammar() - else: - ctx = compiler.compile_json_schema(grammar_spec) - elif request_type == GuidedDecodingOptions.grammar: - ctx = compiler.compile_grammar(grammar_spec) - else: - raise ValueError("grammar is not of valid supported types.") + self.grammar_cache: Dict[GuidedDecodingKey, Grammar] = {} - return Grammar(matcher=xgr.GrammarMatcher(ctx), - vocab_size=self.model_config.hf_text_config.vocab_size, - ctx=ctx) + self.executor = ThreadPoolExecutor() + self.requests: Set[Request] = set() + + self._lock = threading.Lock() def should_cache(self, request: Request): if not request.use_guided_decoding: return False - request.grammar = self.get(request) + request.grammar = self.get_grammar(request) if not request.grammar: request.grammar = self.cache(request) - request.status = RequestStatus.WAITING_FOR_FSM return True return False @@ -117,23 +111,31 @@ def _executor_loop(self, request: Request): key = request.guided_decoding_key self.requests.add(request) with self._lock: - cache_hit = False if key in self.grammar_cache: - cache_hit, entry = True, self.grammar_cache[key] - else: - entry = GrammarCache(None, threading.Event()) - self.grammar_cache[key] = entry + return self.grammar_cache[key] + + self.grammar_cache[key] = self.initialize_grammar(key) + return self.grammar_cache[key] + + def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: + request_type, grammar_spec = key - if cache_hit: - entry.event.wait() + if request_type == GuidedDecodingOptions.json: + if not isinstance(grammar_spec, str): + ctx = self.compiler.compile_builtin_json_grammar() + else: + ctx = self.compiler.compile_json_schema(grammar_spec) + elif request_type == GuidedDecodingOptions.grammar: + ctx = self.compiler.compile_grammar(grammar_spec) else: - entry.value = self.initialize_cache(key) - entry.event.set() - return entry.value if entry.value else None + raise ValueError( + f"`grammar` is not of valid supported types. ({request_type!s})" + ) + + return Grammar(matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vocab_size, + ctx=ctx) - def get(self, request: Request): + def get_grammar(self, request: Request): with self._lock: - entry = self.grammar_cache.get(request.guided_decoding_key) - if entry is None or not entry.event.is_set(): - return None - return entry.value if entry.value else None + return self.grammar_cache.get(request.guided_decoding_key) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fd3519677477c..d5644c4a9b882 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -7,11 +7,13 @@ import json from concurrent.futures import Future from concurrent.futures._base import TimeoutError -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Union from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) +from vllm.v1.guided_decoding import (Grammar, GuidedDecodingKey, + GuidedDecodingOptions) from vllm.v1.utils import ConstantList if TYPE_CHECKING: @@ -20,18 +22,6 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.guided_decoding import Grammar - - -class GuidedDecodingOptions(enum.Enum): - json = enum.auto() - regex = enum.auto() - grammar = enum.auto() - choice = enum.auto() - - -GuidedDecodingObject = Union[str, Dict[str, Any]] -GuidedDecodingKey = Tuple[GuidedDecodingOptions, GuidedDecodingObject] class Request: @@ -88,7 +78,7 @@ def __init__( # Grammar fields, including the grammar object and the bitmask self._grammar: Future[Grammar] | Grammar | None = None - self._bitmask = None + self._grammar_bitmask = None @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> Request: @@ -164,20 +154,26 @@ def guided_decoding_key(self) -> GuidedDecodingKey: params = self.sampling_params.guided_decoding assert params is not None, "params can't be None." if params.json is not None: - key = params.json - if params.json_object or type(key) is not str: - key = json.dumps(params.json) - return (GuidedDecodingOptions.json, key) + if params.json_object or not isinstance(params.json, str): + json_str = json.dumps(params.json) + else: + json_str = params.json + return (GuidedDecodingOptions.json, json_str) elif params.regex is not None: return (GuidedDecodingOptions.regex, params.regex) elif params.choice is not None: - return (GuidedDecodingOptions.choice, params.choice) + if not isinstance(params.choice, str): + json_str = json.dumps(params.choice) + else: + json_str = params.choice + return (GuidedDecodingOptions.choice, json_str) elif params.grammar is not None: return (GuidedDecodingOptions.grammar, params.grammar) else: raise ValueError("No valid guided decoding parameter found") - def allocate_bitmask(self, batch_size: int, vocab_size: int) -> None: + def allocate_grammar_bitmask(self, batch_size: int, + vocab_size: int) -> None: if isinstance(self._grammar, Future): try: self._grammar = self._grammar.result(timeout=0.05) @@ -186,12 +182,12 @@ def allocate_bitmask(self, batch_size: int, vocab_size: int) -> None: return if self._grammar is not None: - self._bitmask = self._grammar.allocate_bitmask( + self._grammar_bitmask = self._grammar.allocate_bitmask( batch_size, vocab_size) @functools.cached_property - def bitmask(self) -> Optional[torch.Tensor]: - return self._bitmask + def grammar_bitmask(self) -> Optional[torch.Tensor]: + return self._grammar_bitmask @property def is_grammar_ready(self) -> bool: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index cd09f6313f5e4..e21bcb6bc4438 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -41,7 +41,7 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None grammar: Optional[Grammar] = None - bitmask: Optional[torch.Tensor] = None + grammar_bitmask: Optional[torch.Tensor] = None @property def num_tokens(self) -> int: @@ -253,7 +253,7 @@ def add_request( self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs if request.grammar is not None: - self.grammar_reqs[req_id] = request.bitmask + self.grammar_reqs[req_id] = request.grammar_bitmask # Add request lora ID if request.lora_request: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fb9b4c77a478c..6fba6bd39b89a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -284,7 +284,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: output_token_ids=[], lora_request=new_req_data.lora_request, grammar=new_req_data.grammar, - bitmask=new_req_data.bitmask, + grammar_bitmask=new_req_data.grammar_bitmask, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -357,9 +357,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: token_idx = scheduler_output.num_scheduled_tokens[ req_id] - 1 if not req_state.grammar.matcher.is_terminated(): - assert req_state.bitmask is not None - req_state.grammar.fill_bitmask(req_state.bitmask, - token_idx) + assert req_state.grammar_bitmask is not None + req_state.grammar.fill_bitmask( + req_state.grammar_bitmask, token_idx) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -408,7 +408,7 @@ def _prepare_inputs( num_tokens) if req_id in scheduler_output.guided_decoding_request_ids: print(self.requests[req_id]) - bitmask = self.requests[req_id].bitmask + bitmask = self.requests[req_id].grammar_bitmask num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, dtype=np.int32) assert max_num_scheduled_tokens > 0 From 872c66f7d24bf93e1134a6b6d7d246b7bc834037 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 13 Feb 2025 18:42:08 +0000 Subject: [PATCH 18/84] chore: simplify cache implementations Signed-off-by: Aaron Pham --- vllm/v1/engine/core.py | 8 ++++---- vllm/v1/guided_decoding/__init__.py | 20 ++++++-------------- vllm/v1/request.py | 6 +++--- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 14a72ac48e6a2..c0bc0024da809 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -158,14 +158,14 @@ def calculate_grammar_bitmasks(self): for req in self.guided_decoding_manager.requests: # Check if grammar is ready in cache - grammar = self.guided_decoding_manager.get_grammar(req) + print(req.is_grammar_ready, req.grammar, + self.guided_decoding_manager.grammar_cache) + grammar = self.guided_decoding_manager.grammar_cache.get( + req.guided_decoding_key) if grammar is not None: - print(req.use_guided_decoding, req.is_grammar_ready, - req.grammar, grammar) req.grammar = grammar req.allocate_grammar_bitmask( 1, self.guided_decoding_manager.vocab_size) - print(req.status) continue diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 075b6a7564619..7594322c82bca 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations import enum -import threading from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Dict, Set, Tuple @@ -80,12 +79,11 @@ def __init__(self, vllm_config: VllmConfig): parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) tokenizer_group.ping() - self.model_config = vllm_config.model_config self.vocab_size = vllm_config.model_config.get_vocab_size() + tokenizer = tokenizer_group.get_lora_tokenizer(None) tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer_group.get_lora_tokenizer(None), - vocab_size=self.vocab_size) + tokenizer, vocab_size=self.vocab_size) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) self.grammar_cache: Dict[GuidedDecodingKey, Grammar] = {} @@ -93,14 +91,13 @@ def __init__(self, vllm_config: VllmConfig): self.executor = ThreadPoolExecutor() self.requests: Set[Request] = set() - self._lock = threading.Lock() - def should_cache(self, request: Request): if not request.use_guided_decoding: return False - request.grammar = self.get_grammar(request) + request.grammar = self.grammar_cache.get(request.guided_decoding_key) if not request.grammar: request.grammar = self.cache(request) + print(request.grammar) return True return False @@ -110,9 +107,8 @@ def cache(self, request: Request): def _executor_loop(self, request: Request): key = request.guided_decoding_key self.requests.add(request) - with self._lock: - if key in self.grammar_cache: - return self.grammar_cache[key] + if key in self.grammar_cache: + return self.grammar_cache[key] self.grammar_cache[key] = self.initialize_grammar(key) return self.grammar_cache[key] @@ -135,7 +131,3 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: return Grammar(matcher=xgr.GrammarMatcher(ctx), vocab_size=self.vocab_size, ctx=ctx) - - def get_grammar(self, request: Request): - with self._lock: - return self.grammar_cache.get(request.guided_decoding_key) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d5644c4a9b882..45d3793529798 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -7,7 +7,7 @@ import json from concurrent.futures import Future from concurrent.futures._base import TimeoutError -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union, cast from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, @@ -198,7 +198,7 @@ def is_grammar_ready(self) -> bool: @property def grammar(self) -> Optional[Grammar]: - return self._grammar if isinstance(self._grammar, Grammar) else None + return cast(Optional[Grammar], self._grammar) @grammar.setter def grammar(self, grammar: Grammar | Future[Grammar]) -> None: @@ -224,7 +224,7 @@ def is_finished(status: RequestStatus) -> bool: @staticmethod def is_waiting(status: RequestStatus) -> bool: - return status < RequestStatus.WAITING_FOR_FSM + return status <= RequestStatus.WAITING_FOR_FSM @staticmethod def get_finished_reason( From a8a2f2744f59e6f4f0d34aa8e4ec87be9b28bcac Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 13:48:17 -0500 Subject: [PATCH 19/84] Changes to get a test request working In the scheduler, if a request is still waiting on FSM completion, we now push it back to the end of the waiting queue. Each time we check on a request in this state, we also check to see if it has actually finished. It was previously stuck without ever getting the result. Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 18 +++++++++--------- vllm/v1/request.py | 16 +++++++++++----- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 5cd7bc9fc271a..42b71e0910213 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -206,15 +206,15 @@ def schedule(self) -> "SchedulerOutput": if request.use_guided_decoding: guided_decoding_request_ids.add(request.request_id) - if (request.status == RequestStatus.WAITING_FOR_FSM - and request.grammar and request.is_grammar_ready): - request.status = RequestStatus.WAITING - request.grammar.prefilled = True - # else: - # # Skip this request but keep in the queue - # self.waiting.popleft() - # self.waiting.append(request) - # continue + if request.status == RequestStatus.WAITING_FOR_FSM: + if request.grammar and request.is_grammar_ready: + request.status = RequestStatus.WAITING + request.grammar.prefilled = True + else: + # Skip this request but keep in the queue + self.waiting.popleft() + self.waiting.append(request) + continue # # Check that adding the request still respects the max_loras # constraint. diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 45d3793529798..7392b670c8980 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -7,7 +7,7 @@ import json from concurrent.futures import Future from concurrent.futures._base import TimeoutError -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Union from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, @@ -172,14 +172,19 @@ def guided_decoding_key(self) -> GuidedDecodingKey: else: raise ValueError("No valid guided decoding parameter found") - def allocate_grammar_bitmask(self, batch_size: int, - vocab_size: int) -> None: + def _check_grammar_completion(self) -> bool: if isinstance(self._grammar, Future): try: self._grammar = self._grammar.result(timeout=0.05) self.status = RequestStatus.WAITING except TimeoutError: - return + return False + return True + + def allocate_grammar_bitmask(self, batch_size: int, + vocab_size: int) -> None: + if not self._check_grammar_completion(): + return if self._grammar is not None: self._grammar_bitmask = self._grammar.allocate_bitmask( @@ -198,7 +203,8 @@ def is_grammar_ready(self) -> bool: @property def grammar(self) -> Optional[Grammar]: - return cast(Optional[Grammar], self._grammar) + self._check_grammar_completion() + return self._grammar if isinstance(self._grammar, Grammar) else None @grammar.setter def grammar(self, grammar: Grammar | Future[Grammar]) -> None: From 3fda1482ea62ef35cd81832037596e34814579c4 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 14:11:42 -0500 Subject: [PATCH 20/84] Resolve mypy error in request when we go to call allocate_bitmask() on the Grammar, mypy wants to know that the type is guaranteed to be a Grammar and not Future[Grammar]. This makes it clear. Signed-off-by: Russell Bryant --- vllm/v1/request.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7392b670c8980..d09ac030f7cbf 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -9,6 +9,7 @@ from concurrent.futures._base import TimeoutError from typing import TYPE_CHECKING, List, Optional, Union +from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) @@ -23,6 +24,8 @@ from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange +logger = init_logger(__name__) + class Request: @@ -187,8 +190,13 @@ def allocate_grammar_bitmask(self, batch_size: int, return if self._grammar is not None: - self._grammar_bitmask = self._grammar.allocate_bitmask( - batch_size, vocab_size) + if isinstance(self._grammar, Grammar): + self._grammar_bitmask = self._grammar.allocate_bitmask( + batch_size, vocab_size) + else: + logger.error( + "Grammar is not ready yet. This should never happen." + " Please file an issue.") @functools.cached_property def grammar_bitmask(self) -> Optional[torch.Tensor]: From d7a64eb4eb2a6c6c41ea81ef71365de62517dafe Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 13 Feb 2025 19:17:43 +0000 Subject: [PATCH 21/84] chore: remove debug print Signed-off-by: Aaron Pham --- vllm/v1/engine/core.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 1 - 2 files changed, 3 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c0bc0024da809..926e89f2986d8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -158,8 +158,6 @@ def calculate_grammar_bitmasks(self): for req in self.guided_decoding_manager.requests: # Check if grammar is ready in cache - print(req.is_grammar_ready, req.grammar, - self.guided_decoding_manager.grammar_cache) grammar = self.guided_decoding_manager.grammar_cache.get( req.guided_decoding_key) if grammar is not None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c026180328e59..d37dc75bb3d73 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -409,7 +409,6 @@ def _prepare_inputs( max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) if req_id in scheduler_output.guided_decoding_request_ids: - print(self.requests[req_id]) bitmask = self.requests[req_id].grammar_bitmask num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, dtype=np.int32) From 34c08ac48ffca0137c036949c23f269f89dbb319 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 15:15:27 -0500 Subject: [PATCH 22/84] Enable some v1 structured output tests This change updates tests.entrypoints.llm.test_guided_generate to run its test cases against both v0 and v1. It skips cases we know won't work with v1. I still see failures here, but it seems like a real bug vs the test doing the wrong thing. Signed-off-by: Russell Bryant --- tests/entrypoints/llm/test_guided_generate.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 932a35a9950ec..f950ad8fbaaaf 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import os import re import weakref @@ -14,23 +15,47 @@ MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] +GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] -@pytest.fixture(scope="module") -def llm(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, max_model_len=1024) +@pytest.fixture(autouse=True) +def v1(request, run_with_both_engines, monkeypatch): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + use_v1 = os.getenv('VLLM_USE_V1') == '1' + if use_v1 and 'guided_decoding_backend' in request.fixturenames: + guided_decoding_backend = request.getfixturevalue( + 'guided_decoding_backend') + if guided_decoding_backend not in GUIDED_DECODING_BACKENDS_V1: + pytest.skip(f"Skipping test because {guided_decoding_backend} " + "is not in GUIDED_DECODING_BACKENDS_V1") - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) - del llm - cleanup_dist_env_and_memory() + if use_v1 and "regex" in request.node.name: + pytest.skip("Skipping test because V1 does not support regex") + + +@pytest.fixture(scope="function") +def llm(monkeypatch): + with monkeypatch.context() as m: + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_model_len=1024) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + del llm + cleanup_dist_env_and_memory() @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): + use_v1 = os.getenv('VLLM_USE_V1') == '1' + if use_v1: + pytest.skip("Skipping test because V1 does not support regex") + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( From 3b736ced23f0b61bfa22133521fa58d6fa045118 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 15:48:42 -0500 Subject: [PATCH 23/84] Validate structured output backend for v1 Raise a ValueError exception if the guided decoding backend parameter is not set to "xgrammar", as that is the only valid choice right now. Signed-off-by: Russell Bryant --- tests/entrypoints/llm/test_guided_generate.py | 4 ---- vllm/v1/engine/processor.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index f950ad8fbaaaf..ff4c44de364fe 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -52,10 +52,6 @@ def llm(monkeypatch): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): - use_v1 = os.getenv('VLLM_USE_V1') == '1' - if use_v1: - pytest.skip("Skipping test because V1 does not support regex") - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b7eee5a39972b..71b91ea403c31 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -83,6 +83,15 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + def _validate_guided_decoding( + self, params: Union[SamplingParams, PoolingParams]) -> None: + if not isinstance(params, SamplingParams): + return + if (params.guided_decoding + and params.guided_decoding.backend != 'xgrammar'): + raise ValueError( + "Only xgrammar guided decoding is supported in V1.") + def process_inputs( self, request_id: str, @@ -100,6 +109,7 @@ def process_inputs( self._validate_logprobs(params) self._validate_lora(lora_request) + self._validate_guided_decoding(params) if arrival_time is None: arrival_time = time.time() From 1a258fe9e32fca128413ca80e133cb05f6aaa292 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 18 Feb 2025 16:31:28 -0500 Subject: [PATCH 24/84] wip fixes for bitmask initialization and communication Signed-off-by: Russell Bryant --- tests/entrypoints/llm/test_guided_generate.py | 2 +- vllm/v1/core/scheduler.py | 36 ++++++++++++------- vllm/v1/core/scheduler_output.py | 2 ++ vllm/v1/engine/core.py | 5 +++ vllm/v1/worker/gpu_model_runner.py | 7 ++-- 5 files changed, 35 insertions(+), 17 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 70c055fbc7559..f2a1e8097bdab 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -13,7 +13,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 7be2dcd90c857..a06f79d843ff5 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time -from collections import deque -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, SpeculativeConfig) @@ -37,6 +36,7 @@ def __init__( self.lora_config = lora_config self.speculative_config = speculative_config self.log_stats = log_stats + self.vocab_size = model_config.get_vocab_size() # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -59,7 +59,7 @@ def __init__( # req_id -> Request self.requests: Dict[str, Request] = {} # Priority queues for requests. - self.waiting: Deque[Request] = deque() + self.waiting: List[Request] = [] self.running: List[Request] = [] # The requests that have been scheduled and are being executed # by the executor. @@ -144,6 +144,11 @@ def schedule(self) -> "SchedulerOutput": req_index += 1 continue + if request.use_guided_decoding: + if request.grammar_bitmask is None: + request.allocate_grammar_bitmask(1, self.vocab_size) + guided_decoding_request_ids.add(request.request_id) + # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( self._try_schedule_encoder_inputs(request, @@ -170,7 +175,7 @@ def schedule(self) -> "SchedulerOutput": preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 - self.waiting.appendleft(preempted_req) + self.waiting.insert(0, preempted_req) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. @@ -224,24 +229,28 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: + num_to_skip: int = 0 while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - request = self.waiting[0] + if num_to_skip >= len(self.waiting): + break - if request.use_guided_decoding: - guided_decoding_request_ids.add(request.request_id) + request = self.waiting[num_to_skip] if request.status == RequestStatus.WAITING_FOR_FSM: if request.grammar and request.is_grammar_ready: request.status = RequestStatus.WAITING request.grammar.prefilled = True - else: - # Skip this request but keep in the queue - self.waiting.popleft() - self.waiting.append(request) - continue + num_to_skip += 1 + continue + + if request.use_guided_decoding: + if request.grammar_bitmask is None: + request.allocate_grammar_bitmask(1, self.vocab_size) + guided_decoding_request_ids.add(request.request_id) + # # Check that adding the request still respects the max_loras # constraint. @@ -296,7 +305,7 @@ def schedule(self) -> "SchedulerOutput": # The request cannot be scheduled. break - self.waiting.popleft() + self.waiting.pop(num_to_skip) self.running.append(request) self.scheduled_req_ids.add(request.request_id) if RequestStatus.is_waiting(request.status): @@ -416,6 +425,7 @@ def _make_cached_request_data( new_token_ids, new_block_ids) self._cached_reqs_data[request.request_id] = req_data + req_data.grammar_bitmask = request.grammar_bitmask return req_data def _try_schedule_encoder_inputs( diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index eafd9f614e622..e910018b3a10f 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -63,6 +63,7 @@ class CachedRequestData: new_token_ids: List[int] new_block_ids: List[int] num_computed_tokens: int + grammar_bitmask: Optional["torch.Tensor"] @classmethod def from_request( @@ -78,6 +79,7 @@ def from_request( new_token_ids=new_token_ids, new_block_ids=new_block_ids, num_computed_tokens=request.num_computed_tokens, + grammar_bitmask=request.grammar_bitmask, ) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0786b0e738f36..afcb3ad5e2638 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -155,6 +155,9 @@ def step(self) -> EngineCoreOutputs: self.calculate_grammar_bitmasks() scheduler_output = self.scheduler.schedule() + if scheduler_output.total_num_scheduled_tokens == 0: + return EngineCoreOutputs( + outputs=[], scheduler_stats=self.scheduler.make_stats()) output = self.model_executor.execute_model(scheduler_output) @@ -221,6 +224,8 @@ def reset_prefix_cache(self): def calculate_grammar_bitmasks(self): for req in self.guided_decoding_manager.requests: + if req.grammar_bitmask is not None: + continue # Check if grammar is ready in cache grammar = self.guided_decoding_manager.grammar_cache.get( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa2fa2169d1b7..2ac8c00b056a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -392,7 +392,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index, start_index:end_token_index] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec decode tokens. self.input_batch.num_tokens[req_index] = end_token_index - # Fill the bitmask if (req_id in scheduler_output.guided_decoding_request_ids and req_state.grammar is not None): @@ -456,8 +455,10 @@ def _prepare_inputs( num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - if req_id in scheduler_output.guided_decoding_request_ids: - bitmask = self.requests[req_id].grammar_bitmask + # TODO - + # -- we have a bitmask per request + # -- need to pull the latest bitmask out of scheduler_output + bitmask = self.requests[req_id].grammar_bitmask # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] From 10f01f559a2a0824192754f9a358f0cbb6ef591e Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 19 Feb 2025 12:59:34 -0500 Subject: [PATCH 25/84] Clean up some remnants of inaccurate merge conflict resolution I spotted these changes in our current diff, but they're unrelated to our changes, so they were just mistakes during conflict resolution. Signed-off-by: Russell Bryant --- vllm/v1/engine/core.py | 4 ++++ vllm/v1/worker/gpu_model_runner.py | 7 ------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index afcb3ad5e2638..49be831c98c85 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -141,6 +141,10 @@ def add_request(self, request: EngineCoreRequest): def abort_requests(self, request_ids: List[str]): """Abort requests from the scheduler.""" + + # TODO: The scheduler doesn't really need to know the + # specific finish reason, TBD whether we propagate that + # (i.e. client-aborted vs stop criteria met). self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2ac8c00b056a1..cc783832fe01c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -605,13 +605,6 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - # NOTE(woosuk): Due to chunked prefills, the batch may contain partial - # requests. While we should not sample any token from these partial - # requests, we do so for simplicity. We will ignore the sampled - # tokens from the partial requests. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - return attn_metadata, logits_indices, bitmask def _compute_cascade_attn_prefix_len( From a6b07d17dca2020b265a9d48ef0d909660f9df7c Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 19 Feb 2025 18:23:28 +0000 Subject: [PATCH 26/84] fix: correctly use bitmask batch-wise Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 20 +++---- vllm/v1/core/scheduler_output.py | 22 +++++--- vllm/v1/engine/core.py | 23 +++++--- vllm/v1/guided_decoding/__init__.py | 88 +++++++++++++++++++++-------- vllm/v1/request.py | 24 +------- vllm/v1/utils.py | 4 +- vllm/v1/worker/gpu_input_batch.py | 7 --- vllm/v1/worker/gpu_model_runner.py | 41 +++++++------- 8 files changed, 127 insertions(+), 102 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a06f79d843ff5..2def2cc404353 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -112,7 +112,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_resumed_reqs: List[Request] = [] scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] - guided_decoding_request_ids: Set[str] = set() + guided_decoding_request_ids: Dict[str, int] = {} req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} @@ -144,10 +144,9 @@ def schedule(self) -> "SchedulerOutput": req_index += 1 continue - if request.use_guided_decoding: - if request.grammar_bitmask is None: - request.allocate_grammar_bitmask(1, self.vocab_size) - guided_decoding_request_ids.add(request.request_id) + if request.use_guided_decoding \ + and request.request_id not in guided_decoding_request_ids: + guided_decoding_request_ids[request.request_id] = req_index # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( @@ -246,10 +245,9 @@ def schedule(self) -> "SchedulerOutput": num_to_skip += 1 continue - if request.use_guided_decoding: - if request.grammar_bitmask is None: - request.allocate_grammar_bitmask(1, self.vocab_size) - guided_decoding_request_ids.add(request.request_id) + if request.use_guided_decoding \ + and request.request_id not in guided_decoding_request_ids: + guided_decoding_request_ids[request.request_id] = req_index # # Check that adding the request still respects the max_loras @@ -394,7 +392,8 @@ def schedule(self) -> "SchedulerOutput": # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), - guided_decoding_request_ids=guided_decoding_request_ids) + guided_decoding_request_ids=guided_decoding_request_ids, + ) self.finished_req_ids = set() return scheduler_output @@ -425,7 +424,6 @@ def _make_cached_request_data( new_token_ids, new_block_ids) self._cached_reqs_data[request.request_id] = req_data - req_data.grammar_bitmask = request.grammar_bitmask return req_data def _try_schedule_encoder_inputs( diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index e910018b3a10f..861f766ccd311 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple if TYPE_CHECKING: @@ -28,7 +28,6 @@ class NewRequestData: num_computed_tokens: int lora_request: Optional["LoRARequest"] grammar: Optional["Grammar"] - grammar_bitmask: Optional["torch.Tensor"] @classmethod def from_request( @@ -48,7 +47,6 @@ def from_request( num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, grammar=request.grammar, - grammar_bitmask=request.grammar_bitmask, ) @@ -63,7 +61,6 @@ class CachedRequestData: new_token_ids: List[int] new_block_ids: List[int] num_computed_tokens: int - grammar_bitmask: Optional["torch.Tensor"] @classmethod def from_request( @@ -79,7 +76,6 @@ def from_request( new_token_ids=new_token_ids, new_block_ids=new_block_ids, num_computed_tokens=request.num_computed_tokens, - grammar_bitmask=request.grammar_bitmask, ) @@ -121,5 +117,17 @@ class SchedulerOutput: # Used to free the encoder cache. free_encoder_input_ids: List[Tuple[str, int]] - # Set of request ids for all requests that uses guided decoding - guided_decoding_request_ids: Set[str] + # Dict of request ids to its index within the batch + # for filling the next token bitmask + guided_decoding_request_ids: Dict[str, int] + # the bitmask for the whole batch + _grammar_bitmask: Optional["torch.Tensor"] = field(default=None, + repr=False) + + @property + def grammar_bitmask(self) -> Optional[torch.Tensor]: + return self._grammar_bitmask + + @grammar_bitmask.setter + def grammar_bitmask(self, bitmask: torch.Tensor) -> None: + self._grammar_bitmask = bitmask diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 49be831c98c85..1e0a4a4122635 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -74,8 +74,7 @@ def __init__( self.mm_input_cache_server = MMInputCacheServer( vllm_config.model_config) - self.guided_decoding_manager = GuidedDecodingManager( - vllm_config=vllm_config) + self.guided_decoding_manager = GuidedDecodingManager(vllm_config) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -156,9 +155,18 @@ def step(self) -> EngineCoreOutputs: outputs=[], scheduler_stats=self.scheduler.make_stats()) # Calculate bitmasks for all active requests - self.calculate_grammar_bitmasks() + self.setup_request_grammars() scheduler_output = self.scheduler.schedule() + # the bitmask allocation for grammars + # should be ready at this point. + if len(self.guided_decoding_manager.requests) > 0: + if not self.guided_decoding_manager.is_bitmask_ready: + raise ValueError("Could be a bug at this point") + # one copy + scheduler_output.grammar_bitmask = \ + self.guided_decoding_manager.grammar_bitmask + if scheduler_output.total_num_scheduled_tokens == 0: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) @@ -226,19 +234,18 @@ def profile(self, is_start: bool = True): def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() - def calculate_grammar_bitmasks(self): + def setup_request_grammars(self): for req in self.guided_decoding_manager.requests: - if req.grammar_bitmask is not None: + if req.grammar is not None: continue # Check if grammar is ready in cache - grammar = self.guided_decoding_manager.grammar_cache.get( + grammar = self.guided_decoding_manager.request_key_to_grammar.get( req.guided_decoding_key) if grammar is not None: req.grammar = grammar - req.allocate_grammar_bitmask( - 1, self.guided_decoding_manager.vocab_size) continue + self.guided_decoding_manager.allocate_bitmask() def add_lora(self, lora_request: LoRARequest) -> None: self.model_executor.add_lora(lora_request) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 7594322c82bca..7802f676e4a7c 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -2,8 +2,9 @@ from __future__ import annotations import enum -from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Dict, Set, Tuple +import functools +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import torch import xgrammar as xgr @@ -25,17 +26,31 @@ class GuidedDecodingOptions(enum.Enum): GuidedDecodingKey = Tuple[GuidedDecodingOptions, str] +def reset_bitmask(bitmask: torch.Tensor): + # this calls bitmask.fill_(tensor([1, 1, ...], dtype=int32)) + xgr.reset_token_bitmask(bitmask) + + +def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor, + indices: List[int]) -> None: + xgr.apply_token_bitmask_inplace(logits, vocab_mask, indices=indices) + + class Grammar: # NOTE: This would be a generic-enough class for # supporting different backends, in the future. # For now, just xgrammar. # + # TODO: support max_rollback_tokens # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string # for jump-forward decoding - # TODO: support max_rollback_tokens - def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, - ctx: xgr.CompiledGrammar) -> None: + def __init__( + self, + matcher: xgr.GrammarMatcher, + vocab_size: int, + ctx: xgr.CompiledGrammar, + ) -> None: self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx @@ -46,18 +61,10 @@ def accept_token(self, token: int) -> bool: # and will also update the machine state return self.matcher.accept_token(token) - def allocate_bitmask(self, batch_size: int, - vocab_size: int) -> torch.Tensor: - return xgr.allocate_token_bitmask(batch_size, vocab_size) - # this should be ran in parallel with model decoding def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(bitmask, idx) - @staticmethod - def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - xgr.apply_token_bitmask_inplace(logits, vocab_mask) - def reset(self): self.matcher.reset() @@ -80,24 +87,55 @@ def __init__(self, vllm_config: VllmConfig): lora_config=vllm_config.lora_config) tokenizer_group.ping() self.vocab_size = vllm_config.model_config.get_vocab_size() + self.vllm_config = vllm_config tokenizer = tokenizer_group.get_lora_tokenizer(None) tokenizer_info = xgr.TokenizerInfo.from_huggingface( tokenizer, vocab_size=self.vocab_size) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - self.grammar_cache: Dict[GuidedDecodingKey, Grammar] = {} + self.request_key_to_grammar: Dict[GuidedDecodingKey, Grammar] = {} self.executor = ThreadPoolExecutor() self.requests: Set[Request] = set() + self._grammar_bitmask: Optional[Union[torch.Tensor, + Future[torch.Tensor]]] = None + + def allocate_bitmask(self) -> None: + self._grammar_bitmask = self.executor.submit( + xgr.allocate_token_bitmask, + self.vllm_config.scheduler_config.max_num_seqs, + self.vocab_size / 32) + + def _ensure_bitmask_ready(self) -> bool: + if isinstance(self._grammar_bitmask, Future): + try: + self._grammar_bitmask = self._grammar_bitmask.result( + timeout=0.05) + except TimeoutError: + return False + return True + + @functools.cached_property + def grammar_bitmask(self) -> Optional[torch.Tensor]: + self._ensure_bitmask_ready() + return self._grammar_bitmask if not isinstance(self._grammar_bitmask, + Future) else None + + @property + def is_bitmask_ready(self) -> bool: + if isinstance(self._grammar_bitmask, Future): + return not self._grammar_bitmask.running( + ) and self._grammar_bitmask.done() + return self._grammar_bitmask is not None def should_cache(self, request: Request): if not request.use_guided_decoding: return False - request.grammar = self.grammar_cache.get(request.guided_decoding_key) + request.grammar = self.request_key_to_grammar.get( + request.guided_decoding_key) if not request.grammar: request.grammar = self.cache(request) - print(request.grammar) return True return False @@ -107,11 +145,11 @@ def cache(self, request: Request): def _executor_loop(self, request: Request): key = request.guided_decoding_key self.requests.add(request) - if key in self.grammar_cache: - return self.grammar_cache[key] + if key in self.request_key_to_grammar: + return self.request_key_to_grammar[key] - self.grammar_cache[key] = self.initialize_grammar(key) - return self.grammar_cache[key] + self.request_key_to_grammar[key] = self.initialize_grammar(key) + return self.request_key_to_grammar[key] def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: request_type, grammar_spec = key @@ -123,11 +161,15 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: ctx = self.compiler.compile_json_schema(grammar_spec) elif request_type == GuidedDecodingOptions.grammar: ctx = self.compiler.compile_grammar(grammar_spec) + elif request_type == GuidedDecodingOptions.regex: + ctx = self.compiler.compile_regex(grammar_spec) else: raise ValueError( f"`grammar` is not of valid supported types. ({request_type!s})" ) - return Grammar(matcher=xgr.GrammarMatcher(ctx), - vocab_size=self.vocab_size, - ctx=ctx) + return Grammar( + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vocab_size, + ctx=ctx, + ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index aec341488a25f..3b01bc491663c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -18,7 +18,6 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: - import torch from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs @@ -81,8 +80,7 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # Grammar fields, including the grammar object and the bitmask - self._grammar: Future[Grammar] | Grammar | None = None - self._grammar_bitmask = None + self._grammar: Optional[Union[Future[Grammar], Grammar]] = None @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> Request: @@ -189,24 +187,6 @@ def _check_grammar_completion(self) -> bool: return False return True - def allocate_grammar_bitmask(self, batch_size: int, - vocab_size: int) -> None: - if not self._check_grammar_completion(): - return - - if self._grammar is not None: - if isinstance(self._grammar, Grammar): - self._grammar_bitmask = self._grammar.allocate_bitmask( - batch_size, vocab_size) - else: - logger.error( - "Grammar is not ready yet. This should never happen." - " Please file an issue.") - - @functools.cached_property - def grammar_bitmask(self) -> Optional[torch.Tensor]: - return self._grammar_bitmask - @property def is_grammar_ready(self) -> bool: if isinstance(self._grammar, Future): @@ -220,7 +200,7 @@ def grammar(self) -> Optional[Grammar]: return self._grammar if isinstance(self._grammar, Grammar) else None @grammar.setter - def grammar(self, grammar: Grammar | Future[Grammar]) -> None: + def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None: self._grammar = grammar diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5be4650142428..759123dc1c49e 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -157,12 +157,12 @@ def bind_kv_cache( This function: 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with kv_caches. - 2) Associates each attention layer in the `forward_context` with its + 2) Associates each attention layer in the `forward_context` with its corresponding KV cache in kv_caches. Args: kv_caches: The allocated kv_caches with layer names as keys. - forward_context: The global forward context containing all Attention + forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 48c138391e1e6..ce93e7257d0fe 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -43,7 +43,6 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None grammar: Optional[Grammar] = None - grammar_bitmask: Optional[torch.Tensor] = None @property def num_tokens(self) -> int: @@ -185,8 +184,6 @@ def __init__( self.lora_id_to_request_ids: Dict[int, Set[str]] = {} self.lora_id_to_lora_request: Dict[int, LoRARequest] = {} - self.grammar_reqs: Dict[str, torch.Tensor] = {} - # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -292,9 +289,6 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias - if request.grammar is not None: - self.grammar_reqs[req_id] = request.grammar_bitmask - # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -328,7 +322,6 @@ def remove_request(self, req_id: str) -> Optional[int]: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.grammar_reqs.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) # LoRA diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cc783832fe01c..591a67f151037 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,7 +28,7 @@ FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient -from vllm.v1.guided_decoding import Grammar +from vllm.v1.guided_decoding import apply_bitmask from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -299,7 +299,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: output_token_ids=[], lora_request=new_req_data.lora_request, grammar=new_req_data.grammar, - grammar_bitmask=new_req_data.grammar_bitmask, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -395,15 +394,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Fill the bitmask if (req_id in scheduler_output.guided_decoding_request_ids and req_state.grammar is not None): + idx = scheduler_output.guided_decoding_request_ids[req_id] + # should already be ready + assert scheduler_output.grammar_bitmask is not None if not req_state.grammar.prefilled: req_state.grammar.prefilled = True else: - token_idx = scheduler_output.num_scheduled_tokens[ - req_id] - 1 if not req_state.grammar.matcher.is_terminated(): - assert req_state.grammar_bitmask is not None + # NOTE: this relies on xgrammar internal bitmask, + # so we need to give the actual index + # of the the request_id in the batch req_state.grammar.fill_bitmask( - req_state.grammar_bitmask, token_idx) + scheduler_output.grammar_bitmask, idx) # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. @@ -431,7 +433,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> Tuple[FlashAttentionMetadata, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -441,11 +443,6 @@ def _prepare_inputs( # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) - # Prepare bitmasks for guided decoding - # OPTIMIZATION: We shouldn't copy over - # the bitmask like this multiple times - bitmask: Optional[torch.Tensor] = None - # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) @@ -455,10 +452,6 @@ def _prepare_inputs( num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - # TODO - - # -- we have a bitmask per request - # -- need to pull the latest bitmask out of scheduler_output - bitmask = self.requests[req_id].grammar_bitmask # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] @@ -605,7 +598,7 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices, bitmask + return attn_metadata, logits_indices def _compute_cascade_attn_prefix_len( self, @@ -894,8 +887,7 @@ def execute_model( encoder_outputs = [] # Prepare the decoder inputs. - attn_metadata, logits_indices, bitmask = self._prepare_inputs( - scheduler_output) + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -967,9 +959,14 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply guided decoding bitmasks if present - if bitmask is not None: - Grammar.apply_bitmask(logits, - bitmask.to(self.device, non_blocking=True)) + if scheduler_output.grammar_bitmask is not None: + # TODO: we probably should move this before and + # after, this might not be correct + apply_bitmask( + logits, + scheduler_output.grammar_bitmask.to(self.device, + non_blocking=True), + list(scheduler_output.guided_decoding_request_ids.values())) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.get_sampling_metadata( From 7f255f07fadc56dd472df4caf4e29dc93b3a53e2 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 19 Feb 2025 18:25:38 +0000 Subject: [PATCH 27/84] fix: correct types Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler_output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 861f766ccd311..02c2f4ee2a91a 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -125,9 +125,9 @@ class SchedulerOutput: repr=False) @property - def grammar_bitmask(self) -> Optional[torch.Tensor]: + def grammar_bitmask(self) -> Optional["torch.Tensor"]: return self._grammar_bitmask @grammar_bitmask.setter - def grammar_bitmask(self, bitmask: torch.Tensor) -> None: + def grammar_bitmask(self, bitmask: "torch.Tensor") -> None: self._grammar_bitmask = bitmask From 9ab107fe18c59780b7039b8aed7cdefb5bed1089 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 19 Feb 2025 18:32:08 +0000 Subject: [PATCH 28/84] chore: validate from decoding_config -> per request Signed-off-by: Aaron Pham --- vllm/v1/engine/async_llm.py | 5 +++-- vllm/v1/engine/core.py | 3 --- vllm/v1/engine/llm_engine.py | 1 + vllm/v1/engine/processor.py | 9 +++++++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 1920dbf7a7dc5..a69a0e153b2bf 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -73,6 +73,7 @@ def __init__( model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, + decoding_config=vllm_config.decoding_config, tokenizer=self.tokenizer, input_registry=input_registry, ) @@ -187,8 +188,8 @@ async def generate( * 3) Adding the Request to the Detokenizer. * 4) Adding the Request to the EngineCore (separate process). - A separate output_handler loop runs in a background AsyncIO task, - pulling outputs from EngineCore and putting them into the + A separate output_handler loop runs in a background AsyncIO task, + pulling outputs from EngineCore and putting them into the per-request AsyncStream. The caller of generate() iterates the returned AsyncGenerator, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1e0a4a4122635..01c67b26975cd 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -161,9 +161,6 @@ def step(self) -> EngineCoreOutputs: # the bitmask allocation for grammars # should be ready at this point. if len(self.guided_decoding_manager.requests) > 0: - if not self.guided_decoding_manager.is_bitmask_ready: - raise ValueError("Could be a bug at this point") - # one copy scheduler_output.grammar_bitmask = \ self.guided_decoding_manager.grammar_bitmask diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c9a4c5369dfd8..f89603e6d68ed 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -59,6 +59,7 @@ def __init__( self.processor = Processor(model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, + decoding_config=vllm_config.decoding_config, tokenizer=self.tokenizer, input_registry=input_registry, mm_registry=mm_registry) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 71b91ea403c31..5aa1671e494c0 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -3,7 +3,7 @@ import time from typing import Mapping, Optional, Union -from vllm.config import CacheConfig, LoRAConfig, ModelConfig +from vllm.config import CacheConfig, DecodingConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs @@ -27,6 +27,7 @@ def __init__( model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + decoding_config: DecodingConfig, tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, @@ -35,6 +36,7 @@ def __init__( self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config + self.decoding_config = decoding_config self.tokenizer = tokenizer self.generation_config_fields = model_config.try_get_generation_config( @@ -87,7 +89,10 @@ def _validate_guided_decoding( self, params: Union[SamplingParams, PoolingParams]) -> None: if not isinstance(params, SamplingParams): return - if (params.guided_decoding + if self.decoding_config.guided_decoding_backend != "xgrammar": + raise ValueError( + "Only xgrammar guided decoding is supported in V1.") + if (params.guided_decoding and params.guided_decoding.backend and params.guided_decoding.backend != 'xgrammar'): raise ValueError( "Only xgrammar guided decoding is supported in V1.") From 8d6bd3bf68873d528339f1cd3d80639c58a95e3a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 19 Feb 2025 18:40:52 +0000 Subject: [PATCH 29/84] chore: passing vocab_size Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 7802f676e4a7c..f36db7bbfec2c 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -105,7 +105,8 @@ def allocate_bitmask(self) -> None: self._grammar_bitmask = self.executor.submit( xgr.allocate_token_bitmask, self.vllm_config.scheduler_config.max_num_seqs, - self.vocab_size / 32) + self.vocab_size, + ) def _ensure_bitmask_ready(self) -> bool: if isinstance(self._grammar_bitmask, Future): From fcb0e8514efc263017629c41b536ec5ac0f33dd0 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 19 Feb 2025 19:18:48 +0000 Subject: [PATCH 30/84] chore: comment out 0.1.13 features Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index f36db7bbfec2c..9a3f6ed57a0c9 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -162,8 +162,8 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: ctx = self.compiler.compile_json_schema(grammar_spec) elif request_type == GuidedDecodingOptions.grammar: ctx = self.compiler.compile_grammar(grammar_spec) - elif request_type == GuidedDecodingOptions.regex: - ctx = self.compiler.compile_regex(grammar_spec) + # elif request_type == GuidedDecodingOptions.regex: + # ctx = self.compiler.compile_regex(grammar_spec) else: raise ValueError( f"`grammar` is not of valid supported types. ({request_type!s})" From e6038f88aa8069ebd2cf0e073e35c1098ac8f8d0 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 20 Feb 2025 10:21:35 -0500 Subject: [PATCH 31/84] Resize bitmask to match the current batch size Applying the bitmask will fail when the bitmask is larger than the current batch size. Resize it here for now. Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 74022d919a76c..9eadcb688bda3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -969,6 +969,13 @@ def execute_model( # Apply guided decoding bitmasks if present if scheduler_output.grammar_bitmask is not None: + if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: + # The bitmask is pre-allocated for the maximum batch size. + # When the batch size is smaller, we need to resize the bitmask + # to match the batch size. + scheduler_output.grammar_bitmask = ( + scheduler_output.grammar_bitmask[:len(self.input_batch. + req_ids)]) # TODO: we probably should move this before and # after, this might not be correct apply_bitmask( From 983089947d0d0b68993c85f6a4b42659f6c457a5 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 20 Feb 2025 10:34:43 -0500 Subject: [PATCH 32/84] set any_whitespace=False for json schema + xgrammar This should be configurable, but that's pending merge of this PR: https://github.com/vllm-project/vllm/pull/12744 we can adopt wheverever that ends up in v1. Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 9a3f6ed57a0c9..017d8e2710437 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -159,7 +159,10 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: if not isinstance(grammar_spec, str): ctx = self.compiler.compile_builtin_json_grammar() else: - ctx = self.compiler.compile_json_schema(grammar_spec) + # TODO -- allow any_whitespace to be configurable + # pending merge of https://github.com/vllm-project/vllm/pull/12744 + ctx = self.compiler.compile_json_schema(grammar_spec, + any_whitespace=False) elif request_type == GuidedDecodingOptions.grammar: ctx = self.compiler.compile_grammar(grammar_spec) # elif request_type == GuidedDecodingOptions.regex: From cebe28180158fe61ccaf8a37227a06fa44b934e6 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 20 Feb 2025 17:17:59 +0000 Subject: [PATCH 33/84] --wip--: debugging fsm apply Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 6 ++---- vllm/v1/engine/core.py | 6 +++++- vllm/v1/guided_decoding/__init__.py | 6 ++---- vllm/v1/worker/gpu_model_runner.py | 4 +++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 2def2cc404353..28452b2916107 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -583,12 +583,10 @@ def update_from_output( "Expect undefined behavior. This may be" " caused by spec decode + structured " "output.") + print(sampled_token_ids) token_id = sampled_token_ids[index][0] # accept token will also advance the FSM - can_accept_token = request.grammar.accept_token(token_id) - if not can_accept_token: - self._free_request(request) - continue + request.grammar.accept_token(token_id) if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a1c14fc406fc0..77c33f7494b80 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -26,7 +26,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor -from vllm.v1.guided_decoding import GuidedDecodingManager +from vllm.v1.guided_decoding import GuidedDecodingManager, reset_bitmask from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -174,6 +174,10 @@ def step(self) -> EngineCoreOutputs: engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore + + if len(self.guided_decoding_manager.requests) > 0: + reset_bitmask(scheduler_output.grammar_bitmask) + return engine_core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 017d8e2710437..49088b2580594 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -125,6 +125,7 @@ def grammar_bitmask(self) -> Optional[torch.Tensor]: @property def is_bitmask_ready(self) -> bool: + self._ensure_bitmask_ready() if isinstance(self._grammar_bitmask, Future): return not self._grammar_bitmask.running( ) and self._grammar_bitmask.done() @@ -136,13 +137,10 @@ def should_cache(self, request: Request): request.grammar = self.request_key_to_grammar.get( request.guided_decoding_key) if not request.grammar: - request.grammar = self.cache(request) + request.grammar = self.executor.submit(self._executor_loop, request) return True return False - def cache(self, request: Request): - return self.executor.submit(self._executor_loop, request) - def _executor_loop(self, request: Request): key = request.guided_decoding_key self.requests.add(request) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9eadcb688bda3..6058088974b37 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -441,7 +441,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_sampling_metadata() def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" + self, + scheduler_output: "SchedulerOutput", ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -968,6 +969,7 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply guided decoding bitmasks if present + print(scheduler_output.guided_decoding_request_ids, scheduler_output.grammar_bitmask[0]) if scheduler_output.grammar_bitmask is not None: if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: # The bitmask is pre-allocated for the maximum batch size. From 862c093f8d74b02a411530ce117bc4658f5af5f1 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 20 Feb 2025 18:37:50 +0000 Subject: [PATCH 34/84] fix: make sure to reset the FSM once we _free_request Signed-off-by: Aaron Pham --- tests/entrypoints/llm/test_guided_generate.py | 1 - vllm/v1/core/scheduler.py | 3 +++ vllm/v1/engine/core.py | 19 ++++++++------- vllm/v1/guided_decoding/__init__.py | 23 +++++++++++-------- vllm/v1/worker/gpu_input_batch.py | 6 ++--- vllm/v1/worker/gpu_model_runner.py | 15 ++++++------ 6 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 7a5aa7fcbca56..eb9d0818d57a9 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -14,7 +14,6 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams - MODEL_NAME = "s3://vllm-ci-model-weights/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 28452b2916107..45afaf840edcd 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -692,6 +692,9 @@ def _free_request(self, request: Request) -> None: self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) + if request.use_guided_decoding and request.grammar: + assert request.grammar.matcher.is_terminated() + request.grammar.reset() def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 77c33f7494b80..3913447b2c627 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -26,7 +26,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor -from vllm.v1.guided_decoding import GuidedDecodingManager, reset_bitmask +from vllm.v1.guided_decoding import GuidedDecodingManager from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -157,27 +157,26 @@ def step(self) -> EngineCoreOutputs: outputs=[], scheduler_stats=self.scheduler.make_stats()) # Calculate bitmasks for all active requests - self.setup_request_grammars() + self.setup_grammars() scheduler_output = self.scheduler.schedule() + + if scheduler_output.total_num_scheduled_tokens == 0: + return EngineCoreOutputs( + outputs=[], scheduler_stats=self.scheduler.make_stats()) + # the bitmask allocation for grammars # should be ready at this point. + # Currently we will broadcast the bitmask if len(self.guided_decoding_manager.requests) > 0: scheduler_output.grammar_bitmask = \ self.guided_decoding_manager.grammar_bitmask - if scheduler_output.total_num_scheduled_tokens == 0: - return EngineCoreOutputs( - outputs=[], scheduler_stats=self.scheduler.make_stats()) - output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore - if len(self.guided_decoding_manager.requests) > 0: - reset_bitmask(scheduler_output.grammar_bitmask) - return engine_core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: @@ -237,7 +236,7 @@ def profile(self, is_start: bool = True): def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() - def setup_request_grammars(self): + def setup_grammars(self): for req in self.guided_decoding_manager.requests: if req.grammar is not None: continue diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 49088b2580594..d2c88bc372f71 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -62,8 +62,8 @@ def accept_token(self, token: int) -> bool: return self.matcher.accept_token(token) # this should be ran in parallel with model decoding - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: - self.matcher.fill_next_token_bitmask(bitmask, idx) + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: + return self.matcher.fill_next_token_bitmask(bitmask, idx) def reset(self): self.matcher.reset() @@ -102,11 +102,13 @@ def __init__(self, vllm_config: VllmConfig): Future[torch.Tensor]]] = None def allocate_bitmask(self) -> None: - self._grammar_bitmask = self.executor.submit( - xgr.allocate_token_bitmask, - self.vllm_config.scheduler_config.max_num_seqs, - self.vocab_size, - ) + # NOTE: We will only want to allocate this once + if self._grammar_bitmask is None: + self._grammar_bitmask = self.executor.submit( + xgr.allocate_token_bitmask, + self.vllm_config.scheduler_config.max_num_seqs, + self.vocab_size, + ) def _ensure_bitmask_ready(self) -> bool: if isinstance(self._grammar_bitmask, Future): @@ -137,11 +139,14 @@ def should_cache(self, request: Request): request.grammar = self.request_key_to_grammar.get( request.guided_decoding_key) if not request.grammar: - request.grammar = self.executor.submit(self._executor_loop, request) + request.grammar = self.cache(request) return True return False - def _executor_loop(self, request: Request): + def cache(self, request: Request): + return self.executor.submit(self._executor_loop, request) + + def _executor_loop(self, request: Request) -> Grammar: key = request.guided_decoding_key self.requests.add(request) if key in self.request_key_to_grammar: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ce93e7257d0fe..0f9c73f8db03f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,7 +10,6 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.guided_decoding import Grammar from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import BlockTable @@ -21,6 +20,7 @@ import torch from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.guided_decoding import Grammar @dataclass @@ -42,7 +42,7 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None - grammar: Optional[Grammar] = None + grammar: Optional["Grammar"] = None @property def num_tokens(self) -> int: @@ -210,7 +210,7 @@ def req_ids(self) -> List[str]: def add_request( self, - request: "CachedRequestState", + request: CachedRequestState, req_index: Optional[int] = None, ) -> None: if req_index is None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6058088974b37..792ced04f9e8d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -969,15 +969,14 @@ def execute_model( logits = self.model.compute_logits(sample_hidden_states, None) # Apply guided decoding bitmasks if present - print(scheduler_output.guided_decoding_request_ids, scheduler_output.grammar_bitmask[0]) if scheduler_output.grammar_bitmask is not None: - if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: - # The bitmask is pre-allocated for the maximum batch size. - # When the batch size is smaller, we need to resize the bitmask - # to match the batch size. - scheduler_output.grammar_bitmask = ( - scheduler_output.grammar_bitmask[:len(self.input_batch. - req_ids)]) + # if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: + # # The bitmask is pre-allocated for the maximum batch size. + # # When the batch size is smaller, we need to resize the bitmask + # # to match the batch size. + # scheduler_output.grammar_bitmask = ( + # scheduler_output.grammar_bitmask[:len(self.input_batch. + # req_ids)]) # TODO: we probably should move this before and # after, this might not be correct apply_bitmask( From 0fc85e3754891a39c0163b4ead75b9a24317af26 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 21 Feb 2025 01:08:54 +0000 Subject: [PATCH 35/84] revert: apply grammar bitmask from update states Signed-off-by: Aaron Pham --- pyproject.toml | 2 + vllm/v1/core/scheduler.py | 68 ++++++++++++++--------------- vllm/v1/engine/core.py | 8 +++- vllm/v1/guided_decoding/__init__.py | 12 +++++ vllm/v1/worker/gpu_model_runner.py | 23 +++++----- 5 files changed, 66 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c03e9e17be55..84106093ee34e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,8 @@ ignore = [ "F405", "F403", # lambda expression assignment "E731", + # line length + "E501", # Loop control variable not used within loop body "B007", # f-string format diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 45afaf840edcd..78f8022bb6b31 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -76,6 +76,9 @@ def __init__( # Request id -> CachedRequestData self._cached_reqs_data: Dict[str, CachedRequestData] = {} + # The list of guided decoding request left within the queue + self.guided_decoding_requests: List[Request] = [] + # Encoder-related. # Calculate encoder cache size if applicable # NOTE: For now we use the same budget for both compute and space. @@ -139,14 +142,15 @@ def schedule(self) -> "SchedulerOutput": num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 - if request.status == RequestStatus.WAITING_FOR_FSM: - # wait for grammar to be ready - req_index += 1 - continue + # Guided decoding related. + if request.use_guided_decoding: + if request.status == RequestStatus.WAITING_FOR_FSM: + # Still waiting for FSM initialization + req_index += 1 + continue - if request.use_guided_decoding \ - and request.request_id not in guided_decoding_request_ids: - guided_decoding_request_ids[request.request_id] = req_index + if request.request_id not in guided_decoding_request_ids: + guided_decoding_request_ids[request.request_id] = req_index # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( @@ -229,26 +233,23 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: num_to_skip: int = 0 - while self.waiting and token_budget > 0: + while num_to_skip < len(self.waiting) and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - if num_to_skip >= len(self.waiting): - break - request = self.waiting[num_to_skip] - if request.status == RequestStatus.WAITING_FOR_FSM: - if request.grammar and request.is_grammar_ready: - request.status = RequestStatus.WAITING - request.grammar.prefilled = True - num_to_skip += 1 - continue - - if request.use_guided_decoding \ - and request.request_id not in guided_decoding_request_ids: + if request.use_guided_decoding: guided_decoding_request_ids[request.request_id] = req_index + if request.status == RequestStatus.WAITING_FOR_FSM: + if request.grammar and request.is_grammar_ready: + request.status = RequestStatus.WAITING + request.grammar.prefilled = True + else: + num_to_skip += 1 + continue + # # Check that adding the request still respects the max_loras # constraint. @@ -572,21 +573,21 @@ def update_from_output( # Handle guided decoding FSM advancement if applicable if (request.use_guided_decoding and request.grammar - and request.is_grammar_ready - and not request.grammar.prefilled): + and request.is_grammar_ready): index = model_runner_output.req_id_to_index.get(req_id) if index is not None: - # TODO - fix spec decode + structured output compatibility - if len(sampled_token_ids[index]) > 1: - logger.error("Structured output does not currently " - "support more than one token at a time. " - "Expect undefined behavior. This may be" - " caused by spec decode + structured " - "output.") - print(sampled_token_ids) - token_id = sampled_token_ids[index][0] - # accept token will also advance the FSM - request.grammar.accept_token(token_id) + token_ids = sampled_token_ids[index] + if len(token_ids) > 1: + logger.error( + "Structured output does not currently support more than one token at a time. Only the first token will be used." + ) + token_id = token_ids[0] + # accept_token advances the FSM + accepted = request.grammar.accept_token(token_id) + if not accepted: + logger.error( + "Failed to advance FSM for request %s with token %d", + req_id, token_id) if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: @@ -693,7 +694,6 @@ def _free_request(self, request: Request) -> None: del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) if request.use_guided_decoding and request.grammar: - assert request.grammar.matcher.is_terminated() request.grammar.reset() def get_num_unfinished_requests(self) -> int: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d05ec3b035564..da73c2b186be4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -149,6 +149,8 @@ def abort_requests(self, request_ids: List[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) + self.guided_decoding_manager.remove_requests(request_ids) + def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -177,6 +179,9 @@ def step(self) -> EngineCoreOutputs: engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore + if len(self.guided_decoding_manager.requests) > 0: + self.guided_decoding_manager.reset_bitmask() + return engine_core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: @@ -251,8 +256,7 @@ def setup_grammars(self): continue # Check if grammar is ready in cache - grammar = self.guided_decoding_manager.request_key_to_grammar.get( - req.guided_decoding_key) + grammar = self.guided_decoding_manager[req.guided_decoding_key] if grammar is not None: req.grammar = grammar continue diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index d2c88bc372f71..107cf16f7510a 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -101,6 +101,9 @@ def __init__(self, vllm_config: VllmConfig): self._grammar_bitmask: Optional[Union[torch.Tensor, Future[torch.Tensor]]] = None + def __getitem__(self, key: GuidedDecodingKey) -> Optional[Grammar]: + return self.request_key_to_grammar.get(key) + def allocate_bitmask(self) -> None: # NOTE: We will only want to allocate this once if self._grammar_bitmask is None: @@ -133,6 +136,15 @@ def is_bitmask_ready(self) -> bool: ) and self._grammar_bitmask.done() return self._grammar_bitmask is not None + def reset_bitmask(self): + reset_bitmask(self.grammar_bitmask) + + def remove_requests(self, request_ids: List[str]) -> None: + self.requests = { + req + for req in self.requests if req.request_id not in request_ids + } + def should_cache(self, request: Request): if not request.use_guided_decoding: return False diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 792ced04f9e8d..35325a467b493 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -968,21 +968,22 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) + # NOTE: We are currently broadcasting the bitmask + # to each worker + grammar_bitmask = scheduler_output.grammar_bitmask + # Apply guided decoding bitmasks if present - if scheduler_output.grammar_bitmask is not None: - # if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: - # # The bitmask is pre-allocated for the maximum batch size. - # # When the batch size is smaller, we need to resize the bitmask - # # to match the batch size. - # scheduler_output.grammar_bitmask = ( - # scheduler_output.grammar_bitmask[:len(self.input_batch. - # req_ids)]) + if grammar_bitmask is not None: + if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: + # The bitmask is pre-allocated for the maximum batch size. + # When the batch size is smaller, we need to resize the bitmask + # to match the batch size. + grammar_bitmask = grammar_bitmask[:len(self.input_batch.req_ids + )] # TODO: we probably should move this before and # after, this might not be correct apply_bitmask( - logits, - scheduler_output.grammar_bitmask.to(self.device, - non_blocking=True), + logits, grammar_bitmask.to(self.device, non_blocking=True), list(scheduler_output.guided_decoding_request_ids.values())) # Sample the next token and get logprobs if needed. From 6a372ea8ca056403a963bf5bf141393b3bc97c52 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 15:54:29 -0500 Subject: [PATCH 36/84] Revert changes to v0 guided decoding tests Signed-off-by: Russell Bryant --- tests/entrypoints/llm/test_guided_generate.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index de9bffe69e800..314dc59328cb0 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import json -import os import re import weakref @@ -15,24 +14,6 @@ MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] -GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] - - -@pytest.fixture(autouse=True) -def v1(request, run_with_both_engines, monkeypatch): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - use_v1 = os.getenv('VLLM_USE_V1') == '1' - if use_v1 and 'guided_decoding_backend' in request.fixturenames: - guided_decoding_backend = request.getfixturevalue( - 'guided_decoding_backend') - if guided_decoding_backend not in GUIDED_DECODING_BACKENDS_V1: - pytest.skip(f"Skipping test because {guided_decoding_backend} " - "is not in GUIDED_DECODING_BACKENDS_V1") - - if use_v1 and "regex" in request.node.name: - pytest.skip("Skipping test because V1 does not support regex") @pytest.fixture(scope="module") From a43afca59eb8f207f21f056ae2ea182c3a84e640 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 16:13:59 -0500 Subject: [PATCH 37/84] create v1 tests_guided_generate for llm entrypoint Signed-off-by: Russell Bryant --- tests/v1/entrypoints/llm/__init__.py | 0 .../entrypoints/llm/test_guided_generate.py | 46 +++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 tests/v1/entrypoints/llm/__init__.py create mode 100644 tests/v1/entrypoints/llm/test_guided_generate.py diff --git a/tests/v1/entrypoints/llm/__init__.py b/tests/v1/entrypoints/llm/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py new file mode 100644 index 0000000000000..bf8cf79d45b33 --- /dev/null +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import GuidedDecodingParams, SamplingParams + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" +GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_completion(monkeypatch, sample_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_json_schema, + backend=guided_decoding_backend)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) From fb40918772bc3a788739d6378528cf6a3b524844 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 16:21:06 -0500 Subject: [PATCH 38/84] Drop unused Scheduler.guided_decoding_requests Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 78f8022bb6b31..218795ce6dd32 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -76,9 +76,6 @@ def __init__( # Request id -> CachedRequestData self._cached_reqs_data: Dict[str, CachedRequestData] = {} - # The list of guided decoding request left within the queue - self.guided_decoding_requests: List[Request] = [] - # Encoder-related. # Calculate encoder cache size if applicable # NOTE: For now we use the same budget for both compute and space. From b8e016c37bd14916bf1bda7430c388d3e3d2256d Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 20:07:21 -0500 Subject: [PATCH 39/84] Allow grammar compilation to complete When checking if the grammar compilation is complete in the scheduler, give it a chance to finish. This prevents requests getting stuck in waiting for forever. Signed-off-by: Russell Bryant --- vllm/v1/request.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3b01bc491663c..ffa18e2ba37f7 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -181,7 +181,7 @@ def guided_decoding_key(self) -> GuidedDecodingKey: def _check_grammar_completion(self) -> bool: if isinstance(self._grammar, Future): try: - self._grammar = self._grammar.result(timeout=0.05) + self._grammar = self._grammar.result(timeout=0.0001) self.status = RequestStatus.WAITING except TimeoutError: return False @@ -189,10 +189,7 @@ def _check_grammar_completion(self) -> bool: @property def is_grammar_ready(self) -> bool: - if isinstance(self._grammar, Future): - return not self._grammar.running() and self._grammar.done() - return (self.status == RequestStatus.WAITING - and self._grammar is not None) + return self._check_grammar_completion() @property def grammar(self) -> Optional[Grammar]: From c63ca92d59a50135f12a8fa946a46b9d66c952bd Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 20:19:21 -0500 Subject: [PATCH 40/84] Remove some dead committed We were checking for pending grammar compilation when iterating the list of running requests. A guided decode request is not put in this list until grammar compilation is already complete, so this condition will never happen. Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 218795ce6dd32..04f01000bff9a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -140,14 +140,9 @@ def schedule(self) -> "SchedulerOutput": assert num_new_tokens > 0 # Guided decoding related. - if request.use_guided_decoding: - if request.status == RequestStatus.WAITING_FOR_FSM: - # Still waiting for FSM initialization - req_index += 1 - continue - - if request.request_id not in guided_decoding_request_ids: - guided_decoding_request_ids[request.request_id] = req_index + if (request.use_guided_decoding + and request.request_id not in guided_decoding_request_ids): + guided_decoding_request_ids[request.request_id] = req_index # Schedule encoder inputs. encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( From 074b65df17fff76ae47173c950cbeab4f3d95a10 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 21:22:26 -0500 Subject: [PATCH 41/84] Fix index calculation for guided requests in a batch We weren't keeping track of the index in the batch of each guided decode request quite right. Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 04f01000bff9a..8a70432bf171a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -140,8 +140,7 @@ def schedule(self) -> "SchedulerOutput": assert num_new_tokens > 0 # Guided decoding related. - if (request.use_guided_decoding - and request.request_id not in guided_decoding_request_ids): + if request.use_guided_decoding: guided_decoding_request_ids[request.request_id] = req_index # Schedule encoder inputs. @@ -231,16 +230,13 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[num_to_skip] - if request.use_guided_decoding: - guided_decoding_request_ids[request.request_id] = req_index - - if request.status == RequestStatus.WAITING_FOR_FSM: - if request.grammar and request.is_grammar_ready: - request.status = RequestStatus.WAITING - request.grammar.prefilled = True - else: - num_to_skip += 1 - continue + if (request.use_guided_decoding + and request.status == RequestStatus.WAITING_FOR_FSM): + if request.grammar and request.is_grammar_ready: + request.status = RequestStatus.WAITING + else: + num_to_skip += 1 + continue # # Check that adding the request still respects the max_loras @@ -297,6 +293,9 @@ def schedule(self) -> "SchedulerOutput": break self.waiting.pop(num_to_skip) + if request.use_guided_decoding: + guided_decoding_request_ids[request.request_id] = req_index + req_index += 1 self.running.append(request) self.scheduled_req_ids.add(request.request_id) if RequestStatus.is_waiting(request.status): From 727dab065b792d4933dcc58bf4d0215ab27d9b77 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sun, 23 Feb 2025 21:24:01 -0500 Subject: [PATCH 42/84] Make guided decoding manager more thread-safe I occasionally see the following error when running the guided decoding: ``` Processed prompts: 0%| | 0/2 [00:00 --- vllm/v1/engine/core.py | 11 +---------- vllm/v1/guided_decoding/__init__.py | 27 ++++++++++++++++++++++----- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9dc6b97691dd8..773e4d3284479 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -254,16 +254,7 @@ def add_lora(self, lora_request: LoRARequest) -> None: self.model_executor.add_lora(lora_request) def setup_grammars(self): - for req in self.guided_decoding_manager.requests: - if req.grammar is not None: - continue - - # Check if grammar is ready in cache - grammar = self.guided_decoding_manager[req.guided_decoding_key] - if grammar is not None: - req.grammar = grammar - continue - self.guided_decoding_manager.allocate_bitmask() + self.guided_decoding_manager.setup_grammars() class EngineCoreProc(EngineCore): diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 107cf16f7510a..b68aa42f7693e 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -3,6 +3,7 @@ import enum import functools +import threading from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union @@ -98,6 +99,7 @@ def __init__(self, vllm_config: VllmConfig): self.executor = ThreadPoolExecutor() self.requests: Set[Request] = set() + self._requests_lock = threading.Lock() self._grammar_bitmask: Optional[Union[torch.Tensor, Future[torch.Tensor]]] = None @@ -140,10 +142,11 @@ def reset_bitmask(self): reset_bitmask(self.grammar_bitmask) def remove_requests(self, request_ids: List[str]) -> None: - self.requests = { - req - for req in self.requests if req.request_id not in request_ids - } + with self._requests_lock: + self.requests = { + req + for req in self.requests if req.request_id not in request_ids + } def should_cache(self, request: Request): if not request.use_guided_decoding: @@ -160,7 +163,8 @@ def cache(self, request: Request): def _executor_loop(self, request: Request) -> Grammar: key = request.guided_decoding_key - self.requests.add(request) + with self._requests_lock: + self.requests.add(request) if key in self.request_key_to_grammar: return self.request_key_to_grammar[key] @@ -192,3 +196,16 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: vocab_size=self.vocab_size, ctx=ctx, ) + + def setup_grammars(self): + with self._requests_lock: + for req in self.requests: + if req.grammar is not None: + continue + + # Check if grammar is ready in cache + grammar = self[req.guided_decoding_key] + if grammar is not None: + req.grammar = grammar + continue + self.allocate_bitmask() From adb50ff2dc89838f171a85e8e961f4e3ed807a07 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 24 Feb 2025 09:54:48 +0000 Subject: [PATCH 43/84] chore: remove prefilled check Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 1 - vllm/v1/worker/gpu_model_runner.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index b68aa42f7693e..bfa214756a479 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -55,7 +55,6 @@ def __init__( self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx - self.prefilled = False def accept_token(self, token: int) -> bool: # NOTE: accept_token will determines whether we accept this token diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3045d8d9e2c65..949d39ecebe33 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -407,15 +407,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: idx = scheduler_output.guided_decoding_request_ids[req_id] # should already be ready assert scheduler_output.grammar_bitmask is not None - if not req_state.grammar.prefilled: - req_state.grammar.prefilled = True - else: - if not req_state.grammar.matcher.is_terminated(): - # NOTE: this relies on xgrammar internal bitmask, - # so we need to give the actual index - # of the the request_id in the batch - req_state.grammar.fill_bitmask( - scheduler_output.grammar_bitmask, idx) + if not req_state.grammar.matcher.is_terminated(): + # NOTE: this relies on xgrammar internal bitmask, + # so we need to give the actual index + # of the the request_id in the batch + req_state.grammar.fill_bitmask( + scheduler_output.grammar_bitmask, idx) # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. From c85408a6b14ddca89adac5a56e69d09290a667a0 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 11:51:29 -0500 Subject: [PATCH 44/84] Re-enable line length checks in ruff Signed-off-by: Russell Bryant --- pyproject.toml | 2 -- vllm/v1/core/scheduler.py | 9 +++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84106093ee34e..1c03e9e17be55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,8 +87,6 @@ ignore = [ "F405", "F403", # lambda expression assignment "E731", - # line length - "E501", # Loop control variable not used within loop body "B007", # f-string format diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 8a70432bf171a..2ab864bcabb32 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -570,15 +570,16 @@ def update_from_output( token_ids = sampled_token_ids[index] if len(token_ids) > 1: logger.error( - "Structured output does not currently support more than one token at a time. Only the first token will be used." - ) + "Structured output does not currently support " + "more than one token at a time. Only the first " + "token will be used.") token_id = token_ids[0] # accept_token advances the FSM accepted = request.grammar.accept_token(token_id) if not accepted: logger.error( - "Failed to advance FSM for request %s with token %d", - req_id, token_id) + "Failed to advance FSM for request %s " + "with token %d", req_id, token_id) if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: From b34e4a71f0175e5b28ba69aab2d0b495c7b244ab Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 11:59:07 -0500 Subject: [PATCH 45/84] Fix a yapf error in main, will be fixed by #13772 Signed-off-by: Russell Bryant --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ddc3ce6f8954..bc9573b36df76 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1271,7 +1271,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.shape[1]] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) From aabe98bff8e9cd54abdd0d2cbf64b36de1f6f6c9 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 14:27:31 -0500 Subject: [PATCH 46/84] Prepare the bitmask on the scheduler side instead of gpu worker Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 22 +++++++++++++++++++++- vllm/v1/core/scheduler_output.py | 11 +---------- vllm/v1/engine/core.py | 7 ++++--- vllm/v1/worker/gpu_model_runner.py | 25 +------------------------ 4 files changed, 27 insertions(+), 38 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 2ab864bcabb32..531f89b8ee7ff 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -13,10 +13,12 @@ SchedulerOutput) from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) +from vllm.v1.guided_decoding import GuidedDecodingManager from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus + logger = init_logger(__name__) @@ -30,13 +32,14 @@ def __init__( lora_config: Optional[LoRAConfig], speculative_config: Optional[SpeculativeConfig], log_stats: bool, + guided_decoding_manager: GuidedDecodingManager, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config self.speculative_config = speculative_config self.log_stats = log_stats - self.vocab_size = model_config.get_vocab_size() + self.guided_decoding_manager = guided_decoding_manager # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -346,6 +349,22 @@ def schedule(self) -> "SchedulerOutput": self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) + # Prepare the guided decoding bitmask for this batch. + grammar_bitmask = None + if guided_decoding_request_ids: + # Fill the bitmask using the index of each request equal to its + # position in the batch. Resize the bitmask down to the size of + # the batch. + grammar_bitmask = self.guided_decoding_manager.grammar_bitmask + assert grammar_bitmask is not None + for req_id, batch_index in guided_decoding_request_ids.items(): + request = self.requests[req_id] + assert request.grammar is not None + if not request.grammar.matcher.is_terminated(): + request.grammar.fill_bitmask(grammar_bitmask, batch_index) + if len(self.running) < grammar_bitmask.shape[0]: + grammar_bitmask = grammar_bitmask[:len(self.running)] + # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, @@ -385,6 +404,7 @@ def schedule(self) -> "SchedulerOutput": finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), guided_decoding_request_ids=guided_decoding_request_ids, + grammar_bitmask=grammar_bitmask, ) self.finished_req_ids = set() diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 02c2f4ee2a91a..3952cc846c844 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -121,13 +121,4 @@ class SchedulerOutput: # for filling the next token bitmask guided_decoding_request_ids: Dict[str, int] # the bitmask for the whole batch - _grammar_bitmask: Optional["torch.Tensor"] = field(default=None, - repr=False) - - @property - def grammar_bitmask(self) -> Optional["torch.Tensor"]: - return self._grammar_bitmask - - @grammar_bitmask.setter - def grammar_bitmask(self, bitmask: "torch.Tensor") -> None: - self._grammar_bitmask = bitmask + grammar_bitmask: Optional["torch.Tensor"] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 773e4d3284479..21ef19d5e5d38 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -62,6 +62,8 @@ def __init__( vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + self.guided_decoding_manager = GuidedDecodingManager(vllm_config) + # Setup scheduler. self.scheduler = Scheduler( scheduler_config=vllm_config.scheduler_config, @@ -70,14 +72,13 @@ def __init__( lora_config=vllm_config.lora_config, speculative_config=vllm_config.speculative_config, log_stats=self.log_stats, + guided_decoding_manager=self.guided_decoding_manager, ) # Setup MM Input Mapper. self.mm_input_cache_server = MMInputCacheServer( vllm_config.model_config) - self.guided_decoding_manager = GuidedDecodingManager(vllm_config) - # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism @@ -158,7 +159,7 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - # Calculate bitmasks for all active requests + # Check for cached grammars and allocate bitmask if necessary self.setup_grammars() scheduler_output = self.scheduler.schedule() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 949d39ecebe33..a9d9ae04aad6a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -401,18 +401,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index, start_index:end_token_index] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec decode tokens. self.input_batch.num_tokens[req_index] = end_token_index - # Fill the bitmask - if (req_id in scheduler_output.guided_decoding_request_ids - and req_state.grammar is not None): - idx = scheduler_output.guided_decoding_request_ids[req_id] - # should already be ready - assert scheduler_output.grammar_bitmask is not None - if not req_state.grammar.matcher.is_terminated(): - # NOTE: this relies on xgrammar internal bitmask, - # so we need to give the actual index - # of the the request_id in the batch - req_state.grammar.fill_bitmask( - scheduler_output.grammar_bitmask, idx) # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. @@ -966,20 +954,9 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) - # NOTE: We are currently broadcasting the bitmask - # to each worker - grammar_bitmask = scheduler_output.grammar_bitmask - # Apply guided decoding bitmasks if present + grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is not None: - if len(self.input_batch.req_ids) < self.input_batch.max_num_reqs: - # The bitmask is pre-allocated for the maximum batch size. - # When the batch size is smaller, we need to resize the bitmask - # to match the batch size. - grammar_bitmask = grammar_bitmask[:len(self.input_batch.req_ids - )] - # TODO: we probably should move this before and - # after, this might not be correct apply_bitmask( logits, grammar_bitmask.to(self.device, non_blocking=True), list(scheduler_output.guided_decoding_request_ids.values())) From 8895e19cdc063179f89e197d4fb174249d4e7482 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 14:45:09 -0500 Subject: [PATCH 47/84] tests: make sample jsonschema xgrammar compatible Signed-off-by: Russell Bryant --- tests/v1/entrypoints/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index b00e168db9d32..e5c6ff026bc00 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -29,6 +29,7 @@ def sample_regex(): r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +# Note: Ensure this only uses attributes compatible with xgrammar @pytest.fixture def sample_json_schema(): return { @@ -44,7 +45,6 @@ def sample_json_schema(): "type": "array", "items": { "type": "string", - "maxLength": 10 }, "minItems": 3 }, From 470b677e42faae066a494e13bf675e91502832d2 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 15:44:48 -0500 Subject: [PATCH 48/84] Detect unsupported jsonschema features for xgrammar Signed-off-by: Russell Bryant --- tests/v1/entrypoints/conftest.py | 3 +- vllm/v1/guided_decoding/__init__.py | 15 ++++++++ vllm/v1/guided_decoding/utils.py | 59 +++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 vllm/v1/guided_decoding/utils.py diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index e5c6ff026bc00..f873e32f4916a 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -45,8 +45,7 @@ def sample_json_schema(): "type": "array", "items": { "type": "string", - }, - "minItems": 3 + } }, "work_history": { "type": "array", diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index bfa214756a479..4e7dc986d4ca9 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -12,10 +12,14 @@ from vllm.config import VllmConfig from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.v1.guided_decoding.utils import ( + has_xgrammar_unsupported_json_features) if TYPE_CHECKING: from vllm.v1.request import Request +import json + class GuidedDecodingOptions(enum.Enum): json = enum.auto() @@ -177,6 +181,17 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: if not isinstance(grammar_spec, str): ctx = self.compiler.compile_builtin_json_grammar() else: + try: + schema = json.loads(grammar_spec) + except json.JSONDecodeError as e: + raise ValueError( + "Invalid JSON grammar specification.") from e + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError( + "The provided JSON schema contains features not " + "supported by xgrammar.") + # TODO -- allow any_whitespace to be configurable # pending merge of https://github.com/vllm-project/vllm/pull/12744 ctx = self.compiler.compile_json_schema(grammar_spec, diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py new file mode 100644 index 0000000000000..fef24416e1f3a --- /dev/null +++ b/vllm/v1/guided_decoding/utils.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 + + + +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for enum restrictions + if "enum" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj for key in [ + "minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf" + ]): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any(key in obj for key in [ + "uniqueItems", "contains", "minContains", "maxContains", + "minItems", "maxItems" + ]): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ["minLength", "maxLength", "format"]): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any(key in obj for key in [ + "minProperties", "maxProperties", "propertyNames", + "patternProperties" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) From 42fe5f86fdd88c72b4a88a2d98767657fb3d7d69 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 16:14:14 -0500 Subject: [PATCH 49/84] Make bitmask allocation synchronous This only happens once and at startup, and it MUST finish before we can start the first decode (if one in the batch is using structured output anyway). Since this is just once at startup, just do it synchronously to simplify the code. Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/__init__.py | 42 +++-------------------------- vllm/v1/guided_decoding/utils.py | 1 - 2 files changed, 4 insertions(+), 39 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 4e7dc986d4ca9..974368218a4ba 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -2,10 +2,9 @@ from __future__ import annotations import enum -import functools import threading -from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import torch import xgrammar as xgr @@ -103,44 +102,12 @@ def __init__(self, vllm_config: VllmConfig): self.executor = ThreadPoolExecutor() self.requests: Set[Request] = set() self._requests_lock = threading.Lock() - self._grammar_bitmask: Optional[Union[torch.Tensor, - Future[torch.Tensor]]] = None + self.grammar_bitmask: torch.Tensor = xgr.allocate_token_bitmask( + self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) def __getitem__(self, key: GuidedDecodingKey) -> Optional[Grammar]: return self.request_key_to_grammar.get(key) - def allocate_bitmask(self) -> None: - # NOTE: We will only want to allocate this once - if self._grammar_bitmask is None: - self._grammar_bitmask = self.executor.submit( - xgr.allocate_token_bitmask, - self.vllm_config.scheduler_config.max_num_seqs, - self.vocab_size, - ) - - def _ensure_bitmask_ready(self) -> bool: - if isinstance(self._grammar_bitmask, Future): - try: - self._grammar_bitmask = self._grammar_bitmask.result( - timeout=0.05) - except TimeoutError: - return False - return True - - @functools.cached_property - def grammar_bitmask(self) -> Optional[torch.Tensor]: - self._ensure_bitmask_ready() - return self._grammar_bitmask if not isinstance(self._grammar_bitmask, - Future) else None - - @property - def is_bitmask_ready(self) -> bool: - self._ensure_bitmask_ready() - if isinstance(self._grammar_bitmask, Future): - return not self._grammar_bitmask.running( - ) and self._grammar_bitmask.done() - return self._grammar_bitmask is not None - def reset_bitmask(self): reset_bitmask(self.grammar_bitmask) @@ -222,4 +189,3 @@ def setup_grammars(self): if grammar is not None: req.grammar = grammar continue - self.allocate_bitmask() diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py index fef24416e1f3a..898b44fe6b173 100644 --- a/vllm/v1/guided_decoding/utils.py +++ b/vllm/v1/guided_decoding/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - def has_xgrammar_unsupported_json_features(schema: dict) -> bool: """Check if JSON schema contains features unsupported by xgrammar.""" From ada4790db387c90cade236855c31509ef42c2bb0 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 16:21:24 -0500 Subject: [PATCH 50/84] Fix compat with TP > 1 Some leftover code was still causing the grammar to get serialized and sent to gpu workers. This wasn't necessary. Stripping it out gets this working with TP>1. Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler_output.py | 5 +---- vllm/v1/worker/gpu_input_batch.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 1 - 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 3952cc846c844..70477105c6461 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple if TYPE_CHECKING: @@ -10,7 +10,6 @@ from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange from vllm.sampling_params import SamplingParams - from vllm.v1.guided_decoding import Grammar from vllm.v1.request import Request @@ -27,7 +26,6 @@ class NewRequestData: block_ids: List[int] num_computed_tokens: int lora_request: Optional["LoRARequest"] - grammar: Optional["Grammar"] @classmethod def from_request( @@ -46,7 +44,6 @@ def from_request( block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, - grammar=request.grammar, ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a3a5cc823022d..d880e4bc0d9ea 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -20,7 +20,6 @@ import torch from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.guided_decoding import Grammar @dataclass @@ -42,7 +41,6 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None - grammar: Optional["Grammar"] = None @property def num_tokens(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a9d9ae04aad6a..af64520d16508 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -308,7 +308,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, - grammar=new_req_data.grammar, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) From 331a7ff62a829dcf5df34055f7d5e9aaf8d7408a Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Feb 2025 16:26:01 -0500 Subject: [PATCH 51/84] Make pre-commit happy again Signed-off-by: Russell Bryant --- vllm/v1/core/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 531f89b8ee7ff..8ac020455084e 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -18,7 +18,6 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus - logger = init_logger(__name__) From 098437929f5095b93794f4e7b575b746ba41e185 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 24 Feb 2025 21:28:59 +0000 Subject: [PATCH 52/84] chore: remove reset_bitmask after every steps Signed-off-by: Aaron Pham --- vllm/v1/engine/core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 21ef19d5e5d38..11a0550d4c481 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -180,9 +180,6 @@ def step(self) -> EngineCoreOutputs: engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore - if len(self.guided_decoding_manager.requests) > 0: - self.guided_decoding_manager.reset_bitmask() - return engine_core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: From 9b62eef70d2ceb9c35de36bdc8edf7ca936f40f9 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 24 Feb 2025 18:00:05 -0500 Subject: [PATCH 53/84] revert: update whitespace Signed-off-by: Aaron Pham --- vllm/v1/utils.py | 4 +-- vllm/v1/worker/gpu_input_batch.py | 47 ++----------------------------- 2 files changed, 5 insertions(+), 46 deletions(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index f2ddd9deb775f..62271255b0c05 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -157,12 +157,12 @@ def bind_kv_cache( This function: 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with kv_caches. - 2) Associates each attention layer in the `forward_context` with its + 2) Associates each attention layer in the `forward_context` with its corresponding KV cache in kv_caches. Args: kv_caches: The allocated kv_caches with layer names as keys. - forward_context: The global forward context containing all Attention + forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d880e4bc0d9ea..bd1c369acb30f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -17,8 +17,6 @@ _SAMPLING_EPS = 1e-5 if TYPE_CHECKING: - import torch - from vllm.multimodal.inputs import PlaceholderRange @@ -145,7 +143,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures @@ -170,7 +168,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() # req_index -> (min_tokens, stop_token_ids) @@ -194,9 +192,6 @@ def __init__( self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs - self.has_allowed_token_ids: Set[str] = set() - self.allowed_token_ids_mask: Optional[torch.Tensor] = None - self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None self.req_output_token_ids: List[Optional[List[int]]] = [] @@ -211,7 +206,7 @@ def req_ids(self) -> List[str]: def add_request( self, - request: CachedRequestState, + request: "CachedRequestState", req_index: Optional[int] = None, ) -> None: if req_index is None: @@ -292,22 +287,6 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias - if sampling_params.allowed_token_ids: - self.has_allowed_token_ids.add(req_id) - if self.allowed_token_ids_mask_cpu_tensor is None: - # Lazy allocation for this tensor, which can be large. - self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( - self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device="cpu") - self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = True - # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -353,9 +332,6 @@ def remove_request(self, req_id: str) -> Optional[int]: self.request_lora_mapping[req_index] = 0 self.logit_bias[req_index] = None - self.has_allowed_token_ids.discard(req_id) - if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) return req_index def condense(self, empty_req_indices: List[int]) -> None: @@ -424,11 +400,6 @@ def condense(self, empty_req_indices: List[int]) -> None: self.logit_bias[empty_index] = self.logit_bias[last_req_index] - if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] - # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -471,13 +442,6 @@ def _make_sampling_metadata(self) -> SamplingMetadata: else: prompt_token_ids = None - allowed_token_ids_mask: Optional[torch.Tensor] = None - if not self.no_allowed_token_ids: - assert self.allowed_token_ids_mask is not None - copy_slice(self.allowed_token_ids_mask_cpu_tensor, - self.allowed_token_ids_mask, num_reqs) - allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] - return SamplingMetadata( temperature=temperature, all_greedy=self.all_greedy, @@ -496,7 +460,6 @@ def _make_sampling_metadata(self) -> SamplingMetadata: min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], - allowed_token_ids_mask=allowed_token_ids_mask, ) def get_sampling_metadata( @@ -587,7 +550,3 @@ def max_num_logprobs(self) -> Optional[int]: @property def no_prompt_logprob(self) -> bool: return not self.num_prompt_logprobs - - @property - def no_allowed_token_ids(self) -> bool: - return len(self.has_allowed_token_ids) == 0 From 2f756e57bd37296f1a771480c36b34eeba1677b3 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 09:36:37 -0500 Subject: [PATCH 54/84] Add tests/v1/guided_decoding/test_utils.py Add unit tests for checking for unsupported xgrammar json schema features. Signed-off-by: Russell Bryant --- tests/v1/guided_decoding/__init__.py | 0 tests/v1/guided_decoding/test_utils.py | 132 +++++++++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 tests/v1/guided_decoding/__init__.py create mode 100644 tests/v1/guided_decoding/test_utils.py diff --git a/tests/v1/guided_decoding/__init__.py b/tests/v1/guided_decoding/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/guided_decoding/test_utils.py b/tests/v1/guided_decoding/test_utils.py new file mode 100644 index 0000000000000..edc304714e2a0 --- /dev/null +++ b/tests/v1/guided_decoding/test_utils.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +from vllm.v1.guided_decoding.utils import ( + has_xgrammar_unsupported_json_features) + + +def test_has_xgrammar_unsupported_json_features(): + schemas_with_unsupported_features: List[dict] = [{ + "type": "string", + "pattern": "^[a-zA-Z]+$" + }, { + "type": + "string", + "enum": ["active", "inactive", "pending"] + }, { + "type": "integer", + "minimum": 0 + }, { + "type": "integer", + "maximum": 120 + }, { + "type": "integer", + "exclusiveMinimum": 120 + }, { + "type": "integer", + "exclusiveMaximum": 120 + }, { + "type": "integer", + "multipleOf": 120 + }, { + "type": "number", + "minimum": 0 + }, { + "type": "number", + "maximum": 120 + }, { + "type": "number", + "exclusiveMinimum": 120 + }, { + "type": "number", + "exclusiveMaximum": 120 + }, { + "type": "number", + "multipleOf": 120 + }, { + "type": "array", + "uniqueItems": True + }, { + "type": "array", + "contains": { + "type": "string" + } + }, { + "type": "array", + "minContains": 1 + }, { + "type": "array", + "maxContains": 5 + }, { + "type": "array", + "minItems": 1 + }, { + "type": "array", + "maxItems": 10 + }, { + "type": "string", + "minLength": 1 + }, { + "type": "string", + "maxLength": 100 + }, { + "type": "string", + "format": "email" + }, { + "type": "object", + "minProperties": 1 + }, { + "type": "object", + "maxProperties": 5 + }, { + "type": "object", + "propertyNames": { + "pattern": "^[a-z]+$" + } + }, { + "type": "object", + "patternProperties": { + "^S": { + "type": "string" + } + } + }] + + for schema in schemas_with_unsupported_features: + assert has_xgrammar_unsupported_json_features(schema) + + schema_without_unsupported_features: dict = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "status": { + "type": "string" + }, + "scores": { + "type": "array", + "items": { + "type": "number" + } + }, + "address": { + "type": "object", + "properties": { + "street": { + "type": "string" + }, + "city": { + "type": "string" + } + } + } + } + } + + assert not has_xgrammar_unsupported_json_features( + schema_without_unsupported_features) From 1be1709a400e026520d095067cfeacb7eb29d75d Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 10:02:53 -0500 Subject: [PATCH 55/84] add v1 structured output regex test case This isn't supported. Check for the expected exception. Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index bf8cf79d45b33..2fea10f59d5ef 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -44,3 +44,35 @@ def test_guided_json_completion(monkeypatch, sample_json_schema, print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="grammar is not of valid supported types"): + llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + # Once regex is supported -- + #assert outputs is not None + #for output in outputs: + # assert output is not None + # assert isinstance(output, RequestOutput) + # prompt = output.prompt + # generated_text = output.outputs[0].text + # print(generated_text) + # assert generated_text is not None + # assert re.fullmatch(sample_regex, generated_text) is not None + # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 0128affc9f3891b3070e2b04c4767d8959269604 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 10:03:51 -0500 Subject: [PATCH 56/84] Restore some code lost in a merge from main Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_input_batch.py | 43 +++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bd1c369acb30f..d9fc53490c076 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -143,7 +143,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures @@ -168,7 +168,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() # req_index -> (min_tokens, stop_token_ids) @@ -192,6 +192,9 @@ def __init__( self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs + self.has_allowed_token_ids: Set[str] = set() + self.allowed_token_ids_mask: Optional[torch.Tensor] = None + self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None self.req_output_token_ids: List[Optional[List[int]]] = [] @@ -287,6 +290,22 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = True + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -332,6 +351,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.request_lora_mapping[req_index] = 0 self.logit_bias[req_index] = None + self.has_allowed_token_ids.discard(req_id) + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) return req_index def condense(self, empty_req_indices: List[int]) -> None: @@ -400,6 +422,11 @@ def condense(self, empty_req_indices: List[int]) -> None: self.logit_bias[empty_index] = self.logit_bias[last_req_index] + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[ + empty_index] = self.allowed_token_ids_mask_cpu_tensor[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -442,6 +469,13 @@ def _make_sampling_metadata(self) -> SamplingMetadata: else: prompt_token_ids = None + allowed_token_ids_mask: Optional[torch.Tensor] = None + if not self.no_allowed_token_ids: + assert self.allowed_token_ids_mask is not None + copy_slice(self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, num_reqs) + allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + return SamplingMetadata( temperature=temperature, all_greedy=self.all_greedy, @@ -460,6 +494,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], + allowed_token_ids_mask=allowed_token_ids_mask, ) def get_sampling_metadata( @@ -550,3 +585,7 @@ def max_num_logprobs(self) -> Optional[int]: @property def no_prompt_logprob(self) -> bool: return not self.num_prompt_logprobs + + @property + def no_allowed_token_ids(self) -> bool: + return len(self.has_allowed_token_ids) == 0 From 9cc90ff22f74c03faef8fc3ab19f78e415f16438 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 10:18:33 -0500 Subject: [PATCH 57/84] Validate schema is supoprted before sending to threadpool If we send it off to the thread pool, we can't handle the ValueError exception as easily. We still need to deal with handling if xgrammar fails on something we deemed supported, but that's for later. Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/__init__.py | 40 ++++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 974368218a4ba..6d2e61e31feaa 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -10,6 +10,7 @@ import xgrammar as xgr from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.guided_decoding.utils import ( has_xgrammar_unsupported_json_features) @@ -19,6 +20,8 @@ import json +logger = init_logger(__name__) + class GuidedDecodingOptions(enum.Enum): json = enum.auto() @@ -129,6 +132,7 @@ def should_cache(self, request: Request): return False def cache(self, request: Request): + self._validate_grammer_is_supported(request.guided_decoding_key) return self.executor.submit(self._executor_loop, request) def _executor_loop(self, request: Request) -> Grammar: @@ -141,6 +145,24 @@ def _executor_loop(self, request: Request) -> Grammar: self.request_key_to_grammar[key] = self.initialize_grammar(key) return self.request_key_to_grammar[key] + def _validate_grammer_is_supported(self, key: GuidedDecodingKey): + request_type, grammar_spec = key + if request_type == GuidedDecodingOptions.json: + try: + schema = json.loads(grammar_spec) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError( + "The provided JSON schema contains features not " + "supported by xgrammar.") + return + elif request_type == GuidedDecodingOptions.grammar: + return + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})") + def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: request_type, grammar_spec = key @@ -148,29 +170,17 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: if not isinstance(grammar_spec, str): ctx = self.compiler.compile_builtin_json_grammar() else: - try: - schema = json.loads(grammar_spec) - except json.JSONDecodeError as e: - raise ValueError( - "Invalid JSON grammar specification.") from e - - if has_xgrammar_unsupported_json_features(schema): - raise ValueError( - "The provided JSON schema contains features not " - "supported by xgrammar.") - # TODO -- allow any_whitespace to be configurable # pending merge of https://github.com/vllm-project/vllm/pull/12744 ctx = self.compiler.compile_json_schema(grammar_spec, any_whitespace=False) elif request_type == GuidedDecodingOptions.grammar: ctx = self.compiler.compile_grammar(grammar_spec) - # elif request_type == GuidedDecodingOptions.regex: - # ctx = self.compiler.compile_regex(grammar_spec) else: + logger.error("Validation should have already occurred. " + "Please file an issue.") raise ValueError( - f"`grammar` is not of valid supported types. ({request_type!s})" - ) + f"grammar is not of valid supported types. ({request_type!s})") return Grammar( matcher=xgr.GrammarMatcher(ctx), From 3a8f955a9400ef741c9999071c20e86280992431 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 25 Feb 2025 10:25:32 -0500 Subject: [PATCH 58/84] chore: remove unused code Given that we are not using remove_bitmask, this is a legacy deadcode that is not used in current implementation Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 6d2e61e31feaa..04e18fa7b2976 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -33,11 +33,6 @@ class GuidedDecodingOptions(enum.Enum): GuidedDecodingKey = Tuple[GuidedDecodingOptions, str] -def reset_bitmask(bitmask: torch.Tensor): - # this calls bitmask.fill_(tensor([1, 1, ...], dtype=int32)) - xgr.reset_token_bitmask(bitmask) - - def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor, indices: List[int]) -> None: xgr.apply_token_bitmask_inplace(logits, vocab_mask, indices=indices) @@ -111,9 +106,6 @@ def __init__(self, vllm_config: VllmConfig): def __getitem__(self, key: GuidedDecodingKey) -> Optional[Grammar]: return self.request_key_to_grammar.get(key) - def reset_bitmask(self): - reset_bitmask(self.grammar_bitmask) - def remove_requests(self, request_ids: List[str]) -> None: with self._requests_lock: self.requests = { From e772efa07c6d190225fca0ba451ef2ab731ee4fe Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 25 Feb 2025 15:36:16 +0000 Subject: [PATCH 59/84] fix: correct typo Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 04e18fa7b2976..1e6824680f777 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -124,7 +124,7 @@ def should_cache(self, request: Request): return False def cache(self, request: Request): - self._validate_grammer_is_supported(request.guided_decoding_key) + self._validate_grammar_is_supported(request.guided_decoding_key) return self.executor.submit(self._executor_loop, request) def _executor_loop(self, request: Request) -> Grammar: @@ -137,7 +137,7 @@ def _executor_loop(self, request: Request) -> Grammar: self.request_key_to_grammar[key] = self.initialize_grammar(key) return self.request_key_to_grammar[key] - def _validate_grammer_is_supported(self, key: GuidedDecodingKey): + def _validate_grammar_is_supported(self, key: GuidedDecodingKey): request_type, grammar_spec = key if request_type == GuidedDecodingOptions.json: try: From 64a2ecff427fb6aaef18b405d0702709908414a8 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 25 Feb 2025 15:59:12 +0000 Subject: [PATCH 60/84] chore(scheduler): simplify check for use_guided_decoding Given that the grammar should already be ready once use_guided_decoding is set Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 8ac020455084e..f4174f2949514 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -582,8 +582,9 @@ def update_from_output( new_token_ids: List[int] = [] # Handle guided decoding FSM advancement if applicable - if (request.use_guided_decoding and request.grammar - and request.is_grammar_ready): + # NOTE: For all requests that uses guided decoding, the grammar + # should be ready at this point. + if request.use_guided_decoding: index = model_runner_output.req_id_to_index.get(req_id) if index is not None: token_ids = sampled_token_ids[index] @@ -593,8 +594,10 @@ def update_from_output( "more than one token at a time. Only the first " "token will be used.") token_id = token_ids[0] + assert request.grammar is not None # accept_token advances the FSM - accepted = request.grammar.accept_token(token_id) + accepted = request.grammar.accept_token( + token_id) # type: ignore[union-attr] if not accepted: logger.error( "Failed to advance FSM for request %s " @@ -704,8 +707,10 @@ def _free_request(self, request: Request) -> None: self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) - if request.use_guided_decoding and request.grammar: - request.grammar.reset() + if request.use_guided_decoding: + # NOTE: grammar should NOT be None + # if use_guided_decoding is True + request.grammar.reset() # type: ignore[union-attr] def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) From e8f47f344d2aaa8202ddb3aa412ba6b8b9adf705 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 14:06:57 -0500 Subject: [PATCH 61/84] Move guided decode validation to the engine core_client The initial validation checks to see if the gudied decode features are currently supported. If not, we might as well catch it in the client instead of sending it over to the engine knowing it'll fail. Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 2 +- vllm/v1/engine/core_client.py | 13 +++++++++ vllm/v1/guided_decoding/__init__.py | 27 +++--------------- vllm/v1/guided_decoding/utils.py | 28 +++++++++++++++++++ 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 2fea10f59d5ef..9082d013d54e0 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -58,7 +58,7 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): regex=sample_regex, backend=guided_decoding_backend)) with pytest.raises(ValueError, - match="grammar is not of valid supported types"): + match="Regex guided decoding is not supported."): llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9f36e11d12d76..ff901b381bd5d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -24,6 +24,7 @@ EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor +from vllm.v1.guided_decoding.utils import validate_guided_decoding_request from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.utils import BackgroundProcHandle @@ -66,6 +67,15 @@ def make_client( return InprocClient(vllm_config, executor_class, log_stats) + @staticmethod + def _validate_request(request: EngineCoreRequest) -> None: + """Validate request before sending to EngineCore. + + Raises ValueError if request contents are known to be invalid or + unsupported. + """ + validate_guided_decoding_request(request.sampling_params) + @abstractmethod def shutdown(self): ... @@ -160,6 +170,7 @@ def get_output(self) -> EngineCoreOutputs: return self.engine_core.step() def add_request(self, request: EngineCoreRequest) -> None: + self._validate_request(request) self.engine_core.add_request(request) def abort_requests(self, request_ids: List[str]) -> None: @@ -368,6 +379,7 @@ def _call_utility(self, method: str, *args) -> Any: return future.result() def add_request(self, request: EngineCoreRequest) -> None: + self._validate_request(request) # NOTE: text prompt is not needed in the core engine as it has been # tokenized. request.prompt = None @@ -466,6 +478,7 @@ async def _call_utility_async(self, method: str, *args) -> Any: return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + self._validate_request(request) # NOTE: text prompt is not needed in the core engine as it has been # tokenized. request.prompt = None diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 1e6824680f777..9785b647dd6d7 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -12,14 +12,10 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.guided_decoding.utils import ( - has_xgrammar_unsupported_json_features) if TYPE_CHECKING: from vllm.v1.request import Request -import json - logger = init_logger(__name__) @@ -124,7 +120,6 @@ def should_cache(self, request: Request): return False def cache(self, request: Request): - self._validate_grammar_is_supported(request.guided_decoding_key) return self.executor.submit(self._executor_loop, request) def _executor_loop(self, request: Request) -> Grammar: @@ -137,25 +132,11 @@ def _executor_loop(self, request: Request) -> Grammar: self.request_key_to_grammar[key] = self.initialize_grammar(key) return self.request_key_to_grammar[key] - def _validate_grammar_is_supported(self, key: GuidedDecodingKey): - request_type, grammar_spec = key - if request_type == GuidedDecodingOptions.json: - try: - schema = json.loads(grammar_spec) - except json.JSONDecodeError as e: - raise ValueError("Invalid JSON grammar specification.") from e - - if has_xgrammar_unsupported_json_features(schema): - raise ValueError( - "The provided JSON schema contains features not " - "supported by xgrammar.") - return - elif request_type == GuidedDecodingOptions.grammar: - return - raise ValueError( - f"grammar is not of valid supported types. ({request_type!s})") - def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: + # Note that the request was validated in the engine core client, + # so at this point we know it is a supported type of request. + # + # TODO: we still need to handle xgrammar compilation failures request_type, grammar_spec = key if request_type == GuidedDecodingOptions.json: diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py index 898b44fe6b173..fa355229779ab 100644 --- a/vllm/v1/guided_decoding/utils.py +++ b/vllm/v1/guided_decoding/utils.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +import json + +from vllm.sampling_params import SamplingParams + def has_xgrammar_unsupported_json_features(schema: dict) -> bool: """Check if JSON schema contains features unsupported by xgrammar.""" @@ -56,3 +60,27 @@ def check_object(obj: dict) -> bool: return False return check_object(schema) + + +def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: + """Validate that the request is supported by guided decoding. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + if gd_params.regex: + raise ValueError("Regex guided decoding is not supported.") + if gd_params.choice: + raise ValueError("Choice guided decoding is not supported.") + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") From f3f7d51119e907782eff10201c723ed6ee6af9b8 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 14:41:32 -0500 Subject: [PATCH 62/84] test for expected behavior of a choice guided decode request Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 9082d013d54e0..6ef0eafc119c6 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -76,3 +76,35 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): # assert generated_text is not None # assert re.fullmatch(sample_regex, generated_text) is not None # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_choice_completion(monkeypatch, sample_guided_choice, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + choice=sample_guided_choice, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="Choice guided decoding is not supported."): + llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True) + + # Once choice is supported -- + #assert outputs is not None + #for output in outputs: + # assert output is not None + # assert isinstance(output, RequestOutput) + # prompt = output.prompt + # generated_text = output.outputs[0].text + # print(generated_text) + # assert generated_text is not None + # assert generated_text in sample_guided_choice + # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 9582f8c947ee2930402dce6577927b2c945a58df Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 14:54:13 -0500 Subject: [PATCH 63/84] Validate jsonschema features for both str and dict cases Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py index fa355229779ab..61f4dbb287ced 100644 --- a/vllm/v1/guided_decoding/utils.py +++ b/vllm/v1/guided_decoding/utils.py @@ -71,15 +71,21 @@ def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: return gd_params = sampling_params.guided_decoding + if gd_params.regex: raise ValueError("Regex guided decoding is not supported.") + if gd_params.choice: raise ValueError("Choice guided decoding is not supported.") - if isinstance(gd_params.json, str): - try: - schema = json.loads(gd_params.json) - except json.JSONDecodeError as e: - raise ValueError("Invalid JSON grammar specification.") from e + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json if has_xgrammar_unsupported_json_features(schema): raise ValueError("The provided JSON schema contains features not " From acd5ae003e11f31770440e6d98f90c103bc92d81 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 14:54:33 -0500 Subject: [PATCH 64/84] Test for expected behavior of a request with unsupported jsonschema features Signed-off-by: Russell Bryant --- tests/v1/entrypoints/conftest.py | 3 ++- .../entrypoints/llm/test_guided_generate.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index f873e32f4916a..22c2add71fd2b 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -70,8 +70,9 @@ def sample_json_schema(): } +# A schema unsupported by xgrammar @pytest.fixture -def sample_complex_json_schema(): +def unsupported_json_schema(): return { "type": "object", "properties": { diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 6ef0eafc119c6..95d8172eade58 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -46,6 +46,29 @@ def test_guided_json_completion(monkeypatch, sample_json_schema, jsonschema.validate(instance=output_json, schema=sample_json_schema) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=unsupported_json_schema, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar."): + llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {unsupported_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) From 4c674ae3306aee297ef6ea49b90766310dfd9b59 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 15:06:43 -0500 Subject: [PATCH 65/84] Correctly differentiate between jsonschema and json object requests Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/__init__.py | 14 +++++++------- vllm/v1/request.py | 4 +++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 9785b647dd6d7..55d6af7584a51 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -21,6 +21,7 @@ class GuidedDecodingOptions(enum.Enum): json = enum.auto() + json_object = enum.auto() regex = enum.auto() grammar = enum.auto() choice = enum.auto() @@ -140,13 +141,12 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: request_type, grammar_spec = key if request_type == GuidedDecodingOptions.json: - if not isinstance(grammar_spec, str): - ctx = self.compiler.compile_builtin_json_grammar() - else: - # TODO -- allow any_whitespace to be configurable - # pending merge of https://github.com/vllm-project/vllm/pull/12744 - ctx = self.compiler.compile_json_schema(grammar_spec, - any_whitespace=False) + # TODO -- allow any_whitespace to be configurable + # pending merge of https://github.com/vllm-project/vllm/pull/12744 + ctx = self.compiler.compile_json_schema(grammar_spec, + any_whitespace=False) + elif request_type == GuidedDecodingOptions.json_object: + ctx = self.compiler.compile_builtin_json_grammar() elif request_type == GuidedDecodingOptions.grammar: ctx = self.compiler.compile_grammar(grammar_spec) else: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ffa18e2ba37f7..765b203bf962d 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -160,11 +160,13 @@ def guided_decoding_key(self) -> GuidedDecodingKey: params = self.sampling_params.guided_decoding assert params is not None, "params can't be None." if params.json is not None: - if params.json_object or not isinstance(params.json, str): + if not isinstance(params.json, str): json_str = json.dumps(params.json) else: json_str = params.json return (GuidedDecodingOptions.json, json_str) + elif params.json_object: + return (GuidedDecodingOptions.json_object, "") elif params.regex is not None: return (GuidedDecodingOptions.regex, params.regex) elif params.choice is not None: From 1b40882d1e8c5dd063fa652746fcb3b5fc8c4cb3 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 15:07:07 -0500 Subject: [PATCH 66/84] Test for correct json object (no schema) request behavior Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 95d8172eade58..42d6ae8588590 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -46,6 +46,40 @@ def test_guided_json_completion(monkeypatch, sample_json_schema, jsonschema.validate(instance=output_json, schema=sample_json_schema) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_object(monkeypatch, guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams( + json_object=True, + backend=guided_decoding_backend)) + + outputs = llm.generate( + prompts=("Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old."), + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + for i in range(2): + generated_text = output.outputs[i].text + print(generated_text) + assert generated_text is not None + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) From 4f551f46af71cd7b129d8d22e59691bcd67faaab Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 15:26:00 -0500 Subject: [PATCH 67/84] Add test for a request using an EBNF style grammar Signed-off-by: Russell Bryant --- tests/v1/entrypoints/conftest.py | 18 ++++----- .../entrypoints/llm/test_guided_generate.py | 38 +++++++++++++++++++ 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 22c2add71fd2b..7e744bf1cbfe9 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -150,12 +150,12 @@ def sample_guided_choice(): @pytest.fixture -def sample_sql_statements(): - return (""" -start: select_statement -select_statement: "SELECT" column "from" table "where" condition -column: "col_1" | "col_2" -table: "table_1" | "table_2" -condition: column "=" number -number: "1" | "2" -""") +def sample_sql_ebnf(): + return """ +root ::= select_statement +select_statement ::= "SELECT" column "from" table "where" condition +column ::= "col_1" | "col_2" +table ::= "table_1" | "table_2" +condition ::= column "=" number +number ::= "1" | "2" +""" diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 42d6ae8588590..b09e5acdde47f 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -80,6 +80,44 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str): assert isinstance(parsed_json, dict) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_ebnf, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) From d132d72acb9c9c0a0b76933b99999671a862cddc Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 15:44:14 -0500 Subject: [PATCH 68/84] Validate that EBNF grammar can be parsed during early validation Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py index 61f4dbb287ced..c9acc951d0925 100644 --- a/vllm/v1/guided_decoding/utils.py +++ b/vllm/v1/guided_decoding/utils.py @@ -2,6 +2,8 @@ import json +import xgrammar + from vllm.sampling_params import SamplingParams @@ -90,3 +92,11 @@ def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: if has_xgrammar_unsupported_json_features(schema): raise ValueError("The provided JSON schema contains features not " "supported by xgrammar.") + + if gd_params.grammar: + # EBNF style grammars only right now + try: + # parse the grammar, but we aren't compiling it. + xgrammar.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e From b994230769f01a6c94dff4c4fff27aadd0f93eba Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 15:44:43 -0500 Subject: [PATCH 69/84] Test for expected behavior of an invalid grammar Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 48 ++++++++++++++----- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index b09e5acdde47f..3aa66efcbed15 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -80,6 +80,29 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str): assert isinstance(parsed_json, dict) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=unsupported_json_schema, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar."): + llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {unsupported_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) @@ -121,24 +144,23 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) -def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, - guided_decoding_backend: str): +def test_guided_grammar_ebnf_invalid(monkeypatch, + guided_decoding_backend: str): monkeypatch.setenv("VLLM_USE_V1", "1") llm = LLM(model=MODEL_NAME, max_model_len=1024) - sampling_params = SamplingParams(temperature=1.0, + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, max_tokens=1000, guided_decoding=GuidedDecodingParams( - json=unsupported_json_schema, + grammar="not a grammar", backend=guided_decoding_backend)) - with pytest.raises(ValueError, - match="The provided JSON schema contains features " - "not supported by xgrammar."): - llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {unsupported_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + with pytest.raises(ValueError, match="Invalid grammar specification."): + llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) @pytest.mark.skip_global_cleanup From 3cc6437bb80c5c6809f16a9d7446d99e09b71481 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 16:01:14 -0500 Subject: [PATCH 70/84] Add support and test coverage for lark style grammars Signed-off-by: Russell Bryant --- tests/v1/entrypoints/conftest.py | 12 ++ .../entrypoints/llm/test_guided_generate.py | 43 +++++ vllm/v1/guided_decoding/utils.py | 168 +++++++++++++++++- 3 files changed, 222 insertions(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 7e744bf1cbfe9..6d4278b4c8719 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -159,3 +159,15 @@ def sample_sql_ebnf(): condition ::= column "=" number number ::= "1" | "2" """ + + +@pytest.fixture +def sample_sql_lark(): + return (""" +start: select_statement +select_statement: "SELECT" column "from" table "where" condition +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number +number: "1" | "2" +""") diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 3aa66efcbed15..7b15261083828 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -141,6 +141,49 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_lark(monkeypatch, sample_sql_lark, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_lark, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_lark) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS_V1) diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py index c9acc951d0925..f9778ed4a4ebc 100644 --- a/vllm/v1/guided_decoding/utils.py +++ b/vllm/v1/guided_decoding/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import re import xgrammar @@ -64,6 +65,163 @@ def check_object(obj: dict) -> bool: return check_object(schema) +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for EBNF rule definition + if '::=' in line: + return False + + return True + + +def convert_lark_to_ebnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to EBNF format. + + EBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in EBNF format + + Examples: + >>> print(convert_lark_to_ebnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) + + def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: """Validate that the request is supported by guided decoding. @@ -94,7 +252,15 @@ def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: "supported by xgrammar.") if gd_params.grammar: - # EBNF style grammars only right now + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark try: # parse the grammar, but we aren't compiling it. xgrammar.Grammar.from_ebnf(gd_params.grammar) From 95be24b71a8d031270c9868ecfe426a3bff3bee0 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 25 Feb 2025 16:28:11 -0500 Subject: [PATCH 71/84] Add support and tests for choice based guided decoding Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 31 +++++++++---------- vllm/v1/guided_decoding/utils.py | 24 +++++++++++++- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 7b15261083828..6c30d5dc3e555 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -250,21 +250,18 @@ def test_guided_choice_completion(monkeypatch, sample_guided_choice, guided_decoding=GuidedDecodingParams( choice=sample_guided_choice, backend=guided_decoding_backend)) - with pytest.raises(ValueError, - match="Choice guided decoding is not supported."): - llm.generate( - prompts="The best language for type-safe systems programming is ", - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True) - # Once choice is supported -- - #assert outputs is not None - #for output in outputs: - # assert output is not None - # assert isinstance(output, RequestOutput) - # prompt = output.prompt - # generated_text = output.outputs[0].text - # print(generated_text) - # assert generated_text is not None - # assert generated_text in sample_guided_choice - # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/v1/guided_decoding/utils.py b/vllm/v1/guided_decoding/utils.py index f9778ed4a4ebc..e01d8771be244 100644 --- a/vllm/v1/guided_decoding/utils.py +++ b/vllm/v1/guided_decoding/utils.py @@ -2,6 +2,7 @@ import json import re +from typing import List import xgrammar @@ -222,6 +223,18 @@ def extract_references(text: str) -> set: return '\n'.join(output_lines) +def choice_as_grammar(choice: List[str]) -> str: + + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + escaped_choices = (escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + + def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: """Validate that the request is supported by guided decoding. @@ -236,7 +249,15 @@ def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: raise ValueError("Regex guided decoding is not supported.") if gd_params.choice: - raise ValueError("Choice guided decoding is not supported.") + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgrammar.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return if gd_params.json: if isinstance(gd_params.json, str): @@ -250,6 +271,7 @@ def validate_guided_decoding_request(sampling_params: SamplingParams) -> None: if has_xgrammar_unsupported_json_features(schema): raise ValueError("The provided JSON schema contains features not " "supported by xgrammar.") + return if gd_params.grammar: if grammar_is_likely_lark(gd_params.grammar): From 9d1fe71a3c355b7511cd65b1d5e15fc53e178147 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 25 Feb 2025 22:11:46 +0000 Subject: [PATCH 72/84] feat: spec decode compatibility [-------------] Ugh it works Also clean up a few things with notes here and there Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 45 +++++++++++++++++------- vllm/v1/guided_decoding/__init__.py | 53 +++++++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 12 +++++-- 3 files changed, 78 insertions(+), 32 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f4174f2949514..412ebc1f19473 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -60,7 +60,11 @@ def __init__( # req_id -> Request self.requests: Dict[str, Request] = {} - # Priority queues for requests. + # NOTE: Priority queues for requests. + # With list, we can safely pop the index + # of a request that are yet to be ready (in this case, + # the one that uses guided decoding) while still maintaining + # the order of all requests in existing waiting queue. self.waiting: List[Request] = [] self.running: List[Request] = [] # The requests that have been scheduled and are being executed @@ -114,6 +118,13 @@ def schedule(self) -> "SchedulerOutput": scheduled_resumed_reqs: List[Request] = [] scheduled_running_reqs: List[Request] = [] preempted_reqs: List[Request] = [] + + # NOTE: guided_decoding_request_ids maps + # guided request's (request that use structured decoding) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. guided_decoding_request_ids: Dict[str, int] = {} req_to_new_block_ids: Dict[str, List[int]] = {} @@ -225,6 +236,9 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: + # NOTE: We uses num_to_skip to determine + # which guided request within the waiting queue to skip + # over if the FSM of said request are yet to be ready. num_to_skip: int = 0 while num_to_skip < len(self.waiting) and token_budget > 0: if len(self.running) == self.max_num_running_reqs: @@ -584,24 +598,31 @@ def update_from_output( # Handle guided decoding FSM advancement if applicable # NOTE: For all requests that uses guided decoding, the grammar # should be ready at this point. + # PERF: This is currently expensive given that FSM is being + # advanced here. if request.use_guided_decoding: + grammar = request.grammar + assert grammar is not None index = model_runner_output.req_id_to_index.get(req_id) if index is not None: - token_ids = sampled_token_ids[index] - if len(token_ids) > 1: - logger.error( - "Structured output does not currently support " - "more than one token at a time. Only the first " - "token will be used.") - token_id = token_ids[0] - assert request.grammar is not None # accept_token advances the FSM - accepted = request.grammar.accept_token( - token_id) # type: ignore[union-attr] + has_accept_tokens = [ + grammar.accept_token(token_id) + for token_id in generated_token_ids + ] + accepted = any(has_accept_tokens) if not accepted: + grammar.rollback(len(has_accept_tokens)) logger.error( "Failed to advance FSM for request %s " - "with token %d", req_id, token_id) + "for all draft " + "tokens %s", req_id, generated_token_ids) + stopped = True + else: + grammar.rollback(len(has_accept_tokens) - 1) + if stopped: + self._free_request(request) + continue if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 55d6af7584a51..2d2f55bd01cc5 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -4,6 +4,7 @@ import enum import threading from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import torch @@ -28,6 +29,7 @@ class GuidedDecodingOptions(enum.Enum): GuidedDecodingKey = Tuple[GuidedDecodingOptions, str] +MAX_ROLLBACK_TOKENS = 100 def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor, @@ -35,6 +37,7 @@ def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor, xgr.apply_token_bitmask_inplace(logits, vocab_mask, indices=indices) +@dataclass(slots=True, unsafe_hash=True) class Grammar: # NOTE: This would be a generic-enough class for # supporting different backends, in the future. @@ -44,35 +47,47 @@ class Grammar: # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string # for jump-forward decoding - def __init__( - self, - matcher: xgr.GrammarMatcher, - vocab_size: int, - ctx: xgr.CompiledGrammar, - ) -> None: - self.matcher = matcher - self.vocab_size = vocab_size - self.ctx = ctx + vocab_size: int + matcher: xgr.GrammarMatcher = field(hash=False) + ctx: xgr.CompiledGrammar = field(hash=False) + max_rollback_tokens: int = field(default=MAX_ROLLBACK_TOKENS, kw_only=True) + num_processed_tokens: int = field( + default_factory=lambda: 0, + repr=False, + hash=False, + init=False, + ) + _accept_lock: threading.Lock = field( + default_factory=lambda: threading.Lock(), + repr=False, + init=False, + hash=False, + ) def accept_token(self, token: int) -> bool: # NOTE: accept_token will determines whether we accept this token # and will also update the machine state - return self.matcher.accept_token(token) + with self._accept_lock: + self.num_processed_tokens += 1 + return self.matcher.accept_token(token) # this should be ran in parallel with model decoding def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: return self.matcher.fill_next_token_bitmask(bitmask, idx) + def rollback(self, num_tokens: int): + self.num_processed_tokens -= num_tokens + self.matcher.rollback(num_tokens) + def reset(self): + self.num_processed_tokens = 0 self.matcher.reset() - def copy(self): + def __copy__(self): return Grammar(matcher=xgr.GrammarMatcher(self.ctx), vocab_size=self.vocab_size, - ctx=self.ctx) - - def __copy__(self): - return self.copy() + ctx=self.ctx, + max_rollback_tokens=self.max_rollback_tokens) class GuidedDecodingManager: @@ -82,7 +97,7 @@ def __init__(self, vllm_config: VllmConfig): model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, - lora_config=vllm_config.lora_config) + lora_config=vllm_config.lora_config) # type: ignore[arg-type] tokenizer_group.ping() self.vocab_size = vllm_config.model_config.get_vocab_size() self.vllm_config = vllm_config @@ -97,7 +112,7 @@ def __init__(self, vllm_config: VllmConfig): self.executor = ThreadPoolExecutor() self.requests: Set[Request] = set() self._requests_lock = threading.Lock() - self.grammar_bitmask: torch.Tensor = xgr.allocate_token_bitmask( + self.grammar_bitmask = xgr.allocate_token_bitmask( self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) def __getitem__(self, key: GuidedDecodingKey) -> Optional[Grammar]: @@ -159,7 +174,9 @@ def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: matcher=xgr.GrammarMatcher(ctx), vocab_size=self.vocab_size, ctx=ctx, - ) + max_rollback_tokens=self.vllm_config.speculative_config. + num_lookahead_slots + if self.vllm_config.speculative_config else MAX_ROLLBACK_TOKENS) def setup_grammars(self): with self._requests_lock: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 333f0069a0fcd..b795b18763787 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -954,9 +954,17 @@ def execute_model( # Apply guided decoding bitmasks if present grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is not None: + # NOTE: A non-blocking copy causes incorrect behaviour + # in speculative decoding (and we should use blocking copy + # for speculative decoding use-case) + # + # TODO: performance with both structured + speculative apply_bitmask( - logits, grammar_bitmask.to(self.device, non_blocking=True), - list(scheduler_output.guided_decoding_request_ids.values())) + logits, + grammar_bitmask.to(self.device, + non_blocking=not self.use_spec_decode), + list(scheduler_output.guided_decoding_request_ids.values()), + ) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.get_sampling_metadata( From 83a52770d4a12408e42fd0c3ca567f1b622f8aa8 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 25 Feb 2025 22:13:39 +0000 Subject: [PATCH 73/84] fix: correct lock the matcher for both rollback and advance Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 2d2f55bd01cc5..f3637c28c4ad5 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -57,7 +57,7 @@ class Grammar: hash=False, init=False, ) - _accept_lock: threading.Lock = field( + _matcher_lock: threading.Lock = field( default_factory=lambda: threading.Lock(), repr=False, init=False, @@ -67,7 +67,7 @@ class Grammar: def accept_token(self, token: int) -> bool: # NOTE: accept_token will determines whether we accept this token # and will also update the machine state - with self._accept_lock: + with self._matcher_lock: self.num_processed_tokens += 1 return self.matcher.accept_token(token) @@ -76,8 +76,9 @@ def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: return self.matcher.fill_next_token_bitmask(bitmask, idx) def rollback(self, num_tokens: int): - self.num_processed_tokens -= num_tokens - self.matcher.rollback(num_tokens) + with self._matcher_lock: + self.num_processed_tokens -= num_tokens + self.matcher.rollback(num_tokens) def reset(self): self.num_processed_tokens = 0 From d02e11a6532aecb8769033142bb5f01cb4ce2182 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 26 Feb 2025 19:02:16 +0000 Subject: [PATCH 74/84] chore: only rollback if there are more than zero processed tokens Signed-off-by: Aaron Pham --- vllm/v1/guided_decoding/__init__.py | 5 +++-- vllm/v1/worker/gpu_model_runner.py | 9 ++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index f3637c28c4ad5..5542c6b025a04 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -77,8 +77,9 @@ def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: def rollback(self, num_tokens: int): with self._matcher_lock: - self.num_processed_tokens -= num_tokens - self.matcher.rollback(num_tokens) + if self.num_processed_tokens > 0: + self.num_processed_tokens -= num_tokens + self.matcher.rollback(num_tokens) def reset(self): self.num_processed_tokens = 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b795b18763787..eb720210a8e73 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -954,15 +954,10 @@ def execute_model( # Apply guided decoding bitmasks if present grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is not None: - # NOTE: A non-blocking copy causes incorrect behaviour - # in speculative decoding (and we should use blocking copy - # for speculative decoding use-case) - # - # TODO: performance with both structured + speculative + # TODO: compatibility with spec decode apply_bitmask( logits, - grammar_bitmask.to(self.device, - non_blocking=not self.use_spec_decode), + grammar_bitmask.to(self.device, non_blocking=True), list(scheduler_output.guided_decoding_request_ids.values()), ) From c64daa746c12d489588d8cf5dad825ca3d2e5a9b Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 26 Feb 2025 19:36:56 +0000 Subject: [PATCH 75/84] fix: correctly free requests based on accepted tokens Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 412ebc1f19473..ef1c74588381a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -603,26 +603,20 @@ def update_from_output( if request.use_guided_decoding: grammar = request.grammar assert grammar is not None - index = model_runner_output.req_id_to_index.get(req_id) - if index is not None: - # accept_token advances the FSM - has_accept_tokens = [ - grammar.accept_token(token_id) - for token_id in generated_token_ids - ] - accepted = any(has_accept_tokens) - if not accepted: - grammar.rollback(len(has_accept_tokens)) - logger.error( - "Failed to advance FSM for request %s " - "for all draft " - "tokens %s", req_id, generated_token_ids) - stopped = True - else: - grammar.rollback(len(has_accept_tokens) - 1) - if stopped: - self._free_request(request) - continue + # accept_token advances the FSM + has_accept_tokens = [ + grammar.accept_token(token_id) + for token_id in generated_token_ids + ] + accepted = any(has_accept_tokens) + if not accepted: + logger.error( + "Failed to advance FSM for request %s " + "for all draft " + "tokens %s", req_id, generated_token_ids) + stopped = True + else: + grammar.rollback(len(has_accept_tokens) - 1) if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: From ad05fe8dbc61d4be33d4479b068831e4653f5ef1 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 26 Feb 2025 14:44:35 -0500 Subject: [PATCH 76/84] Account for differences in scheduler and gpu worker batch ordering The scheduler sends a bitmask for guided decoding down to the gpu worker, but the indices into this bitmask may not match the order of requests used in the gpu worker. This change detects the discrepancy and creates a reordered bitmask when necessary before applying it to the logits. Signed-off-by: Russell Bryant --- .../entrypoints/llm/test_guided_generate.py | 4 +++- vllm/v1/worker/gpu_model_runner.py | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_guided_generate.py b/tests/v1/entrypoints/llm/test_guided_generate.py index 6c30d5dc3e555..871739bcf1640 100644 --- a/tests/v1/entrypoints/llm/test_guided_generate.py +++ b/tests/v1/entrypoints/llm/test_guided_generate.py @@ -197,7 +197,9 @@ def test_guided_grammar_ebnf_invalid(monkeypatch, guided_decoding=GuidedDecodingParams( grammar="not a grammar", backend=guided_decoding_backend)) - with pytest.raises(ValueError, match="Invalid grammar specification."): + with pytest.raises(ValueError, + match="Failed to convert the grammar " + "from Lark to EBNF."): llm.generate( prompts=("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1"), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eb720210a8e73..15e1e91c89be3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -954,6 +954,30 @@ def execute_model( # Apply guided decoding bitmasks if present grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is not None: + # We receive the guided decoding bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the gpu runner is + # ordering the requests in the batch. We need to sort the bitmask to + # match the order of the requests used here. + req_id_indices: Dict[str, int] = {} + indices_match = True + for req_id in self.input_batch.req_ids: + batch_index = self.input_batch.req_id_to_index[req_id] + if batch_index != scheduler_output.guided_decoding_request_ids[ + req_id]: + indices_match = False + req_id_indices[req_id] = batch_index + + sorted_bitmask: Optional[torch.Tensor] = None + if not indices_match: + # Sort the bitmask to match the order of the requests + sorted_bitmask = torch.zeros_like(grammar_bitmask) + for req_id, batch_index in req_id_indices.items(): + orig_index = scheduler_output.guided_decoding_request_ids[ + req_id] + sorted_bitmask[batch_index] = grammar_bitmask[orig_index] + grammar_bitmask = sorted_bitmask + # TODO: compatibility with spec decode apply_bitmask( logits, From 7cf632671ca4677ba906a641b48b20459a8159cc Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 26 Feb 2025 15:05:46 -0500 Subject: [PATCH 77/84] Skip non-guided-decode requests when assembling reordered bitmask Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 15e1e91c89be3..77947d3be4738 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -962,6 +962,9 @@ def execute_model( req_id_indices: Dict[str, int] = {} indices_match = True for req_id in self.input_batch.req_ids: + if req_id not in scheduler_output.guided_decoding_request_ids: + # not a guided decoding request + continue batch_index = self.input_batch.req_id_to_index[req_id] if batch_index != scheduler_output.guided_decoding_request_ids[ req_id]: From 84bbae1d4658f1fa42f24e3c16e334903ff0115b Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 26 Feb 2025 20:07:38 +0000 Subject: [PATCH 78/84] revert: remove rollback check for now, only advance 1 token Signed-off-by: Aaron Pham --- vllm/v1/core/scheduler.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ef1c74588381a..af99c722ddc0f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -603,20 +603,17 @@ def update_from_output( if request.use_guided_decoding: grammar = request.grammar assert grammar is not None + if len(generated_token_ids) > 1: + logger.error( + "Structured output does not currently support " + "more than one token at a time. Only the first " + "token will be used.") # accept_token advances the FSM - has_accept_tokens = [ - grammar.accept_token(token_id) - for token_id in generated_token_ids - ] - accepted = any(has_accept_tokens) + accepted = grammar.accept_token(generated_token_ids[0]) if not accepted: logger.error( "Failed to advance FSM for request %s " - "for all draft " - "tokens %s", req_id, generated_token_ids) - stopped = True - else: - grammar.rollback(len(has_accept_tokens) - 1) + "for tokens %s", req_id, generated_token_ids[0]) if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: From c10eb6ae7f96aae9e270a1e4538d058aac9e7030 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 26 Feb 2025 16:43:41 -0500 Subject: [PATCH 79/84] Fix accidental re-use of cached grammar matcher We cache the compiled grammar, but we need a unique matcher instance for each request. The code previously re-used the same matcher for all requests using the same grammar. If multiple parallel requests had the same grammar, they would mostly fail as a result. Signed-off-by: Russell Bryant --- vllm/v1/guided_decoding/__init__.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 5542c6b025a04..2b17659292bd2 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import copy import enum import threading from concurrent.futures import ThreadPoolExecutor @@ -130,12 +131,12 @@ def remove_requests(self, request_ids: List[str]) -> None: def should_cache(self, request: Request): if not request.use_guided_decoding: return False - request.grammar = self.request_key_to_grammar.get( - request.guided_decoding_key) - if not request.grammar: - request.grammar = self.cache(request) - return True - return False + grammar = self.request_key_to_grammar.get(request.guided_decoding_key) + if grammar: + request.grammar = copy.copy(grammar) + return False + request.grammar = self.cache(request) + return True def cache(self, request: Request): return self.executor.submit(self._executor_loop, request) @@ -145,10 +146,11 @@ def _executor_loop(self, request: Request) -> Grammar: with self._requests_lock: self.requests.add(request) if key in self.request_key_to_grammar: - return self.request_key_to_grammar[key] - - self.request_key_to_grammar[key] = self.initialize_grammar(key) - return self.request_key_to_grammar[key] + grammar = self.request_key_to_grammar[key] + return copy.copy(grammar) + grammar = self.initialize_grammar(key) + self.request_key_to_grammar[key] = grammar + return copy.copy(grammar) def initialize_grammar(self, key: GuidedDecodingKey) -> Grammar: # Note that the request was validated in the engine core client, @@ -189,5 +191,5 @@ def setup_grammars(self): # Check if grammar is ready in cache grammar = self[req.guided_decoding_key] if grammar is not None: - req.grammar = grammar + req.grammar = copy.copy(grammar) continue From 0518b70721bacfeee183c2760301dd1f6e4d32fb Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 26 Feb 2025 17:03:10 -0500 Subject: [PATCH 80/84] Use the correct indices for the logits bitmask This code did a bit of a dance to get the correct indices for the logits and then used the old wrong ones. Oops. Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 77947d3be4738..a314a187c22af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -985,7 +985,7 @@ def execute_model( apply_bitmask( logits, grammar_bitmask.to(self.device, non_blocking=True), - list(scheduler_output.guided_decoding_request_ids.values()), + list(req_id_indices.values()), ) # Sample the next token and get logprobs if needed. From 5f23e8b62d21303fe4d2e6e14bd8bcbe39899bdc Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 26 Feb 2025 17:32:38 -0700 Subject: [PATCH 81/84] Update vllm/v1/core/scheduler_output.py Co-authored-by: Russell Bryant --- vllm/v1/core/scheduler_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 70477105c6461..9cea11d88067a 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -114,7 +114,7 @@ class SchedulerOutput: # Used to free the encoder cache. free_encoder_input_ids: List[Tuple[str, int]] - # Dict of request ids to its index within the batch + # Dict of request ids to their index within the batch # for filling the next token bitmask guided_decoding_request_ids: Dict[str, int] # the bitmask for the whole batch From deb9b36ac51fd614ce2cd4d54508a5ead665dfb6 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 27 Feb 2025 00:20:55 -0500 Subject: [PATCH 82/84] Apply suggestions from Russell Co-authored-by: Russell Bryant --- vllm/v1/core/scheduler.py | 12 ++++-------- vllm/v1/engine/core.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index af99c722ddc0f..87fa4ea74337f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -246,8 +246,7 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[num_to_skip] - if (request.use_guided_decoding - and request.status == RequestStatus.WAITING_FOR_FSM): + if request.status == RequestStatus.WAITING_FOR_FSM: if request.grammar and request.is_grammar_ready: request.status = RequestStatus.WAITING else: @@ -314,7 +313,7 @@ def schedule(self) -> "SchedulerOutput": req_index += 1 self.running.append(request) self.scheduled_req_ids.add(request.request_id) - if RequestStatus.is_waiting(request.status): + if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) self.request_scheduled(request, scheduled_timestamp) elif request.status == RequestStatus.PREEMPTED: @@ -613,7 +612,8 @@ def update_from_output( if not accepted: logger.error( "Failed to advance FSM for request %s " - "for tokens %s", req_id, generated_token_ids[0]) + "for tokens %s. Please file an issue.", + req_id, generated_token_ids[0]) if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: @@ -719,10 +719,6 @@ def _free_request(self, request: Request) -> None: self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) - if request.use_guided_decoding: - # NOTE: grammar should NOT be None - # if use_guided_decoding is True - request.grammar.reset() # type: ignore[union-attr] def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index cb276cc994386..9bfd9d377002b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -159,18 +159,21 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - # Check for cached grammars and allocate bitmask if necessary - self.setup_grammars() + # Check cache for compiled grammars and add them to requests + # when they're ready. + self.guided_decoding_manager.setup_grammars() scheduler_output = self.scheduler.schedule() + # This case may occur when the only unfinished requests are + # guided decoding requests where the grammar has not finished + # compiling yet, so there's nothing to run. if scheduler_output.total_num_scheduled_tokens == 0: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - # the bitmask allocation for grammars - # should be ready at this point. - # Currently we will broadcast the bitmask + # Currently we will broadcast the bitmask. It is populated during + # each schedule() run. if len(self.guided_decoding_manager.requests) > 0: scheduler_output.grammar_bitmask = \ self.guided_decoding_manager.grammar_bitmask @@ -260,9 +263,6 @@ def list_loras(self) -> Set[int]: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) - def setup_grammars(self): - self.guided_decoding_manager.setup_grammars() - class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" From 4bcee6c16a7bf790814d6dc91a6c937d60523b92 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 27 Feb 2025 06:16:13 +0000 Subject: [PATCH 83/84] chore: update requests to remove unused function Signed-off-by: Aaron Pham --- .../structured_schema_1.json | 132 +++--------------- vllm/v1/request.py | 4 - 2 files changed, 22 insertions(+), 114 deletions(-) diff --git a/benchmarks/structured_schemas/structured_schema_1.json b/benchmarks/structured_schemas/structured_schema_1.json index 6003698469e8d..1bd189c9e704f 100644 --- a/benchmarks/structured_schemas/structured_schema_1.json +++ b/benchmarks/structured_schemas/structured_schema_1.json @@ -1,113 +1,25 @@ { - "$schema": - "https://json-schema.org/draft/2020-12/schema", - "title": - "User Profile", - "type": - "object", + "type": "array", + "items": { + "type": "object", "properties": { - "userId": { - "type": "string", - "description": "Unique identifier for the user." - }, - "personalInfo": { - "type": "object", - "properties": { - "firstName": { - "type": "string", - "description": "The user's first name." - }, - "lastName": { - "type": "string", - "description": "The user's last name." - }, - "age": { - "type": "integer", - "minimum": 0, - "description": "The user's age." - }, - "phoneNumbers": { - "type": - "array", - "items": { - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["home", "work", "mobile"], - "description": "Type of phone number." - }, - "number": { - "type": "string", - "pattern": "^\\+?[1-9]\\d{1,14}$", - "description": "Phone number in E.164 format." - } - }, - "required": ["type", "number"] - }, - "description": - "List of phone numbers associated with the user." - } - }, - "required": ["firstName", "lastName"] - }, - "address": { - "type": "object", - "properties": { - "street": { - "type": "string", - "description": "Street address." - }, - "city": { - "type": "string", - "description": "City name." - }, - "state": { - "type": "string", - "description": "State or province." - }, - "postalCode": { - "type": "string", - "pattern": "^\\d{5}(-\\d{4})?$", - "description": "Postal code." - }, - "country": { - "type": "string", - "description": "Country name." - } - }, - "required": ["street", "city", "state", "postalCode", "country"] - }, - "preferences": { - "type": "object", - "properties": { - "newsletterSubscribed": { - "type": - "boolean", - "description": - "Indicates if the user is subscribed to the newsletter." - }, - "favoriteCategories": { - "type": "array", - "items": { - "type": "string" - }, - "description": "List of user's favorite categories." - } - }, - "required": ["newsletterSubscribed"] - }, - "accountStatus": { - "type": "string", - "enum": ["active", "inactive", "suspended"], - "description": "Current status of the user's account." - }, - "registrationDate": { - "type": "string", - "format": "date-time", - "description": "ISO 8601 formatted date-time of user registration." - } + "name": { "type": "string" }, + "race": { "type": "string" }, + "class": { "type": "string" }, + "level": { "type": "integer" }, + "background": { "type": "string" }, + "alignment": { "type": "string" }, + "backstory": { "type": "string" } }, - "required": - ["userId", "personalInfo", "address", "accountStatus", "registrationDate"] -} \ No newline at end of file + "required": [ + "name", + "race", + "class", + "level", + "background", + "alignment", + "backstory" + ] + } +} + diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 765b203bf962d..3f941a8b7209c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -220,10 +220,6 @@ class RequestStatus(enum.IntEnum): def is_finished(status: RequestStatus) -> bool: return status > RequestStatus.PREEMPTED - @staticmethod - def is_waiting(status: RequestStatus) -> bool: - return status <= RequestStatus.WAITING_FOR_FSM - @staticmethod def get_finished_reason( status: RequestStatus) -> Union[FinishReason, None]: From 3b49e8e9c54434424278b653b30536302c84fbff Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 27 Feb 2025 06:58:30 +0000 Subject: [PATCH 84/84] chore: address comments and renaming for clarity Signed-off-by: Aaron Pham --- vllm/v1/engine/core.py | 2 +- vllm/v1/guided_decoding/__init__.py | 48 ++++++++++++++++------------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9bfd9d377002b..1375a2b7a5222 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -137,7 +137,7 @@ def add_request(self, request: EngineCoreRequest): req = Request.from_engine_core_request(request) if req.use_guided_decoding: # Start grammar compilation asynchronously - self.guided_decoding_manager.should_cache(req) + self.guided_decoding_manager.populate_cache(req) self.scheduler.add_request(req) diff --git a/vllm/v1/guided_decoding/__init__.py b/vllm/v1/guided_decoding/__init__.py index 2b17659292bd2..27936e64841e8 100644 --- a/vllm/v1/guided_decoding/__init__.py +++ b/vllm/v1/guided_decoding/__init__.py @@ -4,9 +4,10 @@ import copy import enum import threading +from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, List, Optional, Set, Tuple import torch import xgrammar as xgr @@ -33,12 +34,15 @@ class GuidedDecodingOptions(enum.Enum): MAX_ROLLBACK_TOKENS = 100 -def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor, - indices: List[int]) -> None: +def apply_bitmask( + logits: torch.Tensor, + vocab_mask: torch.Tensor, + indices: List[int], +) -> None: xgr.apply_token_bitmask_inplace(logits, vocab_mask, indices=indices) -@dataclass(slots=True, unsafe_hash=True) +@dataclass(slots=True, unsafe_hash=True) # type: ignore[call-overload] class Grammar: # NOTE: This would be a generic-enough class for # supporting different backends, in the future. @@ -58,29 +62,21 @@ class Grammar: hash=False, init=False, ) - _matcher_lock: threading.Lock = field( - default_factory=lambda: threading.Lock(), - repr=False, - init=False, - hash=False, - ) def accept_token(self, token: int) -> bool: # NOTE: accept_token will determines whether we accept this token # and will also update the machine state - with self._matcher_lock: - self.num_processed_tokens += 1 - return self.matcher.accept_token(token) + self.num_processed_tokens += 1 + return self.matcher.accept_token(token) # this should be ran in parallel with model decoding def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: return self.matcher.fill_next_token_bitmask(bitmask, idx) def rollback(self, num_tokens: int): - with self._matcher_lock: - if self.num_processed_tokens > 0: - self.num_processed_tokens -= num_tokens - self.matcher.rollback(num_tokens) + if self.num_processed_tokens > 0: + self.num_processed_tokens -= num_tokens + self.matcher.rollback(num_tokens) def reset(self): self.num_processed_tokens = 0 @@ -95,7 +91,7 @@ def __copy__(self): class GuidedDecodingManager: - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500): tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, @@ -110,7 +106,9 @@ def __init__(self, vllm_config: VllmConfig): tokenizer, vocab_size=self.vocab_size) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - self.request_key_to_grammar: Dict[GuidedDecodingKey, Grammar] = {} + self.max_cache_size = max_cache_size + self.request_key_to_grammar: OrderedDict[GuidedDecodingKey, + Grammar] = OrderedDict() self.executor = ThreadPoolExecutor() self.requests: Set[Request] = set() @@ -119,7 +117,12 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) def __getitem__(self, key: GuidedDecodingKey) -> Optional[Grammar]: - return self.request_key_to_grammar.get(key) + if key in self.request_key_to_grammar: + # Move accessed item to the end (most recently used) + value = self.request_key_to_grammar.pop(key) + self.request_key_to_grammar[key] = value + return value + return None def remove_requests(self, request_ids: List[str]) -> None: with self._requests_lock: @@ -128,7 +131,7 @@ def remove_requests(self, request_ids: List[str]) -> None: for req in self.requests if req.request_id not in request_ids } - def should_cache(self, request: Request): + def populate_cache(self, request: Request): if not request.use_guided_decoding: return False grammar = self.request_key_to_grammar.get(request.guided_decoding_key) @@ -149,6 +152,9 @@ def _executor_loop(self, request: Request) -> Grammar: grammar = self.request_key_to_grammar[key] return copy.copy(grammar) grammar = self.initialize_grammar(key) + # If cache is full, remove the least recently used item + if len(self.request_key_to_grammar) >= self.max_cache_size: + self.request_key_to_grammar.popitem(last=False) self.request_key_to_grammar[key] = grammar return copy.copy(grammar)