Skip to content

Commit

Permalink
Address some feedback from reviewer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan committed Jul 16, 2024
1 parent 13a9b59 commit b5e9737
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
28 changes: 14 additions & 14 deletions tests/lora/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from huggingface_hub.utils import HfHubHTTPError
from torch import nn

from vllm.lora.utils import (get_lora_absolute_path,
from vllm.lora.utils import (get_adapter_absolute_path,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache

Expand Down Expand Up @@ -187,53 +187,53 @@ def test_lru_cache():
assert 6 in cache


# Unit tests for get_lora_absolute_path
# Unit tests for get_adapter_absolute_path
@patch('os.path.isabs')
def test_get_lora_absolute_path_absolute(mock_isabs):
def test_get_adapter_absolute_path_absolute(mock_isabs):
path = '/absolute/path/to/lora'
mock_isabs.return_value = True
assert get_lora_absolute_path(path) == path
assert get_adapter_absolute_path(path) == path


@patch('os.path.expanduser')
def test_get_lora_absolute_path_expanduser(mock_expanduser):
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
# Path with ~ that needs to be expanded
path = '~/relative/path/to/lora'
absolute_path = '/home/user/relative/path/to/lora'
mock_expanduser.return_value = absolute_path
assert get_lora_absolute_path(path) == absolute_path
assert get_adapter_absolute_path(path) == absolute_path


@patch('os.path.exists')
@patch('os.path.abspath')
def test_get_lora_absolute_path_local_existing(mock_abspath, mock_exist):
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
# Relative path that exists locally
path = 'relative/path/to/lora'
absolute_path = '/absolute/path/to/lora'
mock_exist.return_value = True
mock_abspath.return_value = absolute_path
assert get_lora_absolute_path(path) == absolute_path
assert get_adapter_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_lora_absolute_path_huggingface(mock_exist,
mock_snapshot_download):
def test_get_adapter_absolute_path_huggingface(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier
path = 'org/repo'
absolute_path = '/mock/snapshot/path'
mock_exist.return_value = False
mock_snapshot_download.return_value = absolute_path
assert get_lora_absolute_path(path) == absolute_path
assert get_adapter_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_lora_absolute_path_huggingface_error(mock_exist,
mock_snapshot_download):
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier with download error
path = 'org/repo'
mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info")
assert get_lora_absolute_path(path) == path
assert get_adapter_absolute_path(path) == path
4 changes: 2 additions & 2 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def get_adapter_absolute_path(lora_path: str) -> str:
local_snapshot_path = huggingface_hub.snapshot_download(
repo_id=lora_path)
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
HFValidationError) as e:
HFValidationError):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
print(f"Error downloading the Hugging Face model: {e}")
logger.exception("Error downloading the HuggingFace model")
return lora_path

return local_snapshot_path
4 changes: 2 additions & 2 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_lora_absolute_path
from vllm.lora.utils import get_adapter_absolute_path

logger = init_logger(__name__)

Expand Down Expand Up @@ -90,7 +90,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora_path = get_lora_absolute_path(lora_request.lora_path)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
Expand Down

0 comments on commit b5e9737

Please sign in to comment.