Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] LoRA Support #10957

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,20 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[True, False])
def run_with_both_engines_lora(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
monkeypatch.setenv('VLLM_USE_V1', '0')

yield
8 changes: 8 additions & 0 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
Expand Down
13 changes: 13 additions & 0 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
Expand All @@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
Expand All @@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
Expand Down
12 changes: 12 additions & 0 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List

import pytest
import ray

import vllm
Expand Down Expand Up @@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

Expand All @@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files):
generate_and_test(llm, sql_lora_files)


# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks
# used by the engine yet.
@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and
Expand Down
11 changes: 11 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
Expand Down
13 changes: 13 additions & 0 deletions tests/lora/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def format_prompt_tuples(prompt):
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys():

# Test with no overlap
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys == ()
assert extra_keys is None
assert next_mm_idx == 1

# Test with multiple extra keys
Expand Down
8 changes: 5 additions & 3 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -1043,7 +1042,10 @@ def _get_logits(
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)

# Gather logits for TP
logits = self.base_layer._gather_logits(logits)

if logits is None:
return None

Expand Down
28 changes: 17 additions & 11 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.

parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or envs.VLLM_USE_V1 \
Expand Down Expand Up @@ -88,6 +87,20 @@ def forward(

return logits

def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor : introduce _gather_logits() that LogitsProcessorWithLoRA also uses.

"""gather/all-gather the logits tensor across model parallel group."""
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
return logits

def _get_logits(
self,
hidden_states: torch.Tensor,
Expand All @@ -99,16 +112,9 @@ def _get_logits(
hidden_states,
bias=embedding_bias)

if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
# Gather logits for TP
logits = self._gather_logits(logits)

# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
Expand Down
Loading