diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py new file mode 100644 index 0000000000000..8eb08f3e842ca --- /dev/null +++ b/tests/v1/core/test_scheduler.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, +) -> Scheduler: + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + cache_config.num_gpu_blocks = 10000 + return Scheduler(scheduler_config, + model_config, + cache_config, + lora_config=None) + + +def create_requests( + num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[List[PlaceholderRange]] = None, +): + sampling_params = SamplingParams() + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=None, + arrival_time=0, + ) + requests.append(request) + return requests + + +def test_add_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + assert request.request_id in scheduler.requests + assert len(scheduler.waiting) == i + 1 + + +def test_finish_request(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + assert request.request_id not in scheduler.requests + assert len(scheduler.waiting) == 9 - i + + +def test_get_num_unfinished_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 + + +def test_schedule(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + +def test_schedule_multimodal_requests(): + scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf") + mm_positions = [[PlaceholderRange(offset=i, length=100)] + for i in range(10)] + requests = create_requests( + num_requests=10, + num_tokens=200, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + assert len(output.scheduled_encoder_inputs) == 10 + for req_id, encoder_input in output.scheduled_encoder_inputs.items(): + assert len(encoder_input) == 1 + + +def test_schedule_partial_requests(): + """Test scheduling behavior with partial requests. + + This test verifies that: + 1. The scheduler can handle multiple partial requests in a single step when + constrained by encoder budget. + 2. A request in RUNNING state may be unscheduled in subsequent steps if + there is insufficient encoder budget. + """ + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + max_num_batched_tokens=1024, + ) + mm_positions = [[PlaceholderRange(offset=100, length=600)] + for _ in range(3)] + requests = create_requests( + num_requests=3, + num_tokens=800, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 3 + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + + assert scheduler.max_num_encoder_input_tokens == 1024 + # The first request is scheduled fully. + assert output.num_scheduled_tokens[requests[0].request_id] == 800 + # The second request is scheduled partially. + # The tokens are not scheduled because of the encoder budget. + assert output.num_scheduled_tokens[requests[1].request_id] == 100 + # The third request is also scheduled partially. + # The tokens are not scheduled because of the encoder budget. + assert output.num_scheduled_tokens[requests[2].request_id] == 100 + req_to_index = { + request.request_id: i + for i, request in enumerate(requests) + } + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, + sampled_token_ids=[0] * len(requests), + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + scheduler.update_from_output(output, model_runner_output) + + # Schedule the next step. + # Only the first and second requests are scheduled. + # The third request is in the RUNNING state but not scheduled in this step + # because of the encoder budget. + output = scheduler.schedule() + assert len(scheduler.running) == 3 + assert len(output.scheduled_new_reqs) == 0 + assert len(output.scheduled_cached_reqs) == 2 + assert len(output.finished_req_ids) == 0 + assert output.num_scheduled_tokens[requests[0].request_id] == 1 + assert output.num_scheduled_tokens[requests[1].request_id] == 700 + assert requests[2].request_id not in output.num_scheduled_tokens diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f4738bb33c603..fb5e83fe06274 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -67,10 +67,10 @@ def __init__( # This is flushed at the end of each scheduling step. self.finished_req_ids: Set[str] = set() - # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - # Request id -> RunningRequestData - self.running_reqs_data: Dict[str, RunningRequestData] = {} + # Request id -> CachedRequestData + self._cached_reqs_data: Dict[str, CachedRequestData] = {} # Encoder-related. # Calculate encoder cache size if applicable @@ -115,17 +115,8 @@ def schedule(self) -> "SchedulerOutput": encoder_budget = self.max_num_encoder_input_tokens # First, schedule the RUNNING requests. - # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be - # in the "partial" state, where the request has some tokens computed - # but not all. The constraint is due to the persistent batch in the - # V1 model runner. - # TODO(woosuk): Remove this constraint after refactoring model runner. - has_partial_request = False req_index = 0 - while req_index < len(self.running): - # Only the last request in the RUNNING queue can be "partial". - assert not has_partial_request - assert token_budget > 0 + while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) @@ -137,7 +128,14 @@ def schedule(self) -> "SchedulerOutput": request.num_computed_tokens, num_new_tokens, encoder_budget)) - assert num_new_tokens > 0 + if num_new_tokens == 0: + # The request cannot be scheduled because the encoder budget + # or the encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue while True: new_blocks = self.kv_cache_manager.allocate_slots( @@ -172,8 +170,6 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 - has_partial_request = (request.num_computed_tokens + num_new_tokens - < request.num_tokens) # Encoder-related. if encoder_inputs_to_schedule: @@ -186,13 +182,9 @@ def schedule(self) -> "SchedulerOutput": # Next, schedule the WAITING requests. if not preempted_reqs: - while self.waiting: - if has_partial_request: - break + while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - if token_budget == 0: - break request = self.waiting[0] # Get already-cached tokens. @@ -249,8 +241,6 @@ def schedule(self) -> "SchedulerOutput": token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - has_partial_request = (num_computed_tokens + num_new_tokens - < request.num_tokens) # Encoder-related. if encoder_inputs_to_schedule: @@ -266,8 +256,11 @@ def schedule(self) -> "SchedulerOutput": assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) == len(self.running)) + len(scheduled_running_reqs) <= len(self.running)) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. @@ -286,25 +279,28 @@ def schedule(self) -> "SchedulerOutput": for req in scheduled_new_reqs ] resumed_reqs_data = [ - ResumedRequestData.from_request( - req, req_to_new_block_ids[req.request_id], - req.num_computed_tokens) for req in scheduled_resumed_reqs + self._make_cached_request_data( + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs ] running_reqs_data = [ - self._make_running_request_data( - req, req_to_new_block_ids[req.request_id], - req.num_computed_tokens) for req in scheduled_running_reqs + self._make_cached_request_data( + req, + req_to_new_block_ids[req.request_id], + req.num_computed_tokens, + resumed_from_preemption=False, + ) for req in scheduled_running_reqs ] - preempted_req_ids = {req.request_id for req in preempted_reqs} scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, - scheduled_resumed_reqs=resumed_reqs_data, - scheduled_running_reqs=running_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, - preempted_req_ids=preempted_req_ids, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between @@ -316,22 +312,26 @@ def schedule(self) -> "SchedulerOutput": self.finished_req_ids = set() return scheduler_output - def _make_running_request_data( + def _make_cached_request_data( self, request: Request, new_block_ids: List[int], num_computed_tokens: int, - ) -> "RunningRequestData": - # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating + resumed_from_preemption: bool, + ) -> "CachedRequestData": + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - if request.request_id in self.running_reqs_data: - req_data = self.running_reqs_data[request.request_id] + if request.request_id in self._cached_reqs_data: + req_data = self._cached_reqs_data[request.request_id] + req_data.resumed_from_preemption = resumed_from_preemption req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: - req_data = RunningRequestData.from_request(request, new_block_ids, - num_computed_tokens) - self.running_reqs_data[request.request_id] = req_data + req_data = CachedRequestData.from_request(request, + resumed_from_preemption, + new_block_ids, + num_computed_tokens) + self._cached_reqs_data[request.request_id] = req_data return req_data def _try_schedule_encoder_inputs( @@ -420,7 +420,13 @@ def update_from_output( # expensive operations inside the loop. for request in self.running: req_id = request.request_id - request.num_computed_tokens += num_scheduled_tokens[req_id] + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + request.num_computed_tokens += num_tokens_scheduled # When the request's num_computed_tokens catches up its num_tokens, # the request generates output tokens. Otherwise, we ignore the # sampler output for the request. @@ -529,7 +535,7 @@ def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) self.encoder_cache_manager.free(request) - self.running_reqs_data.pop(request.request_id, None) + self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] self.finished_req_ids.add(request.request_id) @@ -584,30 +590,13 @@ def from_request( @dataclass -class ResumedRequestData: - - req_id: str - block_ids: List[int] - num_computed_tokens: int - - @classmethod - def from_request( - cls, - request: Request, - block_ids: List[int], - num_computed_tokens: int, - ) -> "ResumedRequestData": - return cls( - req_id=request.request_id, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - ) - - -@dataclass -class RunningRequestData: +class CachedRequestData: req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. + resumed_from_preemption: bool new_block_ids: List[int] num_computed_tokens: int @@ -615,11 +604,13 @@ class RunningRequestData: def from_request( cls, request: Request, + resumed_from_preemption: bool, new_block_ids: List[int], num_computed_tokens: int, - ) -> "RunningRequestData": + ) -> "CachedRequestData": return cls( req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, ) @@ -629,14 +620,12 @@ def from_request( class SchedulerOutput: scheduled_new_reqs: List[NewRequestData] - scheduled_resumed_reqs: List[ResumedRequestData] - scheduled_running_reqs: List[RunningRequestData] + scheduled_cached_reqs: List[CachedRequestData] num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] num_common_prefix_blocks: int - preempted_req_ids: Set[str] finished_req_ids: Set[str] free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8d0785243c716..f520ee9586c5c 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -46,6 +46,8 @@ def append_row( start: int, block_ids: List[int], ) -> None: + if not block_ids: + return num_blocks = len(block_ids) self.block_table_np[row_idx, start:start + num_blocks] = block_ids self.num_blocks_per_row[row_idx] = start + num_blocks diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0b5644525553e..7841fac1df34b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -205,12 +205,32 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - # Remove stopped requests from the cached states. - # Keep the states of the preempted requests. + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + Returns: + True if there is a new/resumed/paused/finished request in the batch. + If False, we can skip copying SamplingMetadata to the GPU. + """ + # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + removed_req_indices: List[int] = [] + for req_id in scheduler_output.finished_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) # Free the cached encoder outputs. for req_id, input_id in scheduler_output.free_encoder_input_ids: @@ -220,36 +240,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if not encoder_outputs: self.encoder_cache.pop(req_id, None) - # Remove the requests from the persistent batch. - stopped_req_ids = set().union( - scheduler_output.preempted_req_ids, - scheduler_output.finished_req_ids, - ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Update the states of the running requests. - for req_data in scheduler_output.scheduled_running_reqs: - req_id = req_data.req_id - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) + assert req_index is not None + removed_req_indices.append(req_index) req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -305,14 +311,36 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_ids_to_add.append(req_id) - # Update the cached states of the resumed requests. - for res_req_data in scheduler_output.scheduled_resumed_reqs: - req_id = res_req_data.req_id + # Update the states of the running/resumed requests. + for req_data in scheduler_output.scheduled_cached_reqs: + req_id = req_data.req_id req_state = self.requests[req_id] - req_state.block_ids = res_req_data.block_ids - req_state.num_computed_tokens = res_req_data.num_computed_tokens - req_ids_to_add.append(req_id) + # Update the cached states. + req_state.num_computed_tokens = req_data.num_computed_tokens + if not req_data.resumed_from_preemption: + # Append the new blocks to the existing block IDs. + req_state.block_ids.extend(req_data.new_block_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = req_data.new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + start_index = len(req_state.block_ids) - len( + req_data.new_block_ids) + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -330,6 +358,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) + return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -536,10 +565,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. + # NOTE(woosuk): Due to chunked prefills, the batch may contain partial + # requests. While we should not sample any token from these partial + # requests, we do so for simplicity. We will ignore the sampled + # tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices @@ -601,22 +630,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _prepare_sampling( self, - scheduler_output: "SchedulerOutput", + batch_changed: bool, ) -> SamplingMetadata: - skip_copy = True - if (scheduler_output.finished_req_ids - or scheduler_output.preempted_req_ids): - skip_copy = False - if (scheduler_output.scheduled_new_reqs - or scheduler_output.scheduled_resumed_reqs): - skip_copy = False # Create the sampling metadata. req_id_output_token_ids: Dict[str, List[int]] = \ {req_id: req.output_token_ids \ for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy) + req_id_output_token_ids, skip_copy=not batch_changed) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -715,7 +737,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: - self._update_states(scheduler_output) + batch_changed = self._update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -778,7 +800,7 @@ def execute_model( logits = self.model.compute_logits(hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling(scheduler_output) + sampling_metadata = self._prepare_sampling(batch_changed) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata,