Skip to content

Commit

Permalink
Fixes delayed sampling for sequential requests (#845)
Browse files Browse the repository at this point in the history
- Previously LLM.generate() could not be called multiple times with
delayed sampling enabled.
- This also was the case with step() calls
- Issue occurs when after the last (batch) request is finished, and
we're starting a new request, but `cached_step_inputs` and
`cached_step_outputs` still contain elements saved from the last served
(batch) request. This shouldn't be the case.
- The cleanest solution would be to skip appending to
[`cached_step_inputs/outputs`](https://github.com/HabanaAI/vllm-fork/blob/50b28af6491ed6eb75794d4968fe1c679e65ea92/vllm/worker/hpu_model_runner.py#L2610-L2611)
if the recently generated
[`output`](https://github.com/HabanaAI/vllm-fork/blob/50b28af6491ed6eb75794d4968fe1c679e65ea92/vllm/worker/hpu_model_runner.py#L2608)
is the final token generated for the current batch request. But couldn't
find a cleaner way to check for this in the model runner.
- So we instead check (in
[`_patch_prev_output`](https://github.com/HabanaAI/vllm-fork/blob/50b28af6491ed6eb75794d4968fe1c679e65ea92/vllm/worker/hpu_model_runner.py#L2776))
for when the scheduler context has empty output_queue, which means no
pending outputs to patch.

Tests here: habana-internal/mlperf_inference#158
  • Loading branch information
tianmu-li authored Feb 20, 2025
2 parents 50b28af + ee0b119 commit 6eeefdd
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2781,6 +2781,11 @@ def _patch_prev_output(self):
model_input = self.cached_step_inputs.pop(0)
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(-1).tolist()
ctx = model_input.async_callback.keywords["ctx"]
# If there's no output to patch with,
# which is usually the case when we're starting a new request after all in-flight requests are completed,
# We return (Note that we have now cleared the cached_step_inputs/outputs as required).
if len(ctx.output_queue) == 0:
return
assert len(ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
output_data = ctx.output_queue[0]
assert len(output_data.outputs) == 1
Expand All @@ -2792,4 +2797,4 @@ def _patch_prev_output(self):
# This is a hack. Assigning output_token_ids triggers
# a cache recomputation and we only need to update the last token
seq_data.output_token_ids_array[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out

0 comments on commit 6eeefdd

Please sign in to comment.