Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Disaggregated] patch the inflight batching on the decode node in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) #13987

Merged

Conversation

hasB4K
Copy link
Contributor

@hasB4K hasB4K commented Feb 27, 2025

Hello 👋 ,

While experimenting with disaggregated serving with nccl and vllm, I encountered a bug when running mistral-large.
After a lot of debugging I realized that the inflight batching was the issue. On the decode node you can have a request that is in prefilling (from which the KVCache needs to be fetched) that is being batched with decoding requests.

The current implementation is bugged, here an example:

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens

  • model_input contains 3 sequences: S1 is a prefilling one, S2 and S3 are decodes ones
  • The model_input.attn_metadata.seq_lens will however have values that look like that [18, 17, 16]
  • The query_lens (here) will have values like this: [18, 1, 1]
  • The total size on input_tokens_tensor is 20 (18 + 1 + 1). So using seq_lens seems wrong in the first place.

current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self.select(current_tokens,

So, in this example, the KVCache of S1 will be properly fetched. But, when reaching S2 the current_tokens of input_tensor will be the slice input_tensor[18:35]... which doesn't makes sense because the size of input_tensor is 20. But pytorch won't throw an IndexError and will return a tensor of size 2 with the tokens of S2 and S3. It will then creates a hangs in the prefill node, because the self.select() call is done with invalid tokens.

I think the patch is pretty self explanatory otherwise. But do not hesitate if you have questions. 😉

I have however some follow-up questions:

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

…ode in SimpleConnector

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
@hasB4K hasB4K force-pushed the vllm-public-hasb4k-2025-02-27_disagg-patch branch from ee70a61 to b6544a0 Compare February 27, 2025 19:19
@KuntaiDu
Copy link
Collaborator

In the current implementation I am assuming there is no chunked prefill so that prefill and decode jobs won't appear in the same batch.

@hasB4K
Copy link
Contributor Author

hasB4K commented Feb 28, 2025

In the current implementation I am assuming there is no chunked prefill so that prefill and decode jobs won't appear in the same batch.

I understood that, at least for the Decode node. Is it also the case for the Prefill node?

Anyway, I think a warning and/or set bypass_model_exec=True is still better than nothing 😅

@KuntaiDu
Copy link
Collaborator

I understood that, at least for the Decode node. Is it also the case for the Prefill node?

Anyway, I think a warning and/or set bypass_model_exec=True is still better than nothing 😅

I guess you are referring to the case where the prefill node performs chunked prefill. Just want to clarify that chunked prefill + disaggregated prefill can be useful but from what I heard it is not the default usecase so we deprioritized the support.

But yeah, it's definitely better to have a warning there.

@KuntaiDu
Copy link
Collaborator

Does it work with cuda graphs? I saw this piece of code that indicates that a padding can be added when cuda graphs are used. Could it create silent bugs like this one?

KV cache receive is synchronous in decoding node so it is compatible with cuda graph. KV cache send is asynchronous, but since prefill node does not actually use cuda graph I guess this is OK.

@KuntaiDu
Copy link
Collaborator

Will SimpleConnector be replaced by the LMCache Connector replace entirely at some point?

No. SimpleConnector will be the connector for disaggregated prefill CI so currently we are not planning to deprecate it.

@KuntaiDu
Copy link
Collaborator

Is there a reason why we need min_length slicing?

This min_length mainly intends to handle non-disaggregated-prefill use cases (e.g. connecting vLLM to persistent KV store). In that case it is likely that only a part of prefix has KV cache matched.

Copy link
Collaborator

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@KuntaiDu KuntaiDu enabled auto-merge (squash) February 28, 2025 03:07
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 28, 2025
@KuntaiDu KuntaiDu merged commit b9e4173 into vllm-project:main Feb 28, 2025
53 checks passed
@hasB4K hasB4K deleted the vllm-public-hasb4k-2025-02-27_disagg-patch branch February 28, 2025 08:17
kylehh pushed a commit to kylehh/vllm that referenced this pull request Feb 28, 2025
…e in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) (vllm-project#13987)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants