diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py
new file mode 100644
index 0000000000000..8eb08f3e842ca
--- /dev/null
+++ b/tests/v1/core/test_scheduler.py
@@ -0,0 +1,214 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import List, Optional
+
+from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
+from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
+from vllm.sampling_params import SamplingParams
+from vllm.v1.core.scheduler import Scheduler
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.request import Request, RequestStatus
+
+
+def create_scheduler(
+ model: str = "facebook/opt-125m",
+ max_num_seqs: int = 16,
+ max_num_batched_tokens: int = 8192,
+) -> Scheduler:
+ scheduler_config = SchedulerConfig(
+ max_num_seqs=max_num_seqs,
+ max_num_batched_tokens=max_num_batched_tokens,
+ max_model_len=max_num_batched_tokens,
+ )
+ model_config = ModelConfig(
+ model=model,
+ task="auto",
+ tokenizer=model,
+ tokenizer_mode="auto",
+ trust_remote_code=True,
+ dtype="float16",
+ seed=42,
+ )
+ cache_config = CacheConfig(
+ block_size=16,
+ gpu_memory_utilization=0.9,
+ swap_space=0,
+ cache_dtype="auto",
+ )
+ cache_config.num_gpu_blocks = 10000
+ return Scheduler(scheduler_config,
+ model_config,
+ cache_config,
+ lora_config=None)
+
+
+def create_requests(
+ num_requests: int,
+ num_tokens: int = 10,
+ mm_positions: Optional[List[PlaceholderRange]] = None,
+):
+ sampling_params = SamplingParams()
+ requests = []
+ for i in range(num_requests):
+ if mm_positions is not None:
+ mm_position = mm_positions[i]
+ mm_inputs = [MultiModalKwargs({})] * len(mm_position)
+ else:
+ mm_position = None
+ mm_inputs = None
+ request = Request(
+ request_id=f"{i}",
+ prompt=None,
+ prompt_token_ids=[i] * num_tokens,
+ sampling_params=sampling_params,
+ multi_modal_inputs=mm_inputs,
+ multi_modal_placeholders=mm_position,
+ multi_modal_hashes=None,
+ eos_token_id=None,
+ arrival_time=0,
+ )
+ requests.append(request)
+ return requests
+
+
+def test_add_requests():
+ scheduler = create_scheduler()
+ requests = create_requests(num_requests=10)
+
+ for i, request in enumerate(requests):
+ scheduler.add_request(request)
+ assert request.request_id in scheduler.requests
+ assert len(scheduler.waiting) == i + 1
+
+
+def test_finish_request():
+ scheduler = create_scheduler()
+ requests = create_requests(num_requests=10)
+ for request in requests:
+ scheduler.add_request(request)
+
+ for i, request in enumerate(requests):
+ scheduler.finish_requests(request.request_id,
+ RequestStatus.FINISHED_ABORTED)
+ assert request.request_id not in scheduler.requests
+ assert len(scheduler.waiting) == 9 - i
+
+
+def test_get_num_unfinished_requests():
+ scheduler = create_scheduler()
+ requests = create_requests(num_requests=10)
+ for request in requests:
+ scheduler.add_request(request)
+
+ for i, request in enumerate(requests):
+ scheduler.finish_requests(request.request_id,
+ RequestStatus.FINISHED_STOPPED)
+ assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
+
+
+def test_schedule():
+ scheduler = create_scheduler()
+ requests = create_requests(num_requests=10)
+ for request in requests:
+ scheduler.add_request(request)
+
+ # Test initial scheduling
+ output = scheduler.schedule()
+ assert len(output.scheduled_new_reqs) == len(requests)
+ assert len(output.scheduled_cached_reqs) == 0
+ assert len(output.finished_req_ids) == 0
+ # Verify all requests are scheduled.
+ for req_id, num_tokens in output.num_scheduled_tokens.items():
+ assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
+
+ # Verify requests moved from waiting to running
+ assert len(scheduler.waiting) == 0
+ assert len(scheduler.running) == len(requests)
+ for i, request in enumerate(requests):
+ assert scheduler.running[i] == request
+
+
+def test_schedule_multimodal_requests():
+ scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
+ mm_positions = [[PlaceholderRange(offset=i, length=100)]
+ for i in range(10)]
+ requests = create_requests(
+ num_requests=10,
+ num_tokens=200,
+ mm_positions=mm_positions,
+ )
+ for request in requests:
+ scheduler.add_request(request)
+
+ output = scheduler.schedule()
+ assert len(output.scheduled_new_reqs) == len(requests)
+ assert len(output.scheduled_cached_reqs) == 0
+ assert len(output.finished_req_ids) == 0
+ for req_id, num_tokens in output.num_scheduled_tokens.items():
+ assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
+ assert len(output.scheduled_encoder_inputs) == 10
+ for req_id, encoder_input in output.scheduled_encoder_inputs.items():
+ assert len(encoder_input) == 1
+
+
+def test_schedule_partial_requests():
+ """Test scheduling behavior with partial requests.
+
+ This test verifies that:
+ 1. The scheduler can handle multiple partial requests in a single step when
+ constrained by encoder budget.
+ 2. A request in RUNNING state may be unscheduled in subsequent steps if
+ there is insufficient encoder budget.
+ """
+ scheduler = create_scheduler(
+ model="llava-hf/llava-1.5-7b-hf",
+ max_num_batched_tokens=1024,
+ )
+ mm_positions = [[PlaceholderRange(offset=100, length=600)]
+ for _ in range(3)]
+ requests = create_requests(
+ num_requests=3,
+ num_tokens=800,
+ mm_positions=mm_positions,
+ )
+ for request in requests:
+ scheduler.add_request(request)
+
+ output = scheduler.schedule()
+ assert len(output.scheduled_new_reqs) == 3
+ assert len(output.scheduled_cached_reqs) == 0
+ assert len(output.finished_req_ids) == 0
+
+ assert scheduler.max_num_encoder_input_tokens == 1024
+ # The first request is scheduled fully.
+ assert output.num_scheduled_tokens[requests[0].request_id] == 800
+ # The second request is scheduled partially.
+ # The
tokens are not scheduled because of the encoder budget.
+ assert output.num_scheduled_tokens[requests[1].request_id] == 100
+ # The third request is also scheduled partially.
+ # The
tokens are not scheduled because of the encoder budget.
+ assert output.num_scheduled_tokens[requests[2].request_id] == 100
+ req_to_index = {
+ request.request_id: i
+ for i, request in enumerate(requests)
+ }
+ model_runner_output = ModelRunnerOutput(
+ req_ids=[request.request_id for request in requests],
+ req_id_to_index=req_to_index,
+ sampled_token_ids=[0] * len(requests),
+ logprob_token_ids_cpu=None,
+ logprobs_cpu=None,
+ )
+ scheduler.update_from_output(output, model_runner_output)
+
+ # Schedule the next step.
+ # Only the first and second requests are scheduled.
+ # The third request is in the RUNNING state but not scheduled in this step
+ # because of the encoder budget.
+ output = scheduler.schedule()
+ assert len(scheduler.running) == 3
+ assert len(output.scheduled_new_reqs) == 0
+ assert len(output.scheduled_cached_reqs) == 2
+ assert len(output.finished_req_ids) == 0
+ assert output.num_scheduled_tokens[requests[0].request_id] == 1
+ assert output.num_scheduled_tokens[requests[1].request_id] == 700
+ assert requests[2].request_id not in output.num_scheduled_tokens
diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py
index f4738bb33c603..fb5e83fe06274 100644
--- a/vllm/v1/core/scheduler.py
+++ b/vllm/v1/core/scheduler.py
@@ -67,10 +67,10 @@ def __init__(
# This is flushed at the end of each scheduling step.
self.finished_req_ids: Set[str] = set()
- # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
+ # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
- # Request id -> RunningRequestData
- self.running_reqs_data: Dict[str, RunningRequestData] = {}
+ # Request id -> CachedRequestData
+ self._cached_reqs_data: Dict[str, CachedRequestData] = {}
# Encoder-related.
# Calculate encoder cache size if applicable
@@ -115,17 +115,8 @@ def schedule(self) -> "SchedulerOutput":
encoder_budget = self.max_num_encoder_input_tokens
# First, schedule the RUNNING requests.
- # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
- # in the "partial" state, where the request has some tokens computed
- # but not all. The constraint is due to the persistent batch in the
- # V1 model runner.
- # TODO(woosuk): Remove this constraint after refactoring model runner.
- has_partial_request = False
req_index = 0
- while req_index < len(self.running):
- # Only the last request in the RUNNING queue can be "partial".
- assert not has_partial_request
- assert token_budget > 0
+ while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
@@ -137,7 +128,14 @@ def schedule(self) -> "SchedulerOutput":
request.num_computed_tokens,
num_new_tokens,
encoder_budget))
- assert num_new_tokens > 0
+ if num_new_tokens == 0:
+ # The request cannot be scheduled because the encoder budget
+ # or the encoder cache is exhausted.
+ # NOTE(woosuk): Here, by doing `continue` instead of `break`,
+ # we do not strictly follow the FCFS scheduling policy and
+ # allow the lower-priority requests to be scheduled.
+ req_index += 1
+ continue
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
@@ -172,8 +170,6 @@ def schedule(self) -> "SchedulerOutput":
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
- has_partial_request = (request.num_computed_tokens + num_new_tokens
- < request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
@@ -186,13 +182,9 @@ def schedule(self) -> "SchedulerOutput":
# Next, schedule the WAITING requests.
if not preempted_reqs:
- while self.waiting:
- if has_partial_request:
- break
+ while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
- if token_budget == 0:
- break
request = self.waiting[0]
# Get already-cached tokens.
@@ -249,8 +241,6 @@ def schedule(self) -> "SchedulerOutput":
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
- has_partial_request = (num_computed_tokens + num_new_tokens
- < request.num_tokens)
# Encoder-related.
if encoder_inputs_to_schedule:
@@ -266,8 +256,11 @@ def schedule(self) -> "SchedulerOutput":
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
+ # Since some requests in the RUNNING queue may not be scheduled in
+ # this step, the total number of scheduled requests can be smaller than
+ # len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
- len(scheduled_running_reqs) == len(self.running))
+ len(scheduled_running_reqs) <= len(self.running))
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
@@ -286,25 +279,28 @@ def schedule(self) -> "SchedulerOutput":
for req in scheduled_new_reqs
]
resumed_reqs_data = [
- ResumedRequestData.from_request(
- req, req_to_new_block_ids[req.request_id],
- req.num_computed_tokens) for req in scheduled_resumed_reqs
+ self._make_cached_request_data(
+ req,
+ req_to_new_block_ids[req.request_id],
+ req.num_computed_tokens,
+ resumed_from_preemption=True,
+ ) for req in scheduled_resumed_reqs
]
running_reqs_data = [
- self._make_running_request_data(
- req, req_to_new_block_ids[req.request_id],
- req.num_computed_tokens) for req in scheduled_running_reqs
+ self._make_cached_request_data(
+ req,
+ req_to_new_block_ids[req.request_id],
+ req.num_computed_tokens,
+ resumed_from_preemption=False,
+ ) for req in scheduled_running_reqs
]
- preempted_req_ids = {req.request_id for req in preempted_reqs}
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
- scheduled_resumed_reqs=resumed_reqs_data,
- scheduled_running_reqs=running_reqs_data,
+ scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
- preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
@@ -316,22 +312,26 @@ def schedule(self) -> "SchedulerOutput":
self.finished_req_ids = set()
return scheduler_output
- def _make_running_request_data(
+ def _make_cached_request_data(
self,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
- ) -> "RunningRequestData":
- # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
+ resumed_from_preemption: bool,
+ ) -> "CachedRequestData":
+ # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
- if request.request_id in self.running_reqs_data:
- req_data = self.running_reqs_data[request.request_id]
+ if request.request_id in self._cached_reqs_data:
+ req_data = self._cached_reqs_data[request.request_id]
+ req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
- req_data = RunningRequestData.from_request(request, new_block_ids,
- num_computed_tokens)
- self.running_reqs_data[request.request_id] = req_data
+ req_data = CachedRequestData.from_request(request,
+ resumed_from_preemption,
+ new_block_ids,
+ num_computed_tokens)
+ self._cached_reqs_data[request.request_id] = req_data
return req_data
def _try_schedule_encoder_inputs(
@@ -420,7 +420,13 @@ def update_from_output(
# expensive operations inside the loop.
for request in self.running:
req_id = request.request_id
- request.num_computed_tokens += num_scheduled_tokens[req_id]
+ num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
+ if num_tokens_scheduled == 0:
+ # The request was not scheduled in this step.
+ new_running.append(request)
+ continue
+
+ request.num_computed_tokens += num_tokens_scheduled
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
@@ -529,7 +535,7 @@ def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
- self.running_reqs_data.pop(request.request_id, None)
+ self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
@@ -584,30 +590,13 @@ def from_request(
@dataclass
-class ResumedRequestData:
-
- req_id: str
- block_ids: List[int]
- num_computed_tokens: int
-
- @classmethod
- def from_request(
- cls,
- request: Request,
- block_ids: List[int],
- num_computed_tokens: int,
- ) -> "ResumedRequestData":
- return cls(
- req_id=request.request_id,
- block_ids=block_ids,
- num_computed_tokens=num_computed_tokens,
- )
-
-
-@dataclass
-class RunningRequestData:
+class CachedRequestData:
req_id: str
+ # If resumed_from_preemption is False, new_block_ids will be appended to
+ # the request's block IDs. If True, new_block_ids will be used as the
+ # request's block IDs instead of appending to the existing block IDs.
+ resumed_from_preemption: bool
new_block_ids: List[int]
num_computed_tokens: int
@@ -615,11 +604,13 @@ class RunningRequestData:
def from_request(
cls,
request: Request,
+ resumed_from_preemption: bool,
new_block_ids: List[int],
num_computed_tokens: int,
- ) -> "RunningRequestData":
+ ) -> "CachedRequestData":
return cls(
req_id=request.request_id,
+ resumed_from_preemption=resumed_from_preemption,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
@@ -629,14 +620,12 @@ def from_request(
class SchedulerOutput:
scheduled_new_reqs: List[NewRequestData]
- scheduled_resumed_reqs: List[ResumedRequestData]
- scheduled_running_reqs: List[RunningRequestData]
+ scheduled_cached_reqs: List[CachedRequestData]
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
num_common_prefix_blocks: int
- preempted_req_ids: Set[str]
finished_req_ids: Set[str]
free_encoder_input_ids: List[Tuple[str, int]]
diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py
index 8d0785243c716..f520ee9586c5c 100644
--- a/vllm/v1/worker/block_table.py
+++ b/vllm/v1/worker/block_table.py
@@ -46,6 +46,8 @@ def append_row(
start: int,
block_ids: List[int],
) -> None:
+ if not block_ids:
+ return
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 0b5644525553e..7841fac1df34b 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -205,12 +205,32 @@ def __init__(
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
- def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
- # Remove stopped requests from the cached states.
- # Keep the states of the preempted requests.
+ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
+ """Update the cached states and the persistent batch with the scheduler
+ output.
+
+ The updated states are used by the `_prepare_inputs` function to create
+ the input GPU tensors for the model.
+
+ Returns:
+ True if there is a new/resumed/paused/finished request in the batch.
+ If False, we can skip copying SamplingMetadata to the GPU.
+ """
+ # Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
+ # Remove the finished requests from the persistent batch.
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
+ # scheduled_req_ids overlap. This happens when a request is aborted and
+ # then resubmitted with the same ID. In this case, we treat them as two
+ # distinct requests - clearing the cached states for the first request
+ # and handling the second as a new request.
+ removed_req_indices: List[int] = []
+ for req_id in scheduler_output.finished_req_ids:
+ req_index = self.input_batch.remove_request(req_id)
+ if req_index is not None:
+ removed_req_indices.append(req_index)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
@@ -220,36 +240,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
- # Remove the requests from the persistent batch.
- stopped_req_ids = set().union(
- scheduler_output.preempted_req_ids,
- scheduler_output.finished_req_ids,
- )
- removed_req_indices: List[int] = []
- for req_id in stopped_req_ids:
+ # Remove the unscheduled requests from the persistent batch.
+ # NOTE(woosuk): The unscheduled requests are either preempted requests
+ # or running requests that are not scheduled in this step. We remove
+ # them from the persistent batch but keep their cached states since
+ # they will be scheduled again sometime in the future.
+ scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
+ cached_req_ids = self.input_batch.req_id_to_index.keys()
+ unscheduled_req_ids = cached_req_ids - scheduled_req_ids
+ # NOTE(woosuk): The persistent batch optimization assumes that
+ # consecutive batches contain mostly the same requests. If batches
+ # have low request overlap (e.g., alternating between two distinct
+ # sets of requests), this optimization becomes very inefficient.
+ for req_id in unscheduled_req_ids:
req_index = self.input_batch.remove_request(req_id)
- if req_index is not None:
- removed_req_indices.append(req_index)
-
- # Update the states of the running requests.
- for req_data in scheduler_output.scheduled_running_reqs:
- req_id = req_data.req_id
- req_state = self.requests[req_id]
- req_index = self.input_batch.req_id_to_index[req_id]
-
- # Update the num_computed_tokens.
- req_state.num_computed_tokens = req_data.num_computed_tokens
- self.input_batch.num_computed_tokens_cpu[req_index] = (
- req_data.num_computed_tokens)
-
- # Update the block table.
- num_new_blocks = len(req_data.new_block_ids)
- if num_new_blocks == 0:
- continue
- start_index = len(req_state.block_ids)
- req_state.block_ids.extend(req_data.new_block_ids)
- self.input_batch.block_table.append_row(req_index, start_index,
- req_data.new_block_ids)
+ assert req_index is not None
+ removed_req_indices.append(req_index)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
@@ -305,14 +311,36 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_ids_to_add.append(req_id)
- # Update the cached states of the resumed requests.
- for res_req_data in scheduler_output.scheduled_resumed_reqs:
- req_id = res_req_data.req_id
+ # Update the states of the running/resumed requests.
+ for req_data in scheduler_output.scheduled_cached_reqs:
+ req_id = req_data.req_id
req_state = self.requests[req_id]
- req_state.block_ids = res_req_data.block_ids
- req_state.num_computed_tokens = res_req_data.num_computed_tokens
- req_ids_to_add.append(req_id)
+ # Update the cached states.
+ req_state.num_computed_tokens = req_data.num_computed_tokens
+ if not req_data.resumed_from_preemption:
+ # Append the new blocks to the existing block IDs.
+ req_state.block_ids.extend(req_data.new_block_ids)
+ else:
+ # The request is resumed from preemption.
+ # Replace the existing block IDs with the new ones.
+ req_state.block_ids = req_data.new_block_ids
+
+ req_index = self.input_batch.req_id_to_index.get(req_id)
+ if req_index is None:
+ # The request is not in the persistent batch.
+ # The request was either preempted and resumed later, or was not
+ # scheduled in the previous step and needs to be added again.
+ req_ids_to_add.append(req_id)
+ continue
+
+ # Update the persistent batch.
+ self.input_batch.num_computed_tokens_cpu[req_index] = (
+ req_data.num_computed_tokens)
+ start_index = len(req_state.block_ids) - len(
+ req_data.new_block_ids)
+ self.input_batch.block_table.append_row(req_index, start_index,
+ req_data.new_block_ids)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@@ -330,6 +358,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
+ return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -536,10 +565,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
- # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
- # request in the batch. While we should not sample any token from this
- # partial request, we do so for simplicity. We will ignore the sampled
- # token from the partial request.
+ # NOTE(woosuk): Due to chunked prefills, the batch may contain partial
+ # requests. While we should not sample any token from these partial
+ # requests, we do so for simplicity. We will ignore the sampled
+ # tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
@@ -601,22 +630,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
def _prepare_sampling(
self,
- scheduler_output: "SchedulerOutput",
+ batch_changed: bool,
) -> SamplingMetadata:
- skip_copy = True
- if (scheduler_output.finished_req_ids
- or scheduler_output.preempted_req_ids):
- skip_copy = False
- if (scheduler_output.scheduled_new_reqs
- or scheduler_output.scheduled_resumed_reqs):
- skip_copy = False
# Create the sampling metadata.
req_id_output_token_ids: Dict[str, List[int]] = \
{req_id: req.output_token_ids \
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
- req_id_output_token_ids, skip_copy)
+ req_id_output_token_ids, skip_copy=not batch_changed)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
@@ -715,7 +737,7 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
- self._update_states(scheduler_output)
+ batch_changed = self._update_states(scheduler_output)
if self.is_multimodal_model:
# Run the multimodal encoder if any.
@@ -778,7 +800,7 @@ def execute_model(
logits = self.model.compute_logits(hidden_states, None)
# Sample the next token and get logprobs if needed.
- sampling_metadata = self._prepare_sampling(scheduler_output)
+ sampling_metadata = self._prepare_sampling(batch_changed)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,