Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Upload model after finishing training #826

Merged
merged 29 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
dataquality.get_insights()
"""

__version__ = "1.6.0"
__version__ = "1.6.1"

import sys
from typing import Any, List, Optional
Expand Down
25 changes: 25 additions & 0 deletions dataquality/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,28 @@
},
)
return res

def get_presigned_url_for_model(
franz101 marked this conversation as resolved.
Show resolved Hide resolved
self, project_id: UUID4, run_id: UUID4, model_kind: str, model_parameters: Dict
) -> str:
"""
Returns a presigned url for uploading a model to S3

"""
return self.make_request(

Check warning on line 906 in dataquality/clients/api.py

View check run for this annotation

Codecov / codecov/patch

dataquality/clients/api.py#L906

Added line #L906 was not covered by tests
RequestType.POST,
url=f"{config.api_url}/{Route.projects}/{str(project_id)}/{Route.runs}/{str(run_id)}/{Route.model}",
body={"kind": model_kind, "parameters": model_parameters},
)["upload_url"]

def get_uploaded_model_info(self, project_id: UUID4, run_id: UUID4) -> Any:
"""
Returns information about the model for a given run.
Will also update the status to complete.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it update the status to complete?

also what does it return, the model or a presigned url to download it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we don't know on the backend when a minio upload is completed we update it every time we get the model and if the filename is not saved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what status is it updating, the job status?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the upload is completed. at first it's the put link, once it's put we will pull the download link from minio and add it to the entry

:param project_id: The project id
:param run_id: The run id
"""
return self.make_request(

Check warning on line 919 in dataquality/clients/api.py

View check run for this annotation

Codecov / codecov/patch

dataquality/clients/api.py#L919

Added line #L919 was not covered by tests
RequestType.GET,
url=f"{config.api_url}/{Route.projects}/{str(project_id)}/{Route.runs}/{str(run_id)}/{Route.model}",
)
18 changes: 18 additions & 0 deletions dataquality/core/finish.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from dataquality.analytics import Analytics
from dataquality.clients.api import ApiClient
from dataquality.core._config import config
from dataquality.core.log import get_model_logger
from dataquality.core.report import build_run_report
from dataquality.schemas import RequestType, Route
from dataquality.schemas.job import JobName
from dataquality.schemas.task_type import TaskType
from dataquality.utils.dq_logger import DQ_LOG_FILE_HOME, upload_dq_log_file
from dataquality.utils.helpers import check_noop, gpu_available
from dataquality.utils.thread_pool import ThreadPoolManager
from dataquality.utils.upload_model import upload_model_to_dq

api_client = ApiClient()
a = Analytics(ApiClient, config) # type: ignore
Expand All @@ -25,6 +27,7 @@
wait: bool = True,
create_data_embs: Optional[bool] = None,
data_embs_col: str = "text",
upload_model: bool = bool(os.environ.get("DQ_UPLOAD_MODEL", False)),
) -> str:
"""
Finishes the current run and invokes a job
Expand All @@ -43,6 +46,8 @@
If not set, we default to 'text' which corresponds to the input text.
Can also be set to `target`, `generated_output` or any other column that is
logged as metadata.
:param upload_model: If True, the model will be stored in the galileo project.
Default False or set by the environment variable DQ_UPLOAD_MODEL.
"""
a.log_function("dq/finish")
if create_data_embs is None:
Expand Down Expand Up @@ -85,6 +90,19 @@
f"Job {res['job_name']} successfully submitted. Results will be available "
f"soon at {res['link']}"
)
if upload_model:
try:
helper_data = get_model_logger().logger_config.helper_data
if helper_data and "model" in helper_data:
model = helper_data["model"]
model_parameters = helper_data["model_parameters"]
model_kind = helper_data["model_kind"]
upload_model_to_dq(model, model_parameters, model_kind)
print("Model uploaded successfully.")
else:
print("No model to upload.")
except Exception as e:
print(f"Error uploading model: {e}")

Check warning on line 105 in dataquality/core/finish.py

View check run for this annotation

Codecov / codecov/patch

dataquality/core/finish.py#L103-L105

Added lines #L103 - L105 were not covered by tests
if data_logger.logger_config.conditions:
print(
"Waiting for run to process before building run report... "
Expand Down
11 changes: 11 additions & 0 deletions dataquality/integrations/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataquality.clients.api import ApiClient
from dataquality.exceptions import GalileoException
from dataquality.integrations.torch import TorchBaseInstance
from dataquality.schemas.model import ModelUploadType
from dataquality.schemas.split import Split
from dataquality.schemas.torch import DimensionSlice, InputDim, Layer
from dataquality.utils.helpers import check_noop
Expand Down Expand Up @@ -337,6 +338,16 @@ def watch(
# Unpatch Trainer after logging (when finished is called)
cleanup_manager = RefManager(lambda: unwatch(trainer))
helper_data["cleaner"] = Cleanup(cleanup_manager)
helper_data["model"] = trainer.model
helper_data["model_parameters"] = {
"classifier_layer": classifier_layer,
"embedding_dim": embedding_dim,
"logits_dim": logits_dim,
"embedding_fn": embedding_fn,
"logits_fn": logits_fn,
"last_hidden_state_layer": last_hidden_state_layer,
}
helper_data["model_kind"] = ModelUploadType.transformers


def unwatch(trainer: Trainer) -> None:
Expand Down
6 changes: 6 additions & 0 deletions dataquality/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ class ModelFramework(str, Enum):
keras = "keras"
hf = "hf"
auto = "auto"


@unique
class ModelUploadType(str, Enum):
transformers = "transformers"
setfit = "setfit"
1 change: 1 addition & 0 deletions dataquality/schemas/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Route(str, Enum):
notify = "notify/email"
token = "get-token"
upload_file = "upload_file"
model = "model"
link = "link"

@staticmethod
Expand Down
61 changes: 61 additions & 0 deletions dataquality/utils/upload_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import tarfile
import tempfile
from typing import Any, Dict, Tuple

import requests

from dataquality.clients.api import ApiClient
from dataquality.core._config import config
from dataquality.schemas.model import ModelUploadType

api_client = ApiClient()


def create_tar_archive(source_folder: str, output_filename: str) -> None:
"""
Creates a tar archive from a folder / model.
:param source_folder: The folder to archive.
:param output_filename: The name of the output tar file.
"""
with tarfile.open(output_filename, "w") as archive:
for item in os.listdir(source_folder):
full_path = os.path.join(source_folder, item)
archive.add(full_path, arcname=item)


def upload_to_minio_using_presigned_url(presigned_url: str, file_path: str) -> Tuple:
"""
Uploads a file to a presigned url.
"""
with open(file_path, "rb") as f:
response = requests.put(presigned_url, data=f)
return response.status_code, response.text

Check warning on line 33 in dataquality/utils/upload_model.py

View check run for this annotation

Codecov / codecov/patch

dataquality/utils/upload_model.py#L31-L33

Added lines #L31 - L33 were not covered by tests


def upload_model_to_dq(
model: Any, model_parameters: Dict[str, Any], model_kind: ModelUploadType
) -> None:
"""
Uploads the model to the Galileo platform.

:return: None
"""
assert config.current_project_id, "Project id is required"
assert config.current_run_id, "Run id is required"
signed_url = api_client.get_presigned_url_for_model(

Check warning on line 46 in dataquality/utils/upload_model.py

View check run for this annotation

Codecov / codecov/patch

dataquality/utils/upload_model.py#L44-L46

Added lines #L44 - L46 were not covered by tests
project_id=config.current_project_id,
run_id=config.current_run_id,
model_kind=model_kind,
model_parameters=model_parameters,
)
# save to temporary folder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(f"{tmpdirname}/model_export")
tar_path = f"{tmpdirname}/model.tar.gz"
create_tar_archive(f"{tmpdirname}/model_export", tar_path)
upload_to_minio_using_presigned_url(signed_url, tar_path)
api_client.get_uploaded_model_info(

Check warning on line 58 in dataquality/utils/upload_model.py

View check run for this annotation

Codecov / codecov/patch

dataquality/utils/upload_model.py#L53-L58

Added lines #L53 - L58 were not covered by tests
project_id=config.current_project_id,
run_id=config.current_run_id,
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,3 @@ exclude = '''
| __pycache__
)/
'''

31 changes: 31 additions & 0 deletions tests/core/test_finish.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,34 @@ def test_finish_with_conditions(
run_id=test_session_vars.DEFAULT_RUN_ID,
link="https://www.example.com",
)


@mock.patch.object(dataquality.core.finish, "_reset_run")
@mock.patch.object(dataquality.core.finish, "upload_dq_log_file")
@mock.patch.object(dataquality.clients.api.ApiClient, "make_request")
@mock.patch.object(
dataquality.core.finish.dataquality,
"get_data_logger",
)
@mock.patch.object(dataquality.core.finish, "upload_model_to_dq")
@mock.patch.object(dataquality.core.finish, "wait_for_run")
def test_finish_dq_upload(
mock_wait_for_run: MagicMock,
mock_upload_model_to_dq: MagicMock,
mock_get_data_logger: MagicMock,
mock_make_request: MagicMock,
mock_upload_log_file: MagicMock,
mock_reset_run: MagicMock,
set_test_config,
) -> None:
set_test_config(task_type=TaskType.text_classification)
mock_get_data_logger.return_value = MagicMock(
logger_config=MagicMock(conditions=None)
)
helper_data = dataquality.core.log.get_model_logger().logger_config.helper_data
helper_data["model"] = "model"
helper_data["model_parameters"] = "model_parameters"
helper_data["model_kind"] = "model_kind"
dataquality.finish(upload_model=True)
mock_wait_for_run.assert_called_once()
mock_upload_model_to_dq.assert_called()
42 changes: 42 additions & 0 deletions tests/utils/test_upload_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import tarfile
import tempfile

from dataquality.utils.upload_model import create_tar_archive


def test_create_tar_archive() -> None:
# Create a temporary directory as the source folder
with tempfile.TemporaryDirectory() as source_folder:
# Populate the source folder with mock files
with open(os.path.join(source_folder, "file1.txt"), "w") as f:
f.write("This is a mock file.")
with open(os.path.join(source_folder, "file2.txt"), "w") as f:
f.write("This is another mock file.")

# Create a subfolder with another mock file
subfolder_path = os.path.join(source_folder, "subfolder")
os.mkdir(subfolder_path)
with open(os.path.join(subfolder_path, "file3.txt"), "w") as f:
f.write("This is a mock file inside a subfolder.")
tar_filename = ""
# Create a temporary file to hold the tar archive
with tempfile.NamedTemporaryFile() as tar_file:
tar_filename = tar_file.name

# Call the function to create a tar archive
create_tar_archive(source_folder, tar_filename)

# Untar the created archive to a new temporary directory
with tempfile.TemporaryDirectory() as untar_folder:
with tarfile.open(tar_filename, "r") as archive:
archive.extractall(path=untar_folder)

# Verify the contents of the untarred directory match the original
assert os.path.exists(os.path.join(untar_folder, "file1.txt"))
assert os.path.exists(os.path.join(untar_folder, "file2.txt"))
assert os.path.exists(
os.path.join(untar_folder, "subfolder", "file3.txt")
)

assert not os.path.exists(tar_filename)
Loading