From 36bc041e30429d160ba4818e347db419d25727dd Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 28 Jan 2025 04:59:38 +0000 Subject: [PATCH] chore: --wip-- Signed-off-by: Aaron Pham --- vllm/v1/core/guided_decoding/__init__.py | 78 +++++++++++++----------- vllm/v1/core/guided_decoding/grammar.py | 23 ++++--- vllm/v1/core/scheduler.py | 70 ++++++++++----------- vllm/v1/request.py | 4 +- vllm/v1/worker/gpu_input_batch.py | 8 +-- 5 files changed, 99 insertions(+), 84 deletions(-) diff --git a/vllm/v1/core/guided_decoding/__init__.py b/vllm/v1/core/guided_decoding/__init__.py index 24a8d17fc6a13..20a1468e1329d 100644 --- a/vllm/v1/core/guided_decoding/__init__.py +++ b/vllm/v1/core/guided_decoding/__init__.py @@ -5,12 +5,11 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, - get_args) +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args from transformers import PreTrainedTokenizer -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.utils import LazyLoader from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus @@ -23,6 +22,8 @@ from typing_extensions import LiteralString from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + + from .grammar import XGrammar else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -48,58 +49,67 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: def flush(self): with self._lock: - self.cache.clear() + self.grammar_cache.clear() + + def cache(self, request: Request): + + def _executor_loop(request: Request): + key = request.guided_decoding_key + with self._lock: + cache_hit = False + if key in self.grammar_cache: + cache_hit, entry = True, self.grammar_cache[key] + else: + entry = GrammarCache(None, threading.Event()) + self.grammar_cache[key] = entry + + if cache_hit: + entry.event.wait() + else: + entry.value = self.initialize_cache(key) + entry.event.set() + return copy.copy(entry.value) if entry.value else None - def cache_grammar(self, request: Request): - return self.executor.submit(self._add_grammar_to_cache, request) + return self.executor.submit(_executor_loop, request) - def get_grammar(self, request: Request): + def get(self, request: Request): with self._lock: - entry = self.cache.get(request.guided_decoding_key) + entry = self.grammar_cache.get(request.guided_decoding_key) if entry is None or not entry.event.is_set(): return None return copy.copy(entry.value) if entry.value else None - def should_add(self, request: Request): + def collect(self, request: Request): if not request.use_guided_decoding: return False - request.grammar = self.get_grammar(request) + request.grammar = self.get(request) if not request.grammar: - request.grammar = self.cache_grammar(request) + request.grammar = self.cache(request) request.status = RequestStatus.WAITING_FOR_FSM return True return False - def _add_grammar_to_cache(self, request: Request): - key = request.guided_decoding_key - with self._lock: - cache_hit = False - if key in self.cache: - cache_hit, entry = True, self.cache[key] - else: - entry = GrammarCache(None, threading.Event()) - self.cache[key] = entry - - if cache_hit: - entry.event.wait() - else: - entry.value = self.initialize_cache(key) - entry.event.set() - return copy.copy(entry.value) if entry.value else None - @classmethod - def from_backend(cls, /, backend: LiteralString = "xgrammar", *, + def from_backend(cls, + backend: LiteralString = "xgrammar", + /, + *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig) -> GuidedDecodingManager[T]: manager_cls = cls._registry.get(backend) - if manager_cls is None: raise ValueError( f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}") - return manager_cls(tokenizer_group=tokenizer_group, model_config=model_config) + if manager_cls is None: + raise ValueError( + f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}" + ) + return manager_cls(tokenizer_group=tokenizer_group, + model_config=model_config) _registry: dict[str, type[GuidedDecodingManager[T]]] = {} _backend: T - def __init__(self, *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig): + def __init__(self, *, tokenizer_group: BaseTokenizerGroup, + model_config: ModelConfig): self.model_config = model_config self.tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.cache: dict[GuidedDecodingKey, GrammarCache] = {} + self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {} self.executor = ThreadPoolExecutor() self._lock = threading.Lock() @@ -136,7 +146,7 @@ class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]): _compiler_cache: dict[str, xgr.GrammarCompiler] = {} _compiler: xgr.GrammarCompiler | None = None - def initialize_cache(self, key: GuidedDecodingKey) -> Grammar: + def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar: request_type, grammar_spec = key compiler = XGrammarManager.get_compiler(self.tokenizer) if request_type == "json": diff --git a/vllm/v1/core/guided_decoding/grammar.py b/vllm/v1/core/guided_decoding/grammar.py index baa9b6c693ca0..b634ae169a393 100644 --- a/vllm/v1/core/guided_decoding/grammar.py +++ b/vllm/v1/core/guided_decoding/grammar.py @@ -4,17 +4,19 @@ from typing import (TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, overload) +from typing_extensions import Annotated, LiteralString + from vllm.utils import LazyLoader if TYPE_CHECKING: import torch import xgrammar as xgr - from typing_extensions import LiteralString, Self + from typing_extensions import Self else: xgr = LazyLoader("xgr", globals(), "xgrammar") torch = LazyLoader("torch", globals(), "torch") -T = TypeVar("T", bound=str) +T = TypeVar("T", bound=Annotated[LiteralString, str]) class Grammar(ABC, Generic[T]): @@ -93,13 +95,19 @@ def from_backend( @overload @classmethod - def from_backend(cls, - backend: LiteralString = ..., - **kwargs: Any) -> Grammar: + def from_backend( + cls, + backend: Literal["outlines"] = ..., + *, + guide: str = ..., + whitespace_pattern: str | None = ..., + ) -> XGrammar: ... @classmethod - def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Grammar[T]: + def from_backend(cls, + backend: LiteralString = "xgrammar", + **kwargs: Any) -> Grammar[T]: grammar_cls = cls._registry.get(backend) if grammar_cls is None: raise ValueError( @@ -108,8 +116,7 @@ def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Gra class XGrammar(Grammar[Literal["xgrammar"]]): - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string - # for jump-forward decoding + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int, ctx: xgr.CompiledGrammar) -> None: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 97bb87f996e93..a404e4268923c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,9 +1,10 @@ from __future__ import annotations from collections import deque +from concurrent import futures from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Optional, - Set, Tuple, Union) +from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Literal, + Optional, Set, Tuple, Union) from vllm.config import (CacheConfig, DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -70,22 +71,6 @@ def __init__( self.waiting: Deque[Request] = deque() self.running: List[Request] = [] - # A set of unready requests that might be waiting for grammar compilation - # we can also use this for tracking spec decode request - self.grammar_queue: Deque[Request] = deque() - # initialize the tokenizer on the scheduler (this is used for constrained decoding) - tokenizer_group = init_tokenizer_from_configs( - model_config=model_config, - scheduler_config=scheduler_config, - parallel_config=parallel_config, - lora_config=lora_config) - tokenizer_group.ping() - # setup guided decoding, right now uses xgrammar - self.guided_decoding_manager = GuidedDecodingManager[Any].from_backend( - backend=decoding_config.guided_decoding_backend, - tokenizer_group=tokenizer_group, - model_config=model_config) - # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished # requests so that they can free the cached states for those requests. @@ -117,6 +102,21 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) + # A request queue for grammar compilation + self.grammar: Deque[Request] = deque() + # initialize the tokenizer on the scheduler (this is used for constrained decoding) + tokenizer_group = init_tokenizer_from_configs( + model_config=model_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + lora_config=lora_config) + tokenizer_group.ping() + # setup guided decoding, right now uses xgrammar + self.guided_decoding_manager = GuidedDecodingManager.from_backend( + backend=decoding_config.guided_decoding_backend, + tokenizer_group=tokenizer_group, + model_config=model_config) + def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -133,21 +133,20 @@ def schedule(self) -> "SchedulerOutput": preempted_reqs: List[Request] = [] # we need to check the grammar queue for any requests that have finished FSM compilation - newly_ready_reqs: List[Request] = [] - remaining_grammar_reqs: Deque[Request] = deque() - while self.grammar_queue: - request = self.grammar_queue.popleft() - grammar = self.guided_decoding_manager.get(request) - if grammar is not None: - request.grammar = grammar + newly_grammar_reqs: List[Request] = [] + scheduled_grammar_reqs: Deque[Request] = deque() + while self.grammar: + request = self.grammar.popleft() + try: + request.grammar = request.grammar.result(timeout=0.05) request.status = RequestStatus.WAITING - newly_ready_reqs.append(request) - else: - remaining_grammar_reqs.append(request) - self.grammar_queue = remaining_grammar_reqs + newly_grammar_reqs.append(request) + except futures._base.TimeoutError: + scheduled_grammar_reqs.append(request) + self.grammar = scheduled_grammar_reqs # append all newly ready requests to waiting queue with higher priority - for req in newly_ready_reqs: + for req in newly_grammar_reqs: self.waiting.appendleft(req) req_to_new_block_ids: Dict[str, List[int]] = {} @@ -161,7 +160,6 @@ def schedule(self) -> "SchedulerOutput": vocab_size = self.model_config.get_vocab_size() guided_decoding_bitmasks: Dict[str, torch.Tensor] = {} guided_decoding_reqs: List[Request] = [] - batch_has_grammar = False # First, schedule the RUNNING requests. # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be @@ -242,7 +240,6 @@ def schedule(self) -> "SchedulerOutput": # Track if we need guided decoding # Create individual bitmask for requests with grammar if request.grammar is not None: - batch_has_grammar = True if request.request_id not in guided_decoding_bitmasks: bitmask = request.grammar.allocate_bitmask(1, vocab_size) guided_decoding_bitmasks[request.request_id] = bitmask @@ -266,7 +263,6 @@ def schedule(self) -> "SchedulerOutput": # Track guided decoding needs if request.grammar is not None: - batch_has_grammar = True if request.request_id not in guided_decoding_bitmasks: bitmask = request.grammar.allocate_bitmask( 1, vocab_size) @@ -568,10 +564,10 @@ def _check_stop(self, request: Request) -> bool: def add_request(self, request: Request) -> None: self.requests[request.request_id] = request - add_to_grammar_queue = self.guided_decoding_manager.collect(request) - - if add_to_grammar_queue: self.grammar_queue.append(request) - else: self.waiting.append(request) + if self.guided_decoding_manager.collect(request): + self.grammar.append(request) + else: + self.waiting.append(request) def finish_requests( self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b744d342870d9..d2a3b6e80d798 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -11,6 +11,8 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: + from concurrent.futures import Future + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange @@ -81,7 +83,7 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # grammar objects - self.grammar = grammar + self.grammar: Optional[Grammar[Any] | Future[Grammar[Any]]] = grammar @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9e9a5febd1b46..5d62b09c6312d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -169,7 +169,7 @@ def __init__( self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() - self.guided_decoding_reqs: Set[str] = set() + self.grammar_reqs: Set[str] = set() def add_request( self, @@ -237,7 +237,7 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) - if request.grammar is not None: self.guided_decoding_reqs.add(req_id) + if request.grammar is not None: self.grammar_reqs.add(req_id) def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) @@ -255,7 +255,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) - self.guided_decoding_reqs.discard(req_id) + self.grammar_reqs.discard(req_id) return req_index def clear(self) -> None: @@ -271,7 +271,7 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() - self.guided_decoding_reqs.clear() + self.grammar_reqs.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: