diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 071cdbecc689a..5ea66518b4112 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -306,3 +306,20 @@ def get_model_patched(**kwargs): def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. model_runner.model) + + +@pytest.fixture(params=[True, False]) +def run_with_both_engines_lora(request, monkeypatch): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + monkeypatch.setenv('VLLM_USE_V1', '1') + else: + monkeypatch.setenv('VLLM_USE_V1', '0') + + yield diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 249f7619d6246..d39925948048e 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_baichuan_lora(baichuan_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 0aa9fe7a94948..ee09afe86777d 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -2,6 +2,8 @@ from typing import List +import pytest + import vllm from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest @@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.mark.skip_v1 @fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, @@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_chatglm3_lora_tp4(chatglm3_lora_files): @@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 8923aa2210a55..a1b4c897c45ef 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail(current_platform.is_rocm(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 39f779f400ca3..564818f23fd24 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -2,6 +2,7 @@ from typing import List +import pytest import ray import vllm @@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files): print("removing lora") +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_llama_lora(sql_lora_files): @@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files): generate_and_test(llm, sql_lora_files) +# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks +# used by the engine yet. +@pytest.mark.skip_v1 @fork_new_process_for_each_test def test_llama_lora_warmup(sql_lora_files): """Test that the LLM initialization works with a warmup LORA path and diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py index cbdd688311d74..3a7b391692cc6 100644 --- a/tests/lora/test_lora_bias_e2e.py +++ b/tests/lora/test_lora_bias_e2e.py @@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +# Skipping for V1 for now as we are hitting, +# "Head size 80 is not supported by FlashAttention." error. +@pytest.mark.skip_v1 @pytest.mark.parametrize("lora_bias", [True]) @pytest.mark.parametrize("fully_sharded", [True, False]) def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 651c89ffce2de..8999e0cf31906 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -2,6 +2,8 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +# Skipping for V1 for now as we are hitting, +# "Head size 80 is not supported by FlashAttention." error. +@pytest.mark.skip_v1 def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5702aa26bd916..7f687f563eb8e 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -70,6 +70,14 @@ def format_prompt_tuples(prompt): return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 60cf4384d3fde..8df4cbe1be71b 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys(): # Test with no overlap extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0) - assert extra_keys == () + assert extra_keys is None assert next_mm_idx == 1 # Test with multiple extra keys diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 9f0297596ccbf..9826aeb9dc27e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -16,8 +16,7 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_gather) + tensor_model_parallel_all_reduce) from vllm.distributed.utils import divide # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -1043,7 +1042,10 @@ def _get_logits( logits = lm_head.linear_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + if logits is None: return None diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index cdc67ca83d489..0565c6e8be381 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -51,7 +51,6 @@ def __init__(self, # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - parallel_config = get_current_vllm_config().parallel_config self.use_all_gather = current_platform.is_tpu() \ or envs.VLLM_USE_V1 \ @@ -88,6 +87,20 @@ def forward( return logits + def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: + """gather/all-gather the logits tensor across model parallel group.""" + if self.use_all_gather: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + return logits + def _get_logits( self, hidden_states: torch.Tensor, @@ -99,16 +112,9 @@ def _get_logits( hidden_states, bias=embedding_bias) - if self.use_all_gather: - # Gather is not supported for some devices such as TPUs. - # Use all-gather instead. - # NOTE(woosuk): Here, the outputs of every device should not be None - # because XLA requires strict SPMD among all devices. Every device - # should execute the same operations after gathering the logits. - logits = tensor_model_parallel_all_gather(logits) - else: - # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) + # Gather logits for TP + logits = self._gather_logits(logits) + # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e0976ba8577b9..6888f1a3e1823 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -170,14 +170,28 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: return ret -def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: - """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA ID). - For multi-modal inputs, the extra keys are (mm_hash, start_offset) that - indicate a mm input contained in the block and its starting offset in - the block tokens. +def need_extra_keys(request: Request) -> bool: + """Check whether the blocks allocated to this request need extra hash keys. + + Args: + request (Request): The request. + + Returns: + bool: Whether blocks allocated to this request need extra hash keys. + """ + + # Multimodal requests need to include the MM hash. + # LoRA requests need to include the LoRA ID. + return bool(request.mm_positions) or (request.lora_request is not None) + + +def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, + end_token_idx: int, + start_mm_idx: int) -> Tuple[List[Any], int]: + """Generate extra keys related to MultiModal request for block hash + computation. For multi-modal inputs, the extra keys are + (mm_hash, start_offset) that indicate a mm input contained in the + block and its starting offset in the block tokens. Args: request: The request object. @@ -188,10 +202,11 @@ def generate_block_hash_extra_keys( Returns: A tuple of extra keys and the next multi-modal index. """ + extra_keys: List[Any] = [] mm_positions, mm_hashes = request.mm_positions, request.mm_hashes if not mm_positions: - return None, start_mm_idx + return extra_keys, start_mm_idx if mm_positions and len(mm_positions) != len(mm_hashes): raise ValueError( @@ -204,14 +219,13 @@ def generate_block_hash_extra_keys( # range. This usually happens in the late prefill phase and decoding phase. if mm_positions[-1]["offset"] + mm_positions[-1][ "length"] < start_token_idx: - return None, start_mm_idx + return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. if start_mm_idx < 0: assert -start_mm_idx <= len(mm_positions) start_mm_idx = len(mm_positions) + start_mm_idx - extra_keys = [] curr_mm_idx = start_mm_idx while mm_positions and curr_mm_idx < len(mm_positions): assert mm_hashes[curr_mm_idx] is not None @@ -237,7 +251,50 @@ def generate_block_hash_extra_keys( else: # This block has not reached the current mm input. break - return tuple(extra_keys), curr_mm_idx + return extra_keys, curr_mm_idx + + +def _gen_lora_extra_hash_keys(request: Request) -> List[int]: + """Generate extra keys related to LoRA for block hash computation. + + Args: + request: The request object. + + Returns: + Return LoRA id of the request if it is a LoRA request. Return empty + list otherwise. + """ + if not request.lora_request: + return [] + return [request.lora_request.lora_int_id] + + +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + mm_extra_keys: List[Any] + mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( + request, start_token_idx, end_token_idx, start_mm_idx) + lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request) + + extra_keys: List[Any] = lora_extra_keys + mm_extra_keys + + if not extra_keys: + return None, new_start_mm_idx + + return tuple(extra_keys), new_start_mm_idx def hash_block_tokens( @@ -249,9 +306,6 @@ def hash_block_tokens( prefix caching. We use LRU cache for this function to avoid recomputing hash values for the same block contents. - TODO: Support arbitrary metadata so that we could support more - features such as LoRA adapter. - Args: parent_block_hash: The hash of the parent block. None if this is the first block. @@ -291,14 +345,9 @@ def hash_request_tokens(block_size: int, The list of computed hash values. """ token_ids = request.all_token_ids - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes - if mm_positions and len(mm_positions) != len(mm_hashes): - raise ValueError( - "The number of multi-modal positions and hashes must match.") - # TODO: Extend this to support other features such as LoRA. - need_extra_keys = bool(mm_positions) - extra_keys = None + req_need_extra_keys = need_extra_keys(request) + req_extra_keys = None curr_mm_idx = 0 ret = [] @@ -310,13 +359,13 @@ def hash_request_tokens(block_size: int, if len(block_token_ids) < block_size: break - # Add extra keys if the block is a multi-modal block. - if need_extra_keys: - extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + if req_need_extra_keys: + # MM and LoRA requests need extra keys for block-hash computation. + req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids, extra_keys) + block_token_ids, req_extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index fb5e83fe06274..6c44fec6439e7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -7,6 +7,7 @@ from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) @@ -35,8 +36,6 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config - # TODO: Support LoRA. - assert lora_config is None, "V1 does not support LoRA yet." # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -180,6 +179,14 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Record the LoRAs in scheduled_running_reqs + requested_loras: Set[int] = set() + if self.lora_config: + requested_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(requested_loras) <= self.lora_config.max_loras + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting and token_budget > 0: @@ -187,6 +194,23 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request: + req_lora_id = request.lora_request.lora_int_id + if len(requested_loras) == self.lora_config.max_loras and ( + req_lora_id not in requested_loras): + # Cannot schedule. + # TODO (varun): This means all the other requests in + # the WAITING queue will be blocked by this request, + # even if, + # 1. these other requests do not use LoRA, or, + # 2. these other requests use the already requested + # LoRAs. + # This is too conservative and could be optimized. + break + # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) @@ -234,6 +258,8 @@ def schedule(self) -> "SchedulerOutput": raise RuntimeError( f"Invalid request status: {request.status}") + if self.lora_config and request.lora_request: + requested_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = [ b.block_id for b in computed_blocks + new_blocks ] @@ -568,6 +594,7 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int + lora_request: Optional[LoRARequest] @classmethod def from_request( @@ -586,6 +613,7 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 39708f833fd58..a31e888656166 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,11 +3,12 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata @@ -35,6 +36,8 @@ class CachedRequestState: mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None + lora_request: Optional[LoRARequest] = None + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) @@ -161,6 +164,12 @@ def __init__( ] self.prompt_token_ids: Optional[torch.Tensor] = None + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_id_to_request_ids: Dict[int, Set[str]] = {} + self.lora_id_to_lora_request: Dict[int, LoRARequest] = {} + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -235,6 +244,19 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: @@ -251,6 +273,16 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) + + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + self.lora_id_to_request_ids[lora_id].discard(req_id) + if len(self.lora_id_to_request_ids[lora_id]) == 0: + self.lora_id_to_request_ids.pop(lora_id) + self.lora_id_to_lora_request.pop(lora_id) + self.request_lora_mapping[req_index] = 0 + return req_index def clear(self) -> None: @@ -266,6 +298,9 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() + self.request_lora_mapping.fill(0) + self.lora_id_to_lora_request.clear() + self.lora_id_to_request_ids.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -318,6 +353,9 @@ def condense(self, empty_req_indices: List[int]) -> None: if generator is not None: self.generators[empty_index] = generator + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -401,6 +439,29 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + active_lora_requests: Set[LoRARequest] = set( + self.lora_id_to_lora_request.values()) + + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7841fac1df34b..c67a8b73f548a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -40,7 +41,7 @@ logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -279,6 +280,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], + lora_request=new_req_data.lora_request, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -372,15 +374,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] + num_scheduled_tokens_list: List[int] = [] max_num_scheduled_tokens = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) + num_scheduled_tokens_list.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) assert max_num_scheduled_tokens > 0 # Get request indices. @@ -565,6 +568,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, ) + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + # NOTE(woosuk): Due to chunked prefills, the batch may contain partial # requests. While we should not sample any token from these partial # requests, we do so for simplicity. We will ignore the sampled @@ -867,6 +875,12 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -1005,14 +1019,32 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) - logits = self.model.compute_logits(hidden_states, None) - logits = logits[:self.max_num_tokens] - # TODO(woosuk): Consider the memory usage of the sampler. - torch.cuda.synchronize() - del hidden_states, logits - self.encoder_cache.clear() + # For profile, have maximum num_reqs and that collectively have + # maximum num_tokens. + num_reqs = self.scheduler_config.max_num_seqs + num_tokens = self.max_num_tokens + min_tokens_per_req: int = num_tokens // num_reqs + + num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + + with self.maybe_profile_with_lora(self.lora_config, + num_scheduled_tokens): + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.max_num_tokens, + dummy_kv_caches) + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + torch.cuda.synchronize() + del hidden_states, logits + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000000000..e7501ad2ea168 --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Define LoRA functionality mixin for model runners. +""" + +from contextlib import contextmanager +from typing import Set, Tuple + +import numpy as np +import torch.nn as nn + +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + LORA_WARMUP_RANK = 8 + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + assert supports_lora( + model), f"{model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(model.config, "max_position_embeddings"): + max_pos_embeddings = model.config.max_position_embeddings + else: + max_pos_embeddings = ( + model.config.text_config.max_position_embeddings) + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], + token_lora_mapping: Tuple[int, ...], + lora_requests: Set[LoRARequest]) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + # We dont make any distinction between prefills and decodes in the + # scheduler. To that effect, set is_prefill to True so we use the + # sgmv punica kernels always. + lora_mapping = LoRAMapping(token_lora_mapping, + prompt_lora_mapping, + is_prefill=True) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def set_active_loras(self, input_batch: InputBatch, + num_scheduled_tokens: np.ndarray) -> None: + + prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: Tuple[int, + ...] # of size np.sum(num_scheduled_tokens) + lora_requests: Set[LoRARequest] + prompt_lora_mapping, token_lora_mapping, lora_requests = \ + input_batch.make_lora_inputs(num_scheduled_tokens) + return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + + @contextmanager + def maybe_profile_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_reqs = len(num_scheduled_tokens) + num_loras = lora_config.max_loras + + # Make prompt lora mapping + # Assign LoRA IDs cyclically to simulate a worst-case scenario. + prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % + num_loras) + 1 + + # Make token lora mapping + token_lora_mapping = np.repeat(prompt_lora_mapping, + num_scheduled_tokens) + + # Make dummy lora requests + lora_requests: Set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), + lora_requests) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters()