Skip to content

Commit

Permalink
[V1] Get input tokens from scheduler (#13339)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
  • Loading branch information
WoosukKwon authored Feb 17, 2025
1 parent ce77eb9 commit 4c21ce9
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 139 deletions.
1 change: 1 addition & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=0,
)
Expand Down
43 changes: 28 additions & 15 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput":
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}

# For logging.
scheduled_timestamp = time.monotonic()

# First, schedule the RUNNING requests.
Expand Down Expand Up @@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput":
token_budget -= num_new_tokens
req_index += 1

# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])

# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
Expand All @@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput":
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

# Speculative decode related.
if request.spec_token_ids:
scheduled_spec_decode_tokens[
request.request_id] = request.spec_token_ids

# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
if self.lora_config:
Expand Down Expand Up @@ -324,23 +330,24 @@ def schedule(self) -> "SchedulerOutput":
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id],
req.num_computed_tokens)
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
req.num_computed_tokens,
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
req.num_computed_tokens,
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
Expand All @@ -349,8 +356,8 @@ def schedule(self) -> "SchedulerOutput":
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
Expand All @@ -366,22 +373,28 @@ def schedule(self) -> "SchedulerOutput":
def _make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: List[int],
num_computed_tokens: int,
resumed_from_preemption: bool,
) -> "CachedRequestData":
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
if request.request_id in self._cached_reqs_data:
req_data = self._cached_reqs_data[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_block_ids,
num_computed_tokens)
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data

Expand Down
15 changes: 8 additions & 7 deletions vllm/v1/core/scheduler_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def from_request(
cls,
request: "Request",
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
Expand All @@ -41,7 +40,7 @@ def from_request(
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)

Expand All @@ -54,6 +53,7 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: List[int]
new_block_ids: List[int]
num_computed_tokens: int

Expand All @@ -62,14 +62,15 @@ def from_request(
cls,
request: "Request",
resumed_from_preemption: bool,
new_token_ids: List[int],
new_block_ids: List[int],
num_computed_tokens: int,
) -> "CachedRequestData":
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_computed_tokens=request.num_computed_tokens,
)


Expand All @@ -91,9 +92,9 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_decode_tokens
# If a request does not have any spec decode tokens, it will
# not be included in the dictionary.
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: Dict[str, List[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
Expand Down
Loading

0 comments on commit 4c21ce9

Please sign in to comment.