Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] V1 LoRA support #10579

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
506 changes: 506 additions & 0 deletions benchmarks/benchmark_lora_throughput.py

Large diffs are not rendered by default.

25 changes: 23 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.lora.request import LoRARequest

if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
Expand All @@ -30,8 +31,6 @@ def __init__(
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."

num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
Expand Down Expand Up @@ -170,6 +169,15 @@ def schedule(self) -> "SchedulerOutput":
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

# Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set()
if self.lora_config:
requested_loras = \
set(req.lora_request.lora_int_id \
for req in scheduled_running_reqs \
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras

# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
Expand All @@ -181,6 +189,17 @@ def schedule(self) -> "SchedulerOutput":
break

request = self.waiting[0]

# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
req_lora_id = request.lora_request.lora_int_id
if len(requested_loras) == self.lora_config.max_loras and \
req_lora_id not in requested_loras:
# cannot schedule
break
requested_loras.add(req_lora_id)

# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
Expand Down Expand Up @@ -513,6 +532,7 @@ class NewRequestData:
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
lora_request: Optional[LoRARequest]

@classmethod
def from_request(
Expand All @@ -530,6 +550,7 @@ def from_request(
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=request.lora_request,
)


Expand Down
39 changes: 39 additions & 0 deletions vllm/v1/worker/cached_request_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

from dataclasses import dataclass
from typing import List, Optional, TYPE_CHECKING

from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams
import torch
from vllm.v1.core.scheduler import RunningRequestData
from vllm.lora.request import LoRARequest


if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange

@dataclass
class CachedRequestState:

req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List[MultiModalKwargs]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]

block_ids: List[int]
num_computed_tokens: int
output_token_ids: List[int]

lora_request: Optional[LoRARequest]

@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)

def update(self, req_data: RunningRequestData) -> None:
self.num_computed_tokens = req_data.num_computed_tokens
if len(req_data.new_block_ids):
self.block_ids.extend(req_data.new_block_ids)
Loading