Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Guided decoding falls back to outlines when fails to import xgrammar #12976

Merged
merged 7 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def maybe_backend_fallback(
guided_params.backend = "outlines"

if guided_params.backend == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
Expand Down Expand Up @@ -77,6 +79,13 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
logger.warning("xgrammar module cannot be imported successfully. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
xgr_installed = True
except ImportError:
xgr_installed = False
pass

from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
Expand Down