Skip to content

Commit

Permalink
[Bugfix][Disaggregated] patch the inflight batching on the decode nod…
Browse files Browse the repository at this point in the history
…e in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) (#13987)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
  • Loading branch information
hasB4K authored Feb 28, 2025
1 parent 1088f06 commit b9e4173
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def recv_kv_caches_and_hidden_states(

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()

hidden_or_intermediate_states_for_one_req = []
Expand All @@ -225,9 +226,21 @@ def recv_kv_caches_and_hidden_states(
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):

start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen

if start_pos >= num_prefill_tokens:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break

current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen

Expand Down Expand Up @@ -288,7 +301,7 @@ def recv_kv_caches_and_hidden_states(
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
Expand Down

0 comments on commit b9e4173

Please sign in to comment.