diff --git a/client/verta/tests/test_artifacts.py b/client/verta/tests/test_artifacts.py index 7d79d9921e..a391de6aca 100644 --- a/client/verta/tests/test_artifacts.py +++ b/client/verta/tests/test_artifacts.py @@ -9,6 +9,8 @@ import tempfile import zipfile +import requests + from verta._internal_utils import _artifact_utils from verta._internal_utils import _utils @@ -91,6 +93,44 @@ def test_upload_dir(self, experiment_run, strs, dir_and_files): with zipfile.ZipFile(experiment_run.get_artifact(key), 'r') as zipf: assert filepaths == set(zipf.namelist()) + @pytest.mark.not_oss + def test_upload_multipart(self, experiment_run, in_tempdir): + key = "large" + + # create artifact + with tempfile.NamedTemporaryFile(suffix='.bin', dir=".", delete=False) as tempf: + # write 6 MB file in 1 MB chunks + for _ in range(6): + tempf.write(os.urandom(1*(10**6))) + + # log artifact + # TODO: set part size in config file when supported + PART_SIZE = int(5.4*(10**6)) # 5.4 MB; S3 parts must be > 5 MB + os.environ['VERTA_ARTIFACT_PART_SIZE'] = str(PART_SIZE) + try: + experiment_run.log_artifact(key, tempf.name) + finally: + del os.environ['VERTA_ARTIFACT_PART_SIZE'] + + # get artifact parts + committed_parts = experiment_run.get_artifact_parts(key) + assert committed_parts + + # part checksums match actual file contents + with open(tempf.name, 'rb') as f: + file_parts = iter(lambda: f.read(PART_SIZE), b'') + for file_part, committed_part in zip(file_parts, committed_parts): + part_hash = hashlib.md5(file_part).hexdigest() + assert part_hash == committed_part['etag'].strip('"') + + # retrieved artifact matches original file + filepath = experiment_run.download_artifact(key, download_to_path=key) + with open(filepath, 'rb') as f: + file_parts = iter(lambda: f.read(PART_SIZE), b'') + for file_part, committed_part in zip(file_parts, committed_parts): + part_hash = hashlib.md5(file_part).hexdigest() + assert part_hash == committed_part['etag'].strip('"') + def test_empty(self, experiment_run, strs): """uploading empty data, e.g. an empty file, raises an error""" diff --git a/client/verta/verta/client.py b/client/verta/verta/client.py index c9943d8f5a..e9dc903022 100644 --- a/client/verta/verta/client.py +++ b/client/verta/verta/client.py @@ -2124,6 +2124,12 @@ def _upload_artifact(self, key, artifact_stream, part_size=64*(10**6)): If using multipart upload, number of bytes to upload per part. """ + # TODO: add to Client config + env_part_size = os.environ.get('VERTA_ARTIFACT_PART_SIZE', "") + if env_part_size.isnumeric(): + part_size = int(float(env_part_size)) + print("set artifact part size {} from environment".format(part_size)) + artifact_stream.seek(0) if self._conf.debug: print("[DEBUG] uploading {} bytes ({})".format(_artifact_utils.get_stream_length(artifact_stream), key)) @@ -3440,6 +3446,22 @@ def download_artifact(self, key, download_to_path): return download_to_path + def get_artifact_parts(self, key): + endpoint = "{}://{}/api/v1/modeldb/experiment-run/getCommittedArtifactParts".format( + self._conn.scheme, + self._conn.socket, + ) + data = {'id': self.id, 'key': key} + response = _utils.make_request("GET", endpoint, self._conn, params=data) + _utils.raise_for_http_error(response) + + committed_parts = _utils.body_to_json(response).get('artifact_parts', []) + committed_parts = list(sorted( + committed_parts, + key=lambda part: int(part['part_number']), + )) + return committed_parts + def log_observation(self, key, value, timestamp=None, epoch_num=None): """ Logs an observation to this Experiment Run.