From 79961c9f21a189a37f02b8287b5b5721bc2392c3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 2 Feb 2025 20:42:32 -0800 Subject: [PATCH 1/8] [V1] Remove constraints on partial requests Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 14 -------------- vllm/v1/worker/gpu_model_runner.py | 28 ++++++++++++++++++---------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f4738bb33c603..bf4713e344ca8 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -115,16 +115,8 @@ def schedule(self) -> "SchedulerOutput": encoder_budget = self.max_num_encoder_input_tokens # First, schedule the RUNNING requests. - # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be - # in the "partial" state, where the request has some tokens computed - # but not all. The constraint is due to the persistent batch in the - # V1 model runner. - # TODO(woosuk): Remove this constraint after refactoring model runner. - has_partial_request = False req_index = 0 while req_index < len(self.running): - # Only the last request in the RUNNING queue can be "partial". - assert not has_partial_request assert token_budget > 0 request = self.running[req_index] num_new_tokens = request.num_tokens - request.num_computed_tokens @@ -172,8 +164,6 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 - has_partial_request = (request.num_computed_tokens + num_new_tokens - < request.num_tokens) # Encoder-related. if encoder_inputs_to_schedule: @@ -187,8 +177,6 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: - if has_partial_request: - break if len(self.running) == self.max_num_running_reqs: break if token_budget == 0: @@ -249,8 +237,6 @@ def schedule(self) -> "SchedulerOutput": token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - has_partial_request = (num_computed_tokens + num_new_tokens - < request.num_tokens) # Encoder-related. if encoder_inputs_to_schedule: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0b5644525553e..f1958877de1cc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -220,13 +220,21 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if not encoder_outputs: self.encoder_cache.pop(req_id, None) - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests include 1) finished requests, + # 2) preempted requests, and 3) running requests that are not scheduled + # in this step. For 1) finished requests, we will remove them from the + # persistent batch and the cached states. For 2) & 3), we will remove + # them from the persistent batch only and keep their cached states. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: + for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) @@ -536,10 +544,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. + # NOTE(woosuk): Due to chunked prefills, the batch may contain partial + # requests. While we should not sample any token from these partial + # requests, we do so for simplicity. We will ignore the sampled + # tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices From ac13c4b993a642f3b6d55c009094fe338e525bfd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Feb 2025 00:57:47 -0800 Subject: [PATCH 2/8] Fix Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 109 ++++++++++++++--------------- vllm/v1/worker/block_table.py | 2 + vllm/v1/worker/gpu_model_runner.py | 72 +++++++++---------- 3 files changed, 89 insertions(+), 94 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index bf4713e344ca8..3998602611184 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -67,10 +67,10 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: Set[str] = set() - # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - # Request id -> RunningRequestData - self.running_reqs_data: Dict[str, RunningRequestData] = {} + # Request id -> CachedRequestData + self._cached_reqs_data: Dict[str, CachedRequestData] = {} # Encoder-related. # Calculate encoder cache size if applicable @@ -116,8 +116,7 @@ def schedule(self) -> "SchedulerOutput": # First, schedule the RUNNING requests. req_index = 0 - while req_index < len(self.running): - assert token_budget > 0 + while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) @@ -129,7 +128,14 @@ def schedule(self) -> "SchedulerOutput": request.num_computed_tokens, num_new_tokens, encoder_budget)) - assert num_new_tokens > 0 + if num_new_tokens == 0: + # The request cannot be scheduled because the encoder budget + # or the encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue while True: new_blocks = self.kv_cache_manager.allocate_slots( @@ -176,11 +182,9 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: - while self.waiting: + while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - if token_budget == 0: - break request = self.waiting[0] # Get already-cached tokens. @@ -253,7 +257,7 @@ def schedule(self) -> "SchedulerOutput": assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) == len(self.running)) + len(scheduled_running_reqs) <= len(self.running)) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. @@ -272,25 +276,28 @@ def schedule(self) -> "SchedulerOutput": for req in scheduled_new_reqs ] resumed_reqs_data = [ - ResumedRequestData.from_request( - req, req_to_new_block_ids[req.request_id], - req.num_computed_tokens) for req in scheduled_resumed_reqs + self._make_cached_request_data( + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs ] running_reqs_data = [ - self._make_running_request_data( - req, req_to_new_block_ids[req.request_id], - req.num_computed_tokens) for req in scheduled_running_reqs + self._make_cached_request_data( + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + resumed_from_preemption=False, + ) for req in scheduled_running_reqs ] - preempted_req_ids = {req.request_id for req in preempted_reqs} scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, - scheduled_resumed_reqs=resumed_reqs_data, - scheduled_running_reqs=running_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, - preempted_req_ids=preempted_req_ids, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between @@ -302,22 +309,26 @@ def schedule(self) -> "SchedulerOutput": self.finished_req_ids = set() return scheduler_output - def _make_running_request_data( + def _make_cached_request_data( self, request: Request, new_block_ids: List[int], num_computed_tokens: int, - ) -> "RunningRequestData": - # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating + resumed_from_preemption: bool, + ) -> "CachedRequestData": + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - if request.request_id in self.running_reqs_data: - req_data = self.running_reqs_data[request.request_id] + if request.request_id in self._cached_reqs_data: + req_data = self._cached_reqs_data[request.request_id] + req_data.resumed_from_preemption = resumed_from_preemption req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: - req_data = RunningRequestData.from_request(request, new_block_ids, - num_computed_tokens) - self.running_reqs_data[request.request_id] = req_data + req_data = CachedRequestData.from_request(request, + resumed_from_preemption, + new_block_ids, + num_computed_tokens) + self._cached_reqs_data[request.request_id] = req_data return req_data def _try_schedule_encoder_inputs( @@ -406,7 +417,13 @@ def update_from_output( # expensive operations inside the loop. for request in self.running: req_id = request.request_id - request.num_computed_tokens += num_scheduled_tokens[req_id] + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + request.num_computed_tokens += num_tokens_scheduled # When the request's num_computed_tokens catches up its num_tokens, # the request generates output tokens. Otherwise, we ignore the # sampler output for the request. @@ -515,7 +532,7 @@ def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) self.encoder_cache_manager.free(request) - self.running_reqs_data.pop(request.request_id, None) + self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) @@ -570,30 +587,10 @@ def from_request( @dataclass -class ResumedRequestData: - - req_id: str - block_ids: List[int] - num_computed_tokens: int - - @classmethod - def from_request( - cls, - request: Request, - block_ids: List[int], - num_computed_tokens: int, - ) -> "ResumedRequestData": - return cls( - req_id=request.request_id, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - ) - - -@dataclass -class RunningRequestData: +class CachedRequestData: req_id: str + resumed_from_preemption: bool new_block_ids: List[int] num_computed_tokens: int @@ -601,11 +598,13 @@ class RunningRequestData: def from_request( cls, request: Request, + resumed_from_preemption: bool, new_block_ids: List[int], num_computed_tokens: int, - ) -> "RunningRequestData": + ) -> "CachedRequestData": return cls( req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) @@ -615,14 +614,12 @@ def from_request( class SchedulerOutput: scheduled_new_reqs: List[NewRequestData] - scheduled_resumed_reqs: List[ResumedRequestData] - scheduled_running_reqs: List[RunningRequestData] + scheduled_cached_reqs: List[CachedRequestData] num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] num_common_prefix_blocks: int - preempted_req_ids: Set[str] finished_req_ids: Set[str] free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8d0785243c716..f520ee9586c5c 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -46,6 +46,8 @@ def append_row( start: int, block_ids: List[int], ) -> None: + if not block_ids: + return num_blocks = len(block_ids) self.block_table_np[row_idx, start:start + num_blocks] = block_ids self.num_blocks_per_row[row_idx] = start + num_blocks diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1958877de1cc..7f40ec19c7a73 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -205,7 +205,7 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove stopped requests from the cached states. # Keep the states of the preempted requests. for req_id in scheduler_output.finished_req_ids: @@ -239,26 +239,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if req_index is not None: removed_req_indices.append(req_index) - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) - req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -313,14 +293,36 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_ids_to_add.append(req_id) - # Update the cached states of the resumed requests. - for res_req_data in scheduler_output.scheduled_resumed_reqs: - req_id = res_req_data.req_id + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id req_state = self.requests[req_id] - req_state.block_ids = res_req_data.block_ids - req_state.num_computed_tokens = res_req_data.num_computed_tokens - req_ids_to_add.append(req_id) + # Update the cached states. + req_state.num_computed_tokens = req_data.num_computed_tokens + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + start_index = len(req_state.block_ids) - len( + req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -338,6 +340,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) + return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -609,22 +612,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _prepare_sampling( self, - scheduler_output: "SchedulerOutput", + batch_changed: bool, ) -> SamplingMetadata: - skip_copy = True - if (scheduler_output.finished_req_ids - or scheduler_output.preempted_req_ids): - skip_copy = False - if (scheduler_output.scheduled_new_reqs - or scheduler_output.scheduled_resumed_reqs): - skip_copy = False # Create the sampling metadata. req_id_output_token_ids: Dict[str, List[int]] = \ {req_id: req.output_token_ids \ for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy) + req_id_output_token_ids, skip_copy=not batch_changed) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -723,7 +719,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: - self._update_states(scheduler_output) + batch_changed = self._update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -786,7 +782,7 @@ def execute_model( logits = self.model.compute_logits(hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling(scheduler_output) + sampling_metadata = self._prepare_sampling(batch_changed) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, From ab536d247b4a99e4e28b9d5bd888124d13efa3da Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Feb 2025 17:14:57 -0800 Subject: [PATCH 3/8] docstring & comment Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 6 ++++++ vllm/v1/worker/gpu_model_runner.py | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 3998602611184..fb5e83fe06274 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -256,6 +256,9 @@ def schedule(self) -> "SchedulerOutput": assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running)) @@ -590,6 +593,9 @@ def from_request( class CachedRequestData: req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_block_ids: List[int] num_computed_tokens: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f40ec19c7a73..8a87d7fa467d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -206,6 +206,16 @@ def __init__( self.seq_lens_np = self.seq_lens_cpu.numpy() def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + Returns: + True if there is a new/resumed/paused/finished request in the batch. + If False, we can skip copying SamplingMetadata to the GPU. + """ # Remove stopped requests from the cached states. # Keep the states of the preempted requests. for req_id in scheduler_output.finished_req_ids: From 64fe4b423c8c348bd596458c1ecb8c38ca51fcf5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Feb 2025 21:39:32 -0800 Subject: [PATCH 4/8] Add scheduler test Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 207 ++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 tests/v1/core/test_scheduler.py diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py new file mode 100644 index 0000000000000..a441bc9fe10ce --- /dev/null +++ b/tests/v1/core/test_scheduler.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, +) -> Scheduler: + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + cache_config.num_gpu_blocks = 10000 + return Scheduler(scheduler_config, + model_config, + cache_config, + lora_config=None) + + +def create_requests( + num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[List[PlaceholderRange]] = None, +): + sampling_params = SamplingParams() + requests = [] + is_multimodal = mm_positions is not None + for i in range(num_requests): + if is_multimodal: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=None, + arrival_time=0, + ) + requests.append(request) + return requests + + +def test_add_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + assert request.request_id in scheduler.requests + assert len(scheduler.waiting) == i + 1 + + +def test_finish_request(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + assert request.request_id not in scheduler.requests + assert len(scheduler.waiting) == 9 - i + + +def test_get_num_unfinished_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 + + +def test_schedule(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + +def test_schedule_multimodal_requests(): + scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf") + mm_positions = [[PlaceholderRange(offset=i, length=100)] + for i in range(10)] + requests = create_requests( + num_requests=10, + num_tokens=200, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + assert len(output.scheduled_encoder_inputs) == 10 + for req_id, encoder_input in output.scheduled_encoder_inputs.items(): + assert len(encoder_input) == 1 + + +def test_schedule_partial_requests(): + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + max_num_batched_tokens=1024, + ) + mm_positions = [[PlaceholderRange(offset=100, length=600)] + for _ in range(3)] + requests = create_requests( + num_requests=3, + num_tokens=800, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 3 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + + assert scheduler.max_num_encoder_input_tokens == 1024 + # The first request is scheduled fully. + assert output.num_scheduled_tokens[requests[0].request_id] == 800 + # The second request is scheduled partially. + # The tokens are not scheduled because of the encoder budget. + assert output.num_scheduled_tokens[requests[1].request_id] == 100 + # The third request is also scheduled partially. + # The tokens are not scheduled because of the encoder budget. + assert output.num_scheduled_tokens[requests[2].request_id] == 100 + req_to_index = { + request.request_id: i + for i, request in enumerate(requests) + } + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, + sampled_token_ids=[0] * len(requests), + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + scheduler.update_from_output(output, model_runner_output) + + # Schedule the next step. + # Only the first and second requests are scheduled. + # The third request is in the RUNNING state but not scheduled in this step + # because of the encoder budget. + output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(output.scheduled_new_reqs) == 0 + assert len(output.scheduled_cached_reqs) == 2 + assert len(output.finished_req_ids) == 0 + assert output.num_scheduled_tokens[requests[0].request_id] == 1 + assert output.num_scheduled_tokens[requests[1].request_id] == 700 + assert requests[2].request_id not in output.num_scheduled_tokens From 85d9702e251a40c2fef890911ba5ea65c30a1f59 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Feb 2025 21:43:09 -0800 Subject: [PATCH 5/8] Add docstring Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a441bc9fe10ce..515ace2bcd953 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -152,6 +152,14 @@ def test_schedule_multimodal_requests(): def test_schedule_partial_requests(): + """Test scheduling behavior with partial requests. + + This test verifies that: + 1. The scheduler can handle multiple partial requests in a single step when + constrained by encoder budget. + 2. A request in RUNNING state may be unscheduled in subsequent steps if + there is insufficient encoder budget. + """ scheduler = create_scheduler( model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, From e5d2881aebc9bd9641fa7ebe16b73c8132623730 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Feb 2025 22:00:39 -0800 Subject: [PATCH 6/8] mypy Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 515ace2bcd953..8eb08f3e842ca 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -48,9 +48,8 @@ def create_requests( ): sampling_params = SamplingParams() requests = [] - is_multimodal = mm_positions is not None for i in range(num_requests): - if is_multimodal: + if mm_positions is not None: mm_position = mm_positions[i] mm_inputs = [MultiModalKwargs({})] * len(mm_position) else: From af6fa523e44f104c4b89b71242d75f7b8b6a5e7e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Feb 2025 22:01:16 -0800 Subject: [PATCH 7/8] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8a87d7fa467d6..9109892c94c26 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -303,7 +303,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_ids_to_add.append(req_id) - # Update the states of the running requests. + # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id req_state = self.requests[req_id] From a9d9893b4332b94f06544af2c629930c65755666 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 4 Feb 2025 01:05:39 -0800 Subject: [PATCH 8/8] Fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9109892c94c26..7841fac1df34b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -216,11 +216,21 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: True if there is a new/resumed/paused/finished request in the batch. If False, we can skip copying SamplingMetadata to the GPU. """ - # Remove stopped requests from the cached states. - # Keep the states of the preempted requests. + # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: List[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) # Free the cached encoder outputs. for req_id, input_id in scheduler_output.free_encoder_input_ids: @@ -231,11 +241,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.encoder_cache.pop(req_id, None) # Remove the unscheduled requests from the persistent batch. - # NOTE(woosuk): The unscheduled requests include 1) finished requests, - # 2) preempted requests, and 3) running requests that are not scheduled - # in this step. For 1) finished requests, we will remove them from the - # persistent batch and the cached states. For 2) & 3), we will remove - # them from the persistent batch only and keep their cached states. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() unscheduled_req_ids = cached_req_ids - scheduled_req_ids @@ -243,11 +252,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. - removed_req_indices: List[int] = [] for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) + assert req_index is not None + removed_req_indices.append(req_index) req_ids_to_add: List[str] = [] # Add new requests to the cached states.