Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Structured decoding #12388

Open
wants to merge 94 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
d719c93
feat: initial guided decoding implementation on scheduler
aarnphm Jan 24, 2025
36bc041
chore: --wip--
aarnphm Jan 28, 2025
39068c8
chore: remove lazy loader
aarnphm Jan 28, 2025
2bb535e
fix: update types and attach bitmask to requests
aarnphm Jan 30, 2025
420f52f
chore: --wip--
aarnphm Jan 30, 2025
a5e9874
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm Feb 6, 2025
75e8fb4
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm Feb 8, 2025
9daf140
chore: --wip-- cleanup
aarnphm Feb 8, 2025
299ea58
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm Feb 12, 2025
15a4547
feat: base implementation
aarnphm Feb 12, 2025
49f7b96
fix: update the states within the scheduler
aarnphm Feb 12, 2025
cd357e5
[CI/Build] Ignore ruff warning up007
russellb Feb 13, 2025
9a7b081
Resolve ruff errors
russellb Feb 13, 2025
2e43e04
chore: manage requests within manager class
aarnphm Feb 13, 2025
ccde524
Drop grammar getter/setter on Request
russellb Feb 13, 2025
1587d34
mypy: Fix return type of GPUModelRunner._prepare_inputs()
russellb Feb 13, 2025
227cc7f
Resolve remaining mypy warnings
russellb Feb 13, 2025
c0b235d
Finish getting pre-commit to pass
russellb Feb 13, 2025
49fdce0
Updat michael's suggestions
aarnphm Feb 13, 2025
e9a2304
chore: update according to Michael's review
aarnphm Feb 13, 2025
f6720a8
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb Feb 13, 2025
872c66f
chore: simplify cache implementations
aarnphm Feb 13, 2025
a8a2f27
Changes to get a test request working
russellb Feb 13, 2025
3fda148
Resolve mypy error in request
russellb Feb 13, 2025
d7a64eb
chore: remove debug print
aarnphm Feb 13, 2025
34c08ac
Enable some v1 structured output tests
russellb Feb 13, 2025
3b736ce
Validate structured output backend for v1
russellb Feb 13, 2025
0bffe39
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb Feb 18, 2025
9f73ec9
Merge branch 'main' into v1/structured-decoding
aarnphm Feb 18, 2025
1a258fe
wip fixes for bitmask initialization and communication
russellb Feb 18, 2025
10f01f5
Clean up some remnants of inaccurate merge conflict resolution
russellb Feb 19, 2025
a6b07d1
fix: correctly use bitmask batch-wise
aarnphm Feb 19, 2025
7f255f0
fix: correct types
aarnphm Feb 19, 2025
9ab107f
chore: validate from decoding_config -> per request
aarnphm Feb 19, 2025
8d6bd3b
chore: passing vocab_size
aarnphm Feb 19, 2025
fcb0e85
chore: comment out 0.1.13 features
aarnphm Feb 19, 2025
3402b2a
Merge branch 'main' into v1/structured-decoding
aarnphm Feb 19, 2025
e6038f8
Resize bitmask to match the current batch size
russellb Feb 20, 2025
9830899
set any_whitespace=False for json schema + xgrammar
russellb Feb 20, 2025
cebe281
--wip--: debugging fsm apply
aarnphm Feb 20, 2025
862c093
fix: make sure to reset the FSM once we _free_request
aarnphm Feb 20, 2025
0df21ee
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm Feb 20, 2025
0fc85e3
revert: apply grammar bitmask from update states
aarnphm Feb 21, 2025
d95d1d7
merge: branch 'main' of github.com:vllm-project/vllm into v1/structur…
aarnphm Feb 21, 2025
62f8025
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb Feb 23, 2025
6a372ea
Revert changes to v0 guided decoding tests
russellb Feb 23, 2025
a43afca
create v1 tests_guided_generate for llm entrypoint
russellb Feb 23, 2025
fb40918
Drop unused Scheduler.guided_decoding_requests
russellb Feb 23, 2025
b8e016c
Allow grammar compilation to complete
russellb Feb 24, 2025
c63ca92
Remove some dead committed
russellb Feb 24, 2025
074b65d
Fix index calculation for guided requests in a batch
russellb Feb 24, 2025
727dab0
Make guided decoding manager more thread-safe
russellb Feb 24, 2025
adb50ff
chore: remove prefilled check
aarnphm Feb 24, 2025
5b818f9
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb Feb 24, 2025
c85408a
Re-enable line length checks in ruff
russellb Feb 24, 2025
b34e4a7
Fix a yapf error in main, will be fixed by #13772
russellb Feb 24, 2025
0f2a97f
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb Feb 24, 2025
aabe98b
Prepare the bitmask on the scheduler side instead of gpu worker
russellb Feb 24, 2025
8895e19
tests: make sample jsonschema xgrammar compatible
russellb Feb 24, 2025
470b677
Detect unsupported jsonschema features for xgrammar
russellb Feb 24, 2025
42fe5f8
Make bitmask allocation synchronous
russellb Feb 24, 2025
ada4790
Fix compat with TP > 1
russellb Feb 24, 2025
331a7ff
Make pre-commit happy again
russellb Feb 24, 2025
0984379
chore: remove reset_bitmask after every steps
aarnphm Feb 24, 2025
9b62eef
revert: update whitespace
aarnphm Feb 24, 2025
2f756e5
Add tests/v1/guided_decoding/test_utils.py
russellb Feb 25, 2025
72adc63
Merge remote-tracking branch 'origin/main' into v1/structured-decoding
russellb Feb 25, 2025
1be1709
add v1 structured output regex test case
russellb Feb 25, 2025
0128aff
Restore some code lost in a merge from main
russellb Feb 25, 2025
9cc90ff
Validate schema is supoprted before sending to threadpool
russellb Feb 25, 2025
3a8f955
chore: remove unused code
aarnphm Feb 25, 2025
e772efa
fix: correct typo
aarnphm Feb 25, 2025
64a2ecf
chore(scheduler): simplify check for use_guided_decoding
aarnphm Feb 25, 2025
e8f47f3
Move guided decode validation to the engine core_client
russellb Feb 25, 2025
f3f7d51
test for expected behavior of a choice guided decode request
russellb Feb 25, 2025
9582f8c
Validate jsonschema features for both str and dict cases
russellb Feb 25, 2025
acd5ae0
Test for expected behavior of a request with unsupported jsonschema f…
russellb Feb 25, 2025
4c674ae
Correctly differentiate between jsonschema and json object requests
russellb Feb 25, 2025
1b40882
Test for correct json object (no schema) request behavior
russellb Feb 25, 2025
4f551f4
Add test for a request using an EBNF style grammar
russellb Feb 25, 2025
d132d72
Validate that EBNF grammar can be parsed during early validation
russellb Feb 25, 2025
b994230
Test for expected behavior of an invalid grammar
russellb Feb 25, 2025
3cc6437
Add support and test coverage for lark style grammars
russellb Feb 25, 2025
95be24b
Add support and tests for choice based guided decoding
russellb Feb 25, 2025
9d1fe71
feat: spec decode compatibility [-------------]
aarnphm Feb 25, 2025
83a5277
fix: correct lock the matcher for both rollback and advance
aarnphm Feb 25, 2025
d02e11a
chore: only rollback if there are more than zero processed tokens
aarnphm Feb 26, 2025
c64daa7
fix: correctly free requests based on accepted tokens
aarnphm Feb 26, 2025
ad05fe8
Account for differences in scheduler and gpu worker batch ordering
russellb Feb 26, 2025
7cf6326
Skip non-guided-decode requests when assembling reordered bitmask
russellb Feb 26, 2025
84bbae1
revert: remove rollback check for now, only advance 1 token
aarnphm Feb 26, 2025
c10eb6a
Fix accidental re-use of cached grammar matcher
russellb Feb 26, 2025
0518b70
Use the correct indices for the logits bitmask
russellb Feb 26, 2025
5f23e8b
Update vllm/v1/core/scheduler_output.py
mgoin Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import threading
import time
import traceback
import types
import uuid
import warnings
import weakref
Expand Down Expand Up @@ -2206,3 +2207,71 @@
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)


class LazyLoader(types.ModuleType):
"""
LazyLoader module borrowed from Tensorflow
https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py
with a addition of "module caching". This will throw an
exception if module cannot be imported.

Lazily import a module, mainly to avoid pulling in large dependencies.
`contrib`, and `ffmpeg` are examples of modules that are large and not always

Check failure on line 2220 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/utils.py:2220:81: E501 Line too long (82 > 80)
needed, and this allows them to only be loaded when they are used.
"""

def __init__(
self,
local_name: str,
parent_module_globals: Dict[str, Any],
name: str,
warning: Optional[str] = None,
exc_msg: Optional[str] = None,
exc: Type[Exception] = Exception,
):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._warning = warning
self._exc_msg = exc_msg
self._exc = exc
self._module: types.ModuleType | None = None

super().__init__(str(name))

def _load(self) -> types.ModuleType:
"""Load the module and insert it into the parent's globals."""
from . import warn_deprecated

Check failure on line 2244 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F401)

vllm/utils.py:2244:23: F401 `.warn_deprecated` imported but unused

# Import the target module and insert it into the parent's namespace
try:
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# The additional add to sys.modules ensures library is actually loaded.

Check failure on line 2250 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/utils.py:2250:81: E501 Line too long (83 > 80)
sys.modules[self._local_name] = module
except ModuleNotFoundError as err:
raise self._exc(f"{self._exc_msg} (reason: {err})") from None

# Emit a warning if one was specified
if self._warning:
warnings.warn(self._warning,
category=DeprecationWarning,
stacklevel=4)
# Make sure to only warn once.
self._warning = None

# Update this object's dict so that if someone keeps a reference to the
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups

Check failure on line 2264 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/utils.py:2264:81: E501 Line too long (84 > 80)
# that fail).
self.__dict__.update(module.__dict__)
return module

def __getattr__(self, item: Any) -> Any:
if self._module is None:
self._module = self._load()
return getattr(self._module, item)

def __dir__(self) -> List[str]:
if self._module is None:
self._module = self._load()
return dir(self._module)
192 changes: 192 additions & 0 deletions vllm/v1/core/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from __future__ import annotations

import copy
import threading
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args

from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.utils import LazyLoader
from vllm.v1.request import GuidedDecodingKey, Request, RequestStatus

from .grammar import Grammar

if TYPE_CHECKING:
import xgrammar as xgr
from transformers import PreTrainedTokenizer
from typing_extensions import LiteralString

from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

from .grammar import XGrammar
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)

__all__ = ["Grammar", "GuidedDecodingManager"]


@dataclass
class GrammarCache:
value: Grammar | None
event: threading.Event


T = TypeVar("T", bound=str)


class GuidedDecodingManager(ABC, Generic[T]):

@abstractmethod
def initialize_cache(self, key: GuidedDecodingKey) -> Grammar:
...

def flush(self):
with self._lock:
self.grammar_cache.clear()

def cache(self, request: Request):

def _executor_loop(request: Request):
key = request.guided_decoding_key
with self._lock:
cache_hit = False
if key in self.grammar_cache:
cache_hit, entry = True, self.grammar_cache[key]
else:
entry = GrammarCache(None, threading.Event())
self.grammar_cache[key] = entry

if cache_hit:
entry.event.wait()
else:
entry.value = self.initialize_cache(key)
entry.event.set()
return copy.copy(entry.value) if entry.value else None

return self.executor.submit(_executor_loop, request)

def get(self, request: Request):
with self._lock:
entry = self.grammar_cache.get(request.guided_decoding_key)
if entry is None or not entry.event.is_set(): return None

Check failure on line 78 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/v1/core/guided_decoding/__init__.py:78:57: E701 Multiple statements on one line (colon)
return copy.copy(entry.value) if entry.value else None

def collect(self, request: Request):
if not request.use_guided_decoding: return False

Check failure on line 82 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/v1/core/guided_decoding/__init__.py:82:43: E701 Multiple statements on one line (colon)
request.grammar = self.get(request)
if not request.grammar:
request.grammar = self.cache(request)
request.status = RequestStatus.WAITING_FOR_FSM
return True
return False

@classmethod
def from_backend(cls,
backend: LiteralString = "xgrammar",
/,
*,
tokenizer_group: BaseTokenizerGroup,
model_config: ModelConfig) -> GuidedDecodingManager[T]:
manager_cls = cls._registry.get(backend)
if manager_cls is None:
raise ValueError(
f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}"

Check failure on line 100 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/guided_decoding/__init__.py:100:81: E501 Line too long (103 > 80)
)
return manager_cls(tokenizer_group=tokenizer_group,
model_config=model_config)

_registry: dict[str, type[GuidedDecodingManager[T]]] = {}
_backend: T

def __init__(self, *, tokenizer_group: BaseTokenizerGroup,
model_config: ModelConfig):
self.model_config = model_config
self.tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {}
self.executor = ThreadPoolExecutor()
self._lock = threading.Lock()

def __init_subclass__(cls, **kwargs: Any):
if not hasattr(cls, '__orig_bases__'):
raise TypeError(
f"{cls.__qualname__} must be subclass of GuidedDecodingManager"
)

backend = None
for base in cls.__orig_bases__:
if (origin := get_args(base)) and issubclass(
base.__origin__, GuidedDecodingManager):
backend = get_args(origin[0])[0]
break

if backend is None:
raise TypeError(
f"Class {cls.__qualname__} must specify backend as a Literal type"

Check failure on line 131 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/guided_decoding/__init__.py:131:81: E501 Line too long (82 > 80)
)

if backend in cls._registry:
name = cls._registry[backend].__qualname__
raise ValueError(
f"Backend '{backend}' is already registered to {name}")

# Set the backend value from the Literal type
cls._backend = backend
cls._registry[backend] = cls


class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]):
# cache GrammarCompiler instances based on given tokenizer
_compiler_cache: dict[str, xgr.GrammarCompiler] = {}
_compiler: xgr.GrammarCompiler | None = None

def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar:
request_type, grammar_spec = key
compiler = XGrammarManager.get_compiler(self.tokenizer)
if request_type == "json":
if type(grammar_spec) is not str:
ctx = compiler.compile_builtin_json_grammar()
else:
ctx = compiler.compile_json_schema(grammar_spec)
elif request_type == "grammar":
ctx = compiler.compile_grammar(grammar_spec)
else:
raise ValueError("grammar is not of valid supported types.")
return Grammar.from_backend(
self._backend,
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.model_config.hf_text_config.vocab_size,
ctx=ctx)

def flush(self):
super().flush()
if self._compiler: self._compiler.clear_cache()

Check failure on line 169 in vllm/v1/core/guided_decoding/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/v1/core/guided_decoding/__init__.py:169:26: E701 Multiple statements on one line (colon)
for compiler in self._compiler_cache.values():
compiler.clear_cache()
self._compiler_cache.clear()

@classmethod
def get_compiler(
cls,
tokenizer: PreTrainedTokenizer,
*,
max_threads: int = 8,
# passthrough to TokenizerInfo
vocab_size: int | None = None,
stop_token_ids: list[int] | int | None = None
) -> xgr.GrammarCompiler:
cache_key = str(hash(tokenizer))
if cache_key not in cls._compiler_cache:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
stop_token_ids=stop_token_ids,
vocab_size=vocab_size)
cls._compiler_cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=max_threads)
return cls._compiler_cache[cache_key]
Loading
Loading