Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Add AI Runtime test case #197

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
cd python/aibrix
python -m mypy .
# - name: Run Test
# run: |
# cd python/aibrix
# python -m pytest ./tests
- name: Run Test
run: |
cd python/aibrix/tests
python -m pytest .
7 changes: 5 additions & 2 deletions python/aibrix/aibrix/downloader/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, model_uri: str, model_name: Optional[str] = None):
self.hf_token = envs.DOWNLOADER_HF_TOKEN
self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT
self.hf_revision = envs.DOWNLOADER_HF_REVISION
self.hf_api = HfApi(endpoint=self.hf_endpoint, token=self.hf_token)

super().__init__(
model_uri=model_uri,
Expand Down Expand Up @@ -67,6 +68,9 @@ def _valid_config(self):
), "Model uri must be in `repo/name` format."
assert self.bucket_name is None, "Bucket name is empty in HuggingFace."
assert self.model_name is not None, "Model name is not set."
assert self.hf_api.repo_exists(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the remote request? if hugging face repo is not accessible from CN machines, does it still work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect this network issue to be bypassed using the HF_ENDPOINT env setting

repo_id=self.model_uri
), f"Model {self.model_uri} not exist."

def _is_directory(self) -> bool:
"""Check if model_uri is a directory.
Expand All @@ -75,8 +79,7 @@ def _is_directory(self) -> bool:
return True

def _directory_list(self, path: str) -> List[str]:
hf_api = HfApi(endpoint=self.hf_endpoint, token=self.hf_token)
return hf_api.list_repo_files(repo_id=self.model_uri)
return self.hf_api.list_repo_files(repo_id=self.model_uri)

def _support_range_download(self) -> bool:
return False
Expand Down
4 changes: 2 additions & 2 deletions python/aibrix/aibrix/downloader/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def __init__(self, model_uri):

def _valid_config(self):
assert (
self.bucket_name is not None or self.bucket_name == ""
self.bucket_name is not None and self.bucket_name != ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's split the feature the CI test in separate PRs in future

), "S3 bucket name is not set."
assert (
self.bucket_path is not None or self.bucket_path == ""
self.bucket_path is not None and self.bucket_path != ""
), "S3 bucket path is not set."
try:
self.client.head_bucket(Bucket=self.bucket_name)
Expand Down
4 changes: 2 additions & 2 deletions python/aibrix/aibrix/downloader/tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(self, model_uri):

def _valid_config(self):
assert (
self.bucket_name is not None or self.bucket_name == ""
self.bucket_name is not None and self.bucket_name != ""
), "TOS bucket name is not set."
assert (
self.bucket_path is not None or self.bucket_path == ""
self.bucket_path is not None and self.bucket_path != ""
), "TOS bucket path is not set."
try:
self.client.head_bucket(self.bucket_name)
Expand Down
2 changes: 1 addition & 1 deletion python/aibrix/aibrix/metrics/engine_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@


def get_metric_standard_rules(engine: str) -> Dict[str, StandardRule]:
if engine == "vllm":
if engine.lower() == "vllm":
return VLLM_METRIC_STANDARD_RULES
else:
raise ValueError(f"Engine {engine} is not supported.")
Empty file.
39 changes: 39 additions & 0 deletions python/aibrix/tests/downloader/test_downloader_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from aibrix.downloader.base import get_downloader
from aibrix.downloader.huggingface import HuggingFaceDownloader


def test_get_downloader_hf():
downloader = get_downloader("facebook/opt-125m")
assert isinstance(downloader, HuggingFaceDownloader)


def test_get_downloader_hf_not_exist():
with pytest.raises(AssertionError) as exception:
get_downloader("not_exsit_path/model")
assert "not exist" in str(exception.value)


def test_get_downloader_hf_invalid_uri():
with pytest.raises(AssertionError) as exception:
get_downloader("single_field")
assert "Model uri must be in `repo/name` format." in str(exception.value)

with pytest.raises(AssertionError) as exception:
get_downloader("multi/filed/repo")
assert "Model uri must be in `repo/name` format." in str(exception.value)
72 changes: 72 additions & 0 deletions python/aibrix/tests/downloader/test_downloader_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

from aibrix.downloader.base import get_downloader
from aibrix.downloader.s3 import S3Downloader

S3_BOTO3_MODULE = "aibrix.downloader.s3.boto3"


def mock_not_exsit_boto3(mock_boto3):
mock_client = mock.Mock()
mock_boto3.client.return_value = mock_client
mock_client.head_bucket.side_effect = Exception("head bucket error")


def mock_exsit_boto3(mock_boto3):
mock_client = mock.Mock()
mock_boto3.client.return_value = mock_client
mock_client.head_bucket.return_value = mock.Mock()


@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3(mock_boto3):
mock_exsit_boto3(mock_boto3)

downloader = get_downloader("s3://bucket/path")
assert isinstance(downloader, S3Downloader)


@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_path_not_exist(mock_boto3):
mock_not_exsit_boto3(mock_boto3)

with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/not_exsit_path")
assert "not exist" in str(exception.value)


@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_path_empty(mock_boto3):
mock_exsit_boto3(mock_boto3)

# Bucket name and path both are empty,
# will first assert the name
with pytest.raises(AssertionError) as exception:
get_downloader("s3://")
assert "S3 bucket name is not set." in str(exception.value)


@mock.patch(S3_BOTO3_MODULE)
def test_get_downloader_s3_path_empty_path(mock_boto3):
mock_exsit_boto3(mock_boto3)

# bucket path is empty
with pytest.raises(AssertionError) as exception:
get_downloader("s3://bucket/")
assert "S3 bucket path is not set." in str(exception.value)
72 changes: 72 additions & 0 deletions python/aibrix/tests/downloader/test_downloader_tos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

from aibrix.downloader.base import get_downloader
from aibrix.downloader.tos import TOSDownloader

TOS_MODULE = "aibrix.downloader.tos.tos"


def mock_not_exsit_tos(mock_tos):
mock_client = mock.Mock()
mock_tos.TosClientV2.return_value = mock_client
mock_client.head_bucket.side_effect = Exception("head bucket error")


def mock_exsit_tos(mock_tos):
mock_client = mock.Mock()
mock_tos.TosClientV2.return_value = mock_client
mock_client.head_bucket.return_value = mock.Mock()


@mock.patch(TOS_MODULE)
def test_get_downloader_s3(mock_tos):
mock_exsit_tos(mock_tos)

downloader = get_downloader("tos://bucket/path")
assert isinstance(downloader, TOSDownloader)


@mock.patch(TOS_MODULE)
def test_get_downloader_s3_path_not_exist(mock_tos):
mock_not_exsit_tos(mock_tos)

with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/not_exsit_path")
assert "not exist" in str(exception.value)


@mock.patch(TOS_MODULE)
def test_get_downloader_s3_path_empty(mock_tos):
mock_exsit_tos(mock_tos)

# Bucket name and path both are empty,
# will first assert the name
with pytest.raises(AssertionError) as exception:
get_downloader("tos://")
assert "TOS bucket name is not set." in str(exception.value)


@mock.patch(TOS_MODULE)
def test_get_downloader_s3_path_empty_path(mock_tos):
mock_exsit_tos(mock_tos)

# bucket path is empty
with pytest.raises(AssertionError) as exception:
get_downloader("tos://bucket/")
assert "TOS bucket path is not set." in str(exception.value)
35 changes: 35 additions & 0 deletions python/aibrix/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from aibrix.metrics.engine_rules import get_metric_standard_rules


def test_get_metric_standard_rules_ignore_case():
# Engine str is all lowercase
rules = get_metric_standard_rules("vllm")
assert rules is not None

# The function get_metric_standard_rules is case-insensitive
rules2 = get_metric_standard_rules("vLLM")
assert rules == rules2


def test_get_metric_standard_rules_not_support():
# SGLang and TensorRT-LLM are not supported
with pytest.raises(ValueError):
get_metric_standard_rules("SGLang")

with pytest.raises(ValueError):
get_metric_standard_rules("TensorRT-LLM")
Loading