Skip to content

Commit

Permalink
chore: --wip--
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
  • Loading branch information
aarnphm committed Jan 28, 2025
1 parent d719c93 commit 36bc041
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 84 deletions.
78 changes: 44 additions & 34 deletions vllm/v1/core/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar,
get_args)
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args

from transformers import PreTrainedTokenizer

from vllm.config import DecodingConfig, ModelConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.utils import LazyLoader
from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus
Expand All @@ -23,6 +22,8 @@
from typing_extensions import LiteralString

from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

from .grammar import XGrammar
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")

Expand All @@ -48,58 +49,67 @@ def initialize_cache(self, key: GuidedDecodingKey) -> Grammar:

def flush(self):
with self._lock:
self.cache.clear()
self.grammar_cache.clear()

def cache(self, request: Request):

def _executor_loop(request: Request):
key = request.guided_decoding_key
with self._lock:
cache_hit = False
if key in self.grammar_cache:
cache_hit, entry = True, self.grammar_cache[key]
else:
entry = GrammarCache(None, threading.Event())
self.grammar_cache[key] = entry

if cache_hit:
entry.event.wait()
else:
entry.value = self.initialize_cache(key)
entry.event.set()
return copy.copy(entry.value) if entry.value else None

def cache_grammar(self, request: Request):
return self.executor.submit(self._add_grammar_to_cache, request)
return self.executor.submit(_executor_loop, request)

def get_grammar(self, request: Request):
def get(self, request: Request):
with self._lock:
entry = self.cache.get(request.guided_decoding_key)
entry = self.grammar_cache.get(request.guided_decoding_key)
if entry is None or not entry.event.is_set(): return None

Check failure on line 78 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/v1/core/guided_decoding/__init__.py:78:57: E701 Multiple statements on one line (colon)
return copy.copy(entry.value) if entry.value else None

def should_add(self, request: Request):
def collect(self, request: Request):
if not request.use_guided_decoding: return False

Check failure on line 82 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/v1/core/guided_decoding/__init__.py:82:43: E701 Multiple statements on one line (colon)
request.grammar = self.get_grammar(request)
request.grammar = self.get(request)
if not request.grammar:
request.grammar = self.cache_grammar(request)
request.grammar = self.cache(request)
request.status = RequestStatus.WAITING_FOR_FSM
return True
return False

def _add_grammar_to_cache(self, request: Request):
key = request.guided_decoding_key
with self._lock:
cache_hit = False
if key in self.cache:
cache_hit, entry = True, self.cache[key]
else:
entry = GrammarCache(None, threading.Event())
self.cache[key] = entry

if cache_hit:
entry.event.wait()
else:
entry.value = self.initialize_cache(key)
entry.event.set()
return copy.copy(entry.value) if entry.value else None

@classmethod
def from_backend(cls, /, backend: LiteralString = "xgrammar", *,
def from_backend(cls,
backend: LiteralString = "xgrammar",
/,
*,
tokenizer_group: BaseTokenizerGroup,
model_config: ModelConfig) -> GuidedDecodingManager[T]:
manager_cls = cls._registry.get(backend)
if manager_cls is None: raise ValueError( f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}")
return manager_cls(tokenizer_group=tokenizer_group, model_config=model_config)
if manager_cls is None:
raise ValueError(
f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}"

Check failure on line 100 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/guided_decoding/__init__.py:100:81: E501 Line too long (103 > 80)
)
return manager_cls(tokenizer_group=tokenizer_group,
model_config=model_config)

_registry: dict[str, type[GuidedDecodingManager[T]]] = {}
_backend: T

def __init__(self, *, tokenizer_group: BaseTokenizerGroup, model_config: ModelConfig):
def __init__(self, *, tokenizer_group: BaseTokenizerGroup,
model_config: ModelConfig):
self.model_config = model_config
self.tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.cache: dict[GuidedDecodingKey, GrammarCache] = {}
self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {}
self.executor = ThreadPoolExecutor()
self._lock = threading.Lock()

Expand Down Expand Up @@ -136,7 +146,7 @@ class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]):
_compiler_cache: dict[str, xgr.GrammarCompiler] = {}
_compiler: xgr.GrammarCompiler | None = None

def initialize_cache(self, key: GuidedDecodingKey) -> Grammar:
def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar:
request_type, grammar_spec = key
compiler = XGrammarManager.get_compiler(self.tokenizer)
if request_type == "json":
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/core/guided_decoding/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
from typing import (TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args,
overload)

from typing_extensions import Annotated, LiteralString

from vllm.utils import LazyLoader

if TYPE_CHECKING:
import torch
import xgrammar as xgr
from typing_extensions import LiteralString, Self
from typing_extensions import Self
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
torch = LazyLoader("torch", globals(), "torch")

T = TypeVar("T", bound=str)
T = TypeVar("T", bound=Annotated[LiteralString, str])


class Grammar(ABC, Generic[T]):
Expand Down Expand Up @@ -93,13 +95,19 @@ def from_backend(

@overload
@classmethod
def from_backend(cls,
backend: LiteralString = ...,
**kwargs: Any) -> Grammar:
def from_backend(
cls,
backend: Literal["outlines"] = ...,
*,
guide: str = ...,
whitespace_pattern: str | None = ...,
) -> XGrammar:
...

@classmethod
def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Grammar[T]:
def from_backend(cls,
backend: LiteralString = "xgrammar",
**kwargs: Any) -> Grammar[T]:
grammar_cls = cls._registry.get(backend)
if grammar_cls is None:
raise ValueError(
Expand All @@ -108,8 +116,7 @@ def from_backend(cls, backend: LiteralString = "xgrammar", **kwargs: Any) -> Gra


class XGrammar(Grammar[Literal["xgrammar"]]):
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding

Check failure on line 119 in vllm/v1/core/guided_decoding/grammar.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/guided_decoding/grammar.py:119:81: E501 Line too long (131 > 80)

def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int,
ctx: xgr.CompiledGrammar) -> None:
Expand Down
70 changes: 33 additions & 37 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from collections import deque
from concurrent import futures
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Optional,
Set, Tuple, Union)
from typing import (TYPE_CHECKING, Any, Deque, Dict, Iterable, List, Literal,
Optional, Set, Tuple, Union)

from vllm.config import (CacheConfig, DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
Expand Down Expand Up @@ -70,22 +71,6 @@ def __init__(
self.waiting: Deque[Request] = deque()
self.running: List[Request] = []

# A set of unready requests that might be waiting for grammar compilation
# we can also use this for tracking spec decode request
self.grammar_queue: Deque[Request] = deque()
# initialize the tokenizer on the scheduler (this is used for constrained decoding)
tokenizer_group = init_tokenizer_from_configs(
model_config=model_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
lora_config=lora_config)
tokenizer_group.ping()
# setup guided decoding, right now uses xgrammar
self.guided_decoding_manager = GuidedDecodingManager[Any].from_backend(
backend=decoding_config.guided_decoding_backend,
tokenizer_group=tokenizer_group,
model_config=model_config)

# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
Expand Down Expand Up @@ -117,6 +102,21 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

# A request queue for grammar compilation
self.grammar: Deque[Request] = deque()
# initialize the tokenizer on the scheduler (this is used for constrained decoding)
tokenizer_group = init_tokenizer_from_configs(
model_config=model_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
lora_config=lora_config)
tokenizer_group.ping()
# setup guided decoding, right now uses xgrammar
self.guided_decoding_manager = GuidedDecodingManager.from_backend(
backend=decoding_config.guided_decoding_backend,
tokenizer_group=tokenizer_group,
model_config=model_config)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand All @@ -133,21 +133,20 @@ def schedule(self) -> "SchedulerOutput":
preempted_reqs: List[Request] = []

# we need to check the grammar queue for any requests that have finished FSM compilation
newly_ready_reqs: List[Request] = []
remaining_grammar_reqs: Deque[Request] = deque()
while self.grammar_queue:
request = self.grammar_queue.popleft()
grammar = self.guided_decoding_manager.get(request)
if grammar is not None:
request.grammar = grammar
newly_grammar_reqs: List[Request] = []
scheduled_grammar_reqs: Deque[Request] = deque()
while self.grammar:
request = self.grammar.popleft()
try:
request.grammar = request.grammar.result(timeout=0.05)
request.status = RequestStatus.WAITING
newly_ready_reqs.append(request)
else:
remaining_grammar_reqs.append(request)
self.grammar_queue = remaining_grammar_reqs
newly_grammar_reqs.append(request)
except futures._base.TimeoutError:
scheduled_grammar_reqs.append(request)
self.grammar = scheduled_grammar_reqs

# append all newly ready requests to waiting queue with higher priority
for req in newly_ready_reqs:
for req in newly_grammar_reqs:
self.waiting.appendleft(req)

req_to_new_block_ids: Dict[str, List[int]] = {}
Expand All @@ -161,7 +160,6 @@ def schedule(self) -> "SchedulerOutput":
vocab_size = self.model_config.get_vocab_size()
guided_decoding_bitmasks: Dict[str, torch.Tensor] = {}
guided_decoding_reqs: List[Request] = []
batch_has_grammar = False

# First, schedule the RUNNING requests.
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
Expand Down Expand Up @@ -242,7 +240,6 @@ def schedule(self) -> "SchedulerOutput":
# Track if we need guided decoding
# Create individual bitmask for requests with grammar
if request.grammar is not None:
batch_has_grammar = True
if request.request_id not in guided_decoding_bitmasks:
bitmask = request.grammar.allocate_bitmask(1, vocab_size)
guided_decoding_bitmasks[request.request_id] = bitmask
Expand All @@ -266,7 +263,6 @@ def schedule(self) -> "SchedulerOutput":

# Track guided decoding needs
if request.grammar is not None:
batch_has_grammar = True
if request.request_id not in guided_decoding_bitmasks:
bitmask = request.grammar.allocate_bitmask(
1, vocab_size)
Expand Down Expand Up @@ -568,10 +564,10 @@ def _check_stop(self, request: Request) -> bool:
def add_request(self, request: Request) -> None:
self.requests[request.request_id] = request

add_to_grammar_queue = self.guided_decoding_manager.collect(request)

if add_to_grammar_queue: self.grammar_queue.append(request)
else: self.waiting.append(request)
if self.guided_decoding_manager.collect(request):
self.grammar.append(request)
else:
self.waiting.append(request)

def finish_requests(
self,
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from vllm.v1.utils import ConstantList

if TYPE_CHECKING:
from concurrent.futures import Future

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
Expand Down Expand Up @@ -81,7 +83,7 @@ def __init__(
self.all_token_ids = ConstantList(self._all_token_ids)

# grammar objects
self.grammar = grammar
self.grammar: Optional[Grammar[Any] | Future[Grammar[Any]]] = grammar

@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(

self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
self.guided_decoding_reqs: Set[str] = set()
self.grammar_reqs: Set[str] = set()

def add_request(
self,
Expand Down Expand Up @@ -237,7 +237,7 @@ def add_request(
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)

if request.grammar is not None: self.guided_decoding_reqs.add(req_id)
if request.grammar is not None: self.grammar_reqs.add(req_id)

def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
Expand All @@ -255,7 +255,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
self.guided_decoding_reqs.discard(req_id)
self.grammar_reqs.discard(req_id)
return req_index

def clear(self) -> None:
Expand All @@ -271,7 +271,7 @@ def clear(self) -> None:
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
self.guided_decoding_reqs.clear()
self.grammar_reqs.clear()

def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
Expand Down

0 comments on commit 36bc041

Please sign in to comment.