Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Structured decoding #12388

Open
wants to merge 94 commits into
base: main
Choose a base branch
from

Conversation

aarnphm
Copy link
Contributor

@aarnphm aarnphm commented Jan 24, 2025

Add structured decoding to v1 core engine.

Currently the grammars are being setup/cached on the scheduler.

Signed-off-by: Aaron Pham contact@aarnphm.xyz

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Jan 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @aarnphm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 24, 2025
@aarnphm aarnphm force-pushed the v1/structured-decoding branch from f66e36b to 55741d4 Compare January 24, 2025 06:11
@mergify mergify bot removed the needs-rebase label Jan 24, 2025
@aarnphm aarnphm changed the title feat: initial guided decoding implementation on scheduler [V1][Core] Initial guided decoding implementation on scheduler Jan 24, 2025
@aarnphm aarnphm changed the title [V1][Core] Initial guided decoding implementation on scheduler [V1][Core] Structured decoding on scheduler-level Jan 24, 2025
@aarnphm aarnphm force-pushed the v1/structured-decoding branch from 57e16d9 to 78bfa36 Compare January 27, 2025 06:38
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
@aarnphm aarnphm force-pushed the v1/structured-decoding branch from 78bfa36 to d719c93 Compare January 27, 2025 06:51
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
@mgoin mgoin self-requested a review January 29, 2025 16:21
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
@aarnphm aarnphm force-pushed the v1/structured-decoding branch from bf23eb2 to 733fef4 Compare January 31, 2025 02:04
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
@aarnphm aarnphm force-pushed the v1/structured-decoding branch from 733fef4 to 420f52f Compare January 31, 2025 02:05
@mmoskal
Copy link
Contributor

mmoskal commented Jan 31, 2025

I love seeing structured decoding being intergrated deeply inside of vLLM!

I would love to see llguidance being supported though. Compared to XGrammar, it is significantly faster, has near-zero compilation time, and has much broader JSON Schema support. We've been using it in production instances.

If needed I'm happy to add additional APIs to the Python bindings (server-side integrations so far have been native) or otherwise help.

timing diagram

@mergify mergify bot added the v1 label Feb 1, 2025
Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @aarnphm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
@mergify mergify bot removed the needs-rebase label Feb 6, 2025
russellb and others added 5 commits February 26, 2025 14:44
The scheduler sends a bitmask for guided decoding down to the gpu worker,
but the indices into this bitmask may not match the order of requests
used in the gpu worker. This change detects the discrepancy and creates
a reordered bitmask when necessary before applying it to the logits.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
We cache the compiled grammar, but we need a unique matcher instance
for each request. The code previously re-used the same matcher for all
requests using the same grammar. If multiple parallel requests had the
same grammar, they would mostly fail as a result.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This code did a bit of a dance to get the correct indices for the logits
and then used the old wrong ones.  Oops.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
@aarnphm aarnphm marked this pull request as ready for review February 26, 2025 22:46
Comment on lines +64 to +67
# With list, we can safely pop the index
# of a request that are yet to be ready (in this case,
# the one that uses guided decoding) while still maintaining
# the order of all requests in existing waiting queue.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite -- we pop a ready request from the middle, even when higher priority guided decoding requests remain in the list ahead of it since they're not ready. we don't pop the not-ready ones

Comment on lines +365 to +379
# Prepare the guided decoding bitmask for this batch.
grammar_bitmask = None
if guided_decoding_request_ids:
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
grammar_bitmask = self.guided_decoding_manager.grammar_bitmask
assert grammar_bitmask is not None
for req_id, batch_index in guided_decoding_request_ids.items():
request = self.requests[req_id]
assert request.grammar is not None
if not request.grammar.matcher.is_terminated():
request.grammar.fill_bitmask(grammar_bitmask, batch_index)
if len(self.running) < grammar_bitmask.shape[0]:
grammar_bitmask = grammar_bitmask[:len(self.running)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for follow-up. I don't think we need to block on it.

When this was originally written, I thought we'd be re-using this bitmask as-is in the GPU worker. However, the GPU worker maintains its own batch data structures that are not in the same order as the scheduler. We end up having to reconstruct the bitmask there, so there's no good reason to have one that has empty unused rows in it. We should just create one that has a size equal to the number of guided requests to save some space.

@@ -131,6 +135,9 @@ def add_request(self, request: EngineCoreRequest):
request.mm_inputs, request.mm_hashes)

req = Request.from_engine_core_request(request)
if req.use_guided_decoding:
# Start grammar compilation asynchronously
self.guided_decoding_manager.should_cache(req)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider renaming this. It returns a bool, but we don't use it. Maybe warmup_cache() or populate_cache() or something. What we're really saying is "check if you have a grammar cached already and if not, get that started!"

def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""

if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())

# Check for cached grammars and allocate bitmask if necessary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check for cached grammars and allocate bitmask if necessary
# Check cache for compiled grammars and add them to requests
# when they're ready.

tokenizer, vocab_size=self.vocab_size)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)

self.request_key_to_grammar: Dict[GuidedDecodingKey, Grammar] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to cap the size of this cache (make it an LRU cache perhaps). Otherwise, we have a denial-of-service vulnerability here. A malicious user can send a stream of requests with a slightly different valid grammar each time and the memory consumption here will grow unbounded.

Comment on lines +249 to +250
if (request.use_guided_decoding
and request.status == RequestStatus.WAITING_FOR_FSM):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for use_guided_decoding doesn't seem necessary here

Suggested change
if (request.use_guided_decoding
and request.status == RequestStatus.WAITING_FOR_FSM):
if request.status == RequestStatus.WAITING_FOR_FSM:

Comment on lines +223 to +226
@staticmethod
def is_waiting(status: RequestStatus) -> bool:
return status <= RequestStatus.WAITING_FOR_FSM

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this. It's not necessary

Comment on lines +615 to +616
"Failed to advance FSM for request %s "
"for tokens %s", req_id, generated_token_ids[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this happens, it's a bug. It would be kind to our users to make that a little more clear in the message instead of implying there may be something they did wrong.

Suggested change
"Failed to advance FSM for request %s "
"for tokens %s", req_id, generated_token_ids[0])
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.",
req_id, generated_token_ids[0])

Comment on lines +722 to +725
if request.use_guided_decoding:
# NOTE: grammar should NOT be None
# if use_guided_decoding is True
request.grammar.reset() # type: ignore[union-attr]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably isn't necessary. Each request now gets its own matcher instance. Each matcher shares a common, cached, compiled grammar. We're resetting the matcher here, but it's not reused so it's probably just a waste. Need to run some load tests to be sure, but I think so!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if request.use_guided_decoding:
# NOTE: grammar should NOT be None
# if use_guided_decoding is True
request.grammar.reset() # type: ignore[union-attr]

scheduler_output = self.scheduler.schedule()

if scheduler_output.total_num_scheduled_tokens == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if scheduler_output.total_num_scheduled_tokens == 0:
# This case may occur when the only unfinished requests are
# guided decoding requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:

Comment on lines +171 to +172
# the bitmask allocation for grammars
# should be ready at this point.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bitmask allocation is only done once at startup and is not asynchronous (anymore), so this comment isn't relevant anymore.

Suggested change
# the bitmask allocation for grammars
# should be ready at this point.


# the bitmask allocation for grammars
# should be ready at this point.
# Currently we will broadcast the bitmask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Currently we will broadcast the bitmask
# Currently we will broadcast the bitmask. It is populated during
# each schedule() run.

def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""

if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())

# Check for cached grammars and allocate bitmask if necessary
self.setup_grammars()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this change we can remove setup_grammars() from this class since it's just a direct passthrough.

Suggested change
self.setup_grammars()
self.guided_decoding_manager.setup_grammars()

Comment on lines +263 to +265
def setup_grammars(self):
self.guided_decoding_manager.setup_grammars()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def setup_grammars(self):
self.guided_decoding_manager.setup_grammars()

Comment on lines +88 to +98
def _validate_guided_decoding(
self, params: Union[SamplingParams, PoolingParams]) -> None:
if not isinstance(params, SamplingParams):
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar guided decoding is supported in V1.")
if (params.guided_decoding and params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar guided decoding is supported in V1.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation got added in two different spots at different times during development (and I think they were both me, oops!)

We should reconcile the changes here with the checks triggered in core_client.py

@@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a blocker, but I think it'd be nice to move a lot of this code out into new files instead of putting it all in __init__.py



GuidedDecodingKey = Tuple[GuidedDecodingOptions, str]
MAX_ROLLBACK_TOKENS = 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does 100 come from?

Comment on lines +61 to +66
_matcher_lock: threading.Lock = field(
default_factory=lambda: threading.Lock(),
repr=False,
init=False,
hash=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this lock can be removed. We are not sharing a matcher between requests anymore (we can't), so the lock should not be necessary

Comment on lines +89 to +93
def __copy__(self):
return Grammar(matcher=xgr.GrammarMatcher(self.ctx),
vocab_size=self.vocab_size,
ctx=self.ctx,
max_rollback_tokens=self.max_rollback_tokens)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify code in this file if we only cache the compiled grammar from xgrammar and not an instance of this Grammar class. It's the compiled grammar that's the only thing that gets reused. That way we'd just instantiate a new Grammar each time instead of doing a copy.copy() on a Grammar from the cache.


self.request_key_to_grammar: Dict[GuidedDecodingKey, Grammar] = {}

self.executor = ThreadPoolExecutor()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default thread pool size is too high. We need to pin it to something more reasonable.

https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor

If max_workers is None or not given, it will default to the number of processors on the machine, multiplied by 5, assuming that ThreadPoolExecutor is often used to overlap I/O instead of CPU work

Comment on lines +116 to +117
self.requests: Set[Request] = set()
self._requests_lock = threading.Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the purpose of this set of requests could use some explanation. I feel like we could simplify this, but some comments explaining what's going on here would probably be helpful in the meantime.

Comment on lines +174 to +177
json_str = json.dumps(params.choice)
else:
json_str = params.choice
return (GuidedDecodingOptions.choice, json_str)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of json_str as a variable name here is confusing.

I also don't think converting it to a string is the right thing. This code may not ever run because we actually convert choice into a grammar elsewhere.

Co-authored-by: Russell Bryant <rbryant@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.