Skip to content

Commit

Permalink
Add test for loading lora from huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan committed Jul 12, 2024
1 parent 9f04976 commit 031d7b4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


@pytest.fixture(scope="session")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
Expand Down
39 changes: 39 additions & 0 deletions tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM

# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]


@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)

lora_path = get_adapter_absolute_path(lora_name)

# lora loading should work for either absolute path and hugggingface id.
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)

# Assertions to ensure the model is loaded correctly
assert lora_model is not None, "LoRAModel is not loaded correctly"

0 comments on commit 031d7b4

Please sign in to comment.