Skip to content

Commit

Permalink
[Misc] Infer model name from model_uri and check AWS credential (#250)
Browse files Browse the repository at this point in the history
* refact: remove env DOWNLOADER_MODEL_NAME

* style

* test: fix test case

* docs: update doc about remove DOWNLOADER_MODEL_NAME

* misc: add ak sk check in s3 download
  • Loading branch information
brosoul authored Sep 27, 2024
1 parent 72a7087 commit b664587
Show file tree
Hide file tree
Showing 15 changed files with 99 additions and 57 deletions.
3 changes: 0 additions & 3 deletions docs/source/features/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ First Define the necessary environment variables for the HuggingFace model.
.. code-block:: bash
# General settings
export DOWNLOADER_MODEL_NAME=deepseek-ai/deepseek-coder-6.7b-instruct/
export DOWNLOADER_ALLOW_FILE_SUFFIX=json, safetensors
export DOWNLOADER_NUM_THREADS=16
# HuggingFace settings
Expand All @@ -70,7 +69,6 @@ First Define the necessary environment variables for the S3 model.
.. code-block:: bash
# General settings
export DOWNLOADER_MODEL_NAME=deepseek-ai/deepseek-coder-6.7b-instruct/
export DOWNLOADER_ALLOW_FILE_SUFFIX=json, safetensors
export DOWNLOADER_NUM_THREADS=16
# AWS settings
Expand All @@ -96,7 +94,6 @@ First Define the necessary environment variables for the TOS model.
.. code-block:: bash
# General settings
export DOWNLOADER_MODEL_NAME=deepseek-ai/deepseek-coder-6.7b-instruct/
export DOWNLOADER_ALLOW_FILE_SUFFIX=json, safetensors
export DOWNLOADER_NUM_THREADS=16
# AWS settings
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial/runtime/runtime-hf-download.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ spec:
- --port
- "8000"
- --model
- /models/deepseek-ai/deepseek-coder-6.7b-instruct
- /models/deepseek-coder-6.7b-instruct
- --served-model-name
- deepseek-ai/deepseek-coder-6.7b-instruct
- --distributed-executor-backend
Expand Down Expand Up @@ -96,9 +96,9 @@ spec:
- deepseek-ai/deepseek-coder-6.7b-instruct
- --local-dir
- /models/
- --model-name
- deepseek-coder-6.7b-instruct
env:
- name: DOWNLOADER_MODEL_NAME
value: deepseek-ai/deepseek-coder-6.7b-instruct
- name: DOWNLOADER_ALLOW_FILE_SUFFIX
value: json, safetensors
- name: HF_TOKEN
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial/runtime/runtime-s3-download.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ spec:
- --port
- "8000"
- --model
- /models/deepseek-ai/deepseek-coder-6.7b-instruct
- /models/deepseek-coder-6.7b-instruct
- --served-model-name
- deepseek-ai/deepseek-coder-6.7b-instruct
- --distributed-executor-backend
Expand Down Expand Up @@ -96,9 +96,9 @@ spec:
- s3://<input your s3 bucket name>/<input your s3 bucket path>
- --local-dir
- /models/
- --model-name
- deepseek-coder-6.7b-instruct
env:
- name: DOWNLOADER_MODEL_NAME
value: deepseek-ai/deepseek-coder-6.7b-instruct
- name: DOWNLOADER_ALLOW_FILE_SUFFIX
value: json, safetensors
- name: AWS_ACCESS_KEY_ID
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial/runtime/runtime-tos-download.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ spec:
- --port
- "8000"
- --model
- /models/deepseek-ai/deepseek-coder-6.7b-instruct
- /models/deepseek-coder-6.7b-instruct
- --served-model-name
- deepseek-ai/deepseek-coder-6.7b-instruct
- --distributed-executor-backend
Expand Down Expand Up @@ -96,9 +96,9 @@ spec:
- tos://<input your tos bucket name>/<input your tos bucket path>
- --local-dir
- /models/
- --model-name
- deepseek-coder-6.7b-instruct
env:
- name: DOWNLOADER_MODEL_NAME
value: deepseek-ai/deepseek-coder-6.7b-instruct
- name: DOWNLOADER_ALLOW_FILE_SUFFIX
value: json, safetensors
- name: TOS_ACCESS_KEY
Expand Down
6 changes: 4 additions & 2 deletions python/aibrix/aibrix/downloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
from aibrix.downloader.base import get_downloader


def download_model(model_uri: str, local_path: Optional[str] = None):
def download_model(
model_uri: str, local_path: Optional[str] = None, model_name: Optional[str] = None
):
"""Download model from model_uri to local_path.
Args:
model_uri (str): model uri.
local_path (str): local path to save model.
"""

downloader = get_downloader(model_uri)
downloader = get_downloader(model_uri, model_name)
return downloader.download_model(local_path)


Expand Down
10 changes: 8 additions & 2 deletions python/aibrix/aibrix/downloader/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ def main():
"--local-dir",
type=str,
default=None,
help="dir to save model files",
help="base dir of the model file. If not set, it will used with env `DOWNLOADER_LOCAL_DIR`",
)
parser.add_argument(
"--model-name",
type=str,
default=None,
help="subdir of the base dir to save model files",
)
args = parser.parse_args()
download_model(args.model_uri, args.local_dir)
download_model(args.model_uri, args.local_dir, args.model_name)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions python/aibrix/aibrix/downloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,17 @@ def __hash__(self):
return hash(tuple(self.__dict__))


def get_downloader(model_uri: str) -> BaseDownloader:
def get_downloader(model_uri: str, model_name: Optional[str] = None) -> BaseDownloader:
"""Get downloader for model_uri."""
if re.match(envs.DOWNLOADER_S3_REGEX, model_uri):
from aibrix.downloader.s3 import S3Downloader

return S3Downloader(model_uri)
return S3Downloader(model_uri, model_name)
elif re.match(envs.DOWNLOADER_TOS_REGEX, model_uri):
from aibrix.downloader.tos import TOSDownloader

return TOSDownloader(model_uri)
return TOSDownloader(model_uri, model_name)
else:
from aibrix.downloader.huggingface import HuggingFaceDownloader

return HuggingFaceDownloader(model_uri)
return HuggingFaceDownloader(model_uri, model_name)
6 changes: 2 additions & 4 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def _parse_model_name_from_uri(model_uri: str) -> str:
class HuggingFaceDownloader(BaseDownloader):
def __init__(self, model_uri: str, model_name: Optional[str] = None):
if model_name is None:
if envs.DOWNLOADER_MODEL_NAME is not None:
model_name = envs.DOWNLOADER_MODEL_NAME
else:
model_name = _parse_model_name_from_uri(model_uri)
model_name = _parse_model_name_from_uri(model_uri)
logger.info(f"model_name is not set, using `{model_name}` as model_name")

self.hf_token = envs.DOWNLOADER_HF_TOKEN
self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT
Expand Down
19 changes: 15 additions & 4 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.downloader.utils import (
infer_model_name,
meta_file,
need_to_download,
save_meta_data,
)
from aibrix.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -37,14 +42,20 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:


class S3Downloader(BaseDownloader):
def __init__(self, model_uri):
model_name = envs.DOWNLOADER_MODEL_NAME
def __init__(self, model_uri, model_name: Optional[str] = None):
if model_name is None:
model_name = infer_model_name(model_uri)
logger.info(f"model_name is not set, using `{model_name}` as model_name")

ak = envs.DOWNLOADER_AWS_ACCESS_KEY_ID
sk = envs.DOWNLOADER_AWS_SECRET_ACCESS_KEY
endpoint = envs.DOWNLOADER_AWS_ENDPOINT_URL
region = envs.DOWNLOADER_AWS_REGION
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)

assert ak is not None and ak != "", "`AWS_ACCESS_KEY_ID` is not set."
assert sk is not None and sk != "", "`AWS_SECRET_ACCESS_KEY` is not set."

# Avoid warning log "Connection pool is full"
# Refs: https://github.com/boto/botocore/issues/619#issuecomment-583511406
max_pool_connections = (
Expand Down Expand Up @@ -75,7 +86,7 @@ def __init__(self, model_uri):
def _valid_config(self):
assert (
self.model_name is not None and self.model_name != ""
), "S3 model name is not set, please set env variable DOWNLOADER_MODEL_NAME."
), "S3 model name is not set, please check `--model-name`."
assert (
self.bucket_name is not None and self.bucket_name != ""
), "S3 bucket name is not set."
Expand Down
16 changes: 12 additions & 4 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

from aibrix import envs
from aibrix.downloader.base import BaseDownloader
from aibrix.downloader.utils import meta_file, need_to_download, save_meta_data
from aibrix.downloader.utils import (
infer_model_name,
meta_file,
need_to_download,
save_meta_data,
)
from aibrix.logger import init_logger

tos_logger = logging.getLogger("tos")
Expand All @@ -39,8 +44,11 @@ def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:


class TOSDownloader(BaseDownloader):
def __init__(self, model_uri):
model_name = envs.DOWNLOADER_MODEL_NAME
def __init__(self, model_uri, model_name: Optional[str] = None):
if model_name is None:
model_name = infer_model_name(model_uri)
logger.info(f"model_name is not set, using `{model_name}` as model_name")

ak = envs.DOWNLOADER_TOS_ACCESS_KEY or ""
sk = envs.DOWNLOADER_TOS_SECRET_KEY or ""
endpoint = envs.DOWNLOADER_TOS_ENDPOINT or ""
Expand All @@ -62,7 +70,7 @@ def __init__(self, model_uri):
def _valid_config(self):
assert (
self.model_name is not None and self.model_name != ""
), "TOS model name is not set, please set env variable DOWNLOADER_MODEL_NAME."
), "TOS model name is not set, please check `--model-name`."
assert (
self.bucket_name is not None and self.bucket_name != ""
), "TOS bucket name is not set."
Expand Down
7 changes: 7 additions & 0 deletions python/aibrix/aibrix/downloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,10 @@ def need_to_download(
f"DOWNLOADER_CHECK_FILE_EXIST={envs.DOWNLOADER_CHECK_FILE_EXIST}"
)
return True


def infer_model_name(uri: str):
if uri is None or uri == "":
raise ValueError("Model uri is empty.")

return uri.strip().strip("/").split("/")[-1]
1 change: 0 additions & 1 deletion python/aibrix/aibrix/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def _parse_int_or_none(value: Optional[str]) -> Optional[int]:
DOWNLOADER_LOCAL_DIR = os.getenv("DOWNLOADER_LOCAL_DIR", "/tmp/aibrix/models/")


DOWNLOADER_MODEL_NAME = os.getenv("DOWNLOADER_MODEL_NAME")
DOWNLOADER_NUM_THREADS = int(os.getenv("DOWNLOADER_NUM_THREADS", "4"))
DOWNLOADER_PART_THRESHOLD = _parse_int_or_none(os.getenv("DOWNLOADER_PART_THRESHOLD"))
DOWNLOADER_PART_CHUNKSIZE = _parse_int_or_none(os.getenv("DOWNLOADER_PART_CHUNKSIZE"))
Expand Down
34 changes: 25 additions & 9 deletions python/aibrix/tests/downloader/test_downloader_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ def mock_exsit_boto3(mock_boto3):


env_group = mock.Mock()
env_group.DOWNLOADER_MODEL_NAME = "model_name"
env_group.DOWNLOADER_NUM_THREADS = 4
env_group.DOWNLOADER_AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID"
env_group.DOWNLOADER_AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY"

env_no_ak = mock.Mock()
env_no_ak.DOWNLOADER_NUM_THREADS = 4
env_no_ak.DOWNLOADER_AWS_ACCESS_KEY_ID = ""
env_no_ak.DOWNLOADER_AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY"

env_group_no_model_name = mock.Mock()
env_group_no_model_name.DOWNLOADER_MODEL_NAME = None
env_group_no_model_name.DOWNLOADER_NUM_THREADS = 4
env_no_sk = mock.Mock()
env_no_sk.DOWNLOADER_NUM_THREADS = 4
env_no_sk.DOWNLOADER_AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID"
env_no_sk.DOWNLOADER_AWS_SECRET_ACCESS_KEY = ""


@mock.patch(ENVS_MODULE, env_group)
Expand Down Expand Up @@ -87,11 +93,21 @@ def test_get_downloader_s3_path_empty_path(mock_boto3):
assert "S3 bucket path is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group_no_model_name)
@mock.patch(ENVS_MODULE, env_no_ak)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_no_model_name(mock_tos):
mock_exsit_boto3(mock_tos)
def test_get_downloader_s3_no_ak(mock_boto3):
mock_exsit_boto3(mock_boto3)

with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/path")
assert "S3 model name is not set" in str(exception.value)
get_downloader("s3://bucket/")
assert "`AWS_ACCESS_KEY_ID` is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_no_sk)
@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_no_sk(mock_boto3):
mock_exsit_boto3(mock_boto3)

with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/")
assert "`AWS_SECRET_ACCESS_KEY` is not set." in str(exception.value)
15 changes: 0 additions & 15 deletions python/aibrix/tests/downloader/test_downloader_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ def mock_exsit_tos(mock_tos):


env_group = mock.Mock()
env_group.DOWNLOADER_MODEL_NAME = "model_name"


env_group_no_model_name = mock.Mock()
env_group_no_model_name.DOWNLOADER_MODEL_NAME = None


@mock.patch(ENVS_MODULE, env_group)
Expand Down Expand Up @@ -83,13 +78,3 @@ def test_get_downloader_tos_path_empty_path(mock_tos):
with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/")
assert "TOS bucket path is not set." in str(exception.value)


@mock.patch(ENVS_MODULE, env_group_no_model_name)
@mock.patch(TOS_MODULE)
def test_get_downloader_tos_no_model_name(mock_tos):
mock_exsit_tos(mock_tos)

with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/path")
assert "TOS model name is not set" in str(exception.value)
13 changes: 13 additions & 0 deletions python/aibrix/tests/downloader/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from aibrix.config import DOWNLOAD_CACHE_DIR
from aibrix.downloader.utils import (
check_file_exist,
infer_model_name,
load_meta_data,
meta_file,
need_to_download,
save_meta_data,
)
import pytest


def prepare_file_and_meta_data(file_path, meta_path, file_size, etag):
Expand Down Expand Up @@ -136,3 +138,14 @@ def test_need_to_download(mock_check: mock.Mock):
# recover envs
envs.DOWNLOADER_FORCE_DOWNLOAD = origin_force_download_env
envs.DOWNLOADER_CHECK_FILE_EXIST = origin_check_file_exist


def test_infer_model_name():
with pytest.raises(ValueError):
infer_model_name("")

with pytest.raises(ValueError):
infer_model_name(None)

model_name = infer_model_name("s3://bucket/path/to/model")
assert model_name == "model"

0 comments on commit b664587

Please sign in to comment.