Skip to content

Commit

Permalink
[Core] choice-based structured output with xgrammar (vllm-project#12632)
Browse files Browse the repository at this point in the history
  • Loading branch information
russellb authored and I746365 committed Feb 15, 2025
1 parent 3096661 commit 4af5c25
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11
lark == 1.2.2
xgrammar >= 0.1.6; platform_machine == "x86_64"
xgrammar >= 0.1.11; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar doesn't support regex or choice, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None:
logger.warning(
"xgrammar only supports json or grammar guided decoding. "
"Falling back to use outlines instead.")
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
logger.warning("xgrammar does not support regex guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar doesn't support some JSON schema features
Expand Down
31 changes: 30 additions & 1 deletion vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import copy
import json
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, List

import torch
from transformers import PreTrainedTokenizerFast
Expand Down Expand Up @@ -228,11 +229,39 @@ def from_guided_params(cls,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
elif guided_params.choice:
choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
try:
xgr.Grammar.from_ebnf(choice_str)
except RuntimeError as err:
raise ValueError(str(err)) from err

return cls(
grammar_str=choice_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
)

@staticmethod
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)

@staticmethod
def choice_as_grammar(choice: List[str] | None) -> str:
if choice is None:
raise ValueError("Choice is not set")
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar


@dataclass
class XGrammarLogitsProcessor:
Expand Down

0 comments on commit 4af5c25

Please sign in to comment.