diff --git a/observatory_platform/airflow/release.py b/observatory_platform/airflow/release.py index 4e41b10d1..aace73da8 100644 --- a/observatory_platform/airflow/release.py +++ b/observatory_platform/airflow/release.py @@ -17,12 +17,17 @@ from __future__ import annotations import logging +import json import os +import tempfile +import uuid +from typing import Optional import pendulum from airflow.exceptions import AirflowException from observatory_platform.airflow.workflow import make_workflow_folder +from observatory_platform.google.gcs import gcs_upload_file, gcs_read_blob, gcs_blob_uri DATE_TIME_FORMAT = "YYYY-MM-DD_HH:mm:ss" @@ -54,6 +59,50 @@ def set_task_state(success: bool, task_id: str, release: Release = None): raise AirflowException(msg_failed) +def release_blob(uuid: str) -> str: + """Generates the blob for a release object""" + + return f"releases/{uuid}.json" + + +def release_to_bucket(release: Release, bucket: str, id: Optional[str] = None) -> str: + """Uploads a release object to a bucket in json format. Will put it in {bucket}/releases. + + :param release: The release object + :param bucket: The name of the bucket to upload to + :param id: The id to use as an identifier. Will be generated if not supplied. + :return: The id as a string + """ + + if not id: + id = str(uuid.uuid4()) + + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(json.dumps(release.to_dict())) + f.flush() # Force write stream to file + success, _ = gcs_upload_file(bucket_name=bucket, blob_name=release_blob(id), file_path=f.name) + if not success: + raise RuntimeError(f"Release could not be uploaded to gs://{bucket}/{release_blob}.json") + + return id + + +def release_from_bucket(bucket: str, id: str) -> dict: + """Downloads a release from a bucket. + + :param bucket: The name of the bucket containing the release + :param uuid: The uuid of the release + :return: The content of the release as a json dictionary + """ + + blob_name = release_blob(id) + content, success = gcs_read_blob(bucket_name=bucket, blob_name=blob_name) + if not success: + raise RuntimeError(f"Release at gs://{bucket}/{blob_name} could not be downloaded") + + return json.loads(content) + + class Release: def __init__(self, *, dag_id: str, run_id: str): """Construct a Release instance @@ -116,6 +165,15 @@ def transform_folder(self): os.makedirs(path, exist_ok=True) return path + @staticmethod + def from_dict(_dict: dict): + """Converts the release dict to its object equivalent""" + raise NotImplementedError("_from_dict() not implemented for this Release object") + + def to_dict(self) -> dict: + """Transforms the release to its dictionary equivalent""" + raise NotImplementedError("_to_dict() not implemented for this Release object") + def __str__(self): return f"Release(dag_id={self.dag_id}, run_id={self.run_id})" diff --git a/observatory_platform/airflow/tests/test_release.py b/observatory_platform/airflow/tests/test_release.py index 17c998a6d..6922e8549 100644 --- a/observatory_platform/airflow/tests/test_release.py +++ b/observatory_platform/airflow/tests/test_release.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 Curtin University +# Copyright 2019-2024 Curtin Universityrelease_blob(id) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from random import randint +import json +import uuid +import tempfile + import pendulum from airflow.exceptions import AirflowException -from observatory_platform.airflow.release import set_task_state, make_snapshot_date +from observatory_platform.airflow.release import ( + Release, + set_task_state, + make_snapshot_date, + release_to_bucket, + release_from_bucket, +) +from observatory_platform.google.gcs import gcs_list_blobs, gcs_upload_file from observatory_platform.sandbox.test_utils import SandboxTestCase +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment + + +class _MyRelease(Release): + def __init__(self, dag_id: str, run_id: str, my_int: int, my_time: pendulum.DateTime): + super().__init__(dag_id=dag_id, run_id=run_id) + self.my_int = my_int + self.my_time = my_time + + def to_dict(self): + return dict(dag_id=self.dag_id, run_id=self.run_id, my_int=self.my_int, my_time=self.my_time.timestamp()) + + @staticmethod + def from_dict(dict_: dict): + return _MyRelease(dict_["dag_id"], dict_["run_id"], dict_["my_int"], pendulum.from_timestamp(dict_["my_time"])) + + def __str__(self): + return f"{self.dag_id}, {self.run_id}, {self.my_int}, {self.my_time.timestamp()}" + + +class TestGCSFunctions(SandboxTestCase): + release = _MyRelease( + dag_id="test_dag", + run_id=str(uuid.uuid4()), + my_int=randint(-10e9, 10e9), + my_time=pendulum.datetime(randint(0, 2000), 1, 1), + ) + gcp_project_id = os.getenv("TEST_GCP_PROJECT_ID") + gcp_data_location = os.getenv("TEST_GCP_DATA_LOCATION") + + def test_release_to_bucket(self): + env = SandboxEnvironment(project_id=self.gcp_project_id, data_location=self.gcp_data_location) + bucket = env.add_bucket() + with env.create(): + id = release_to_bucket(self.release, bucket) + blobs = [b.name for b in gcs_list_blobs(bucket)] + self.assertIn(f"releases/{id}.pkl", blobs) + + def test_release_from_bucket(self): + env = SandboxEnvironment(project_id=self.gcp_project_id, data_location=self.gcp_data_location) + bucket = env.add_bucket() + id = "test_release" + with env.create(): + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(json.dumps(self.release.to_dict())) + f.flush() # Force write stream to file + gcs_upload_file(bucket_name=bucket, blob_name=f"releases/{id}.json", file_path=f.name) + release = release_from_bucket(bucket, id) + self.assertEqual(str(_MyRelease.from_dict(release)), str(self.release)) class TestWorkflow(SandboxTestCase): diff --git a/observatory_platform/google/gcs.py b/observatory_platform/google/gcs.py index 7ccd483ae..682f1e875 100644 --- a/observatory_platform/google/gcs.py +++ b/observatory_platform/google/gcs.py @@ -27,7 +27,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum from multiprocessing import cpu_count -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union import pendulum from airflow import AirflowException @@ -191,6 +191,56 @@ def gcs_copy_blob( return success +def gcs_read_blob( + *, + bucket_name: str, + blob_name: str, + retries: int = 3, + connection_sem: threading.BoundedSemaphore = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + client: storage.Client = None, +) -> Union[str, bool]: + """Read the contents of a blob + + :param bucket_name: the name of the Google Cloud storage bucket. + :param blob_name: the path to the blob. + :param retries: the number of times to retry downloading the blob. + :param connection_sem: a BoundedSemaphore to limit the number of download connections that can run at once. + :param chunk_size: the chunk size to use when downloading a blob in multiple parts, must be a multiple of 256 KB. + :param client: Storage client. If None default Client is created. + :return: The contents of the blob, or False if unsuccessful + """ + + func_name = gcs_read_blob.__name__ + logging.info(f"{func_name}: gs://{bucket_name}/{blob_name}") + + if client is None: + client = storage.Client() + bucket = client.bucket(bucket_name) + blob: Blob = bucket.blob(blob_name) + uri = gcs_blob_uri(bucket_name, blob_name) + + # Get connection semaphore + if connection_sem is not None: + connection_sem.acquire() + + success = False + for i in range(0, retries): + try: + blob.chunk_size = chunk_size + downloaded_blob = blob.download_as_text() + success = True + break + except ChunkedEncodingError as e: + logging.error(f"{func_name}: exception downloading file: try={i}, uri={uri}, exception={e}") + + # Release connection semaphore + if connection_sem is not None: + connection_sem.release() + + return downloaded_blob, success + + def gcs_download_blob( *, bucket_name: str, diff --git a/observatory_platform/google/tests/test_gcs.py b/observatory_platform/google/tests/test_gcs.py index e3a806a21..d5d2537dd 100644 --- a/observatory_platform/google/tests/test_gcs.py +++ b/observatory_platform/google/tests/test_gcs.py @@ -19,6 +19,8 @@ from datetime import timedelta from typing import Optional from unittest.mock import patch +from uuid import uuid4 +import tempfile import boto3 import pendulum @@ -38,6 +40,7 @@ gcs_delete_bucket_dir, gcs_download_blob, gcs_download_blobs, + gcs_read_blob, gcs_upload_file, gcs_upload_files, gcs_list_buckets_with_prefix, @@ -188,6 +191,18 @@ def test_gcs_copy_blob(self): if blob.exists(): blob.delete() + def test_gcs_read_blob(self): + blob_uuid = str(uuid4()) + content = str(uuid4()) + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(content) + f.flush() # Force write stream to file + gcs_upload_file(bucket_name=self.gc_bucket_name, blob_name=blob_uuid, file_path=f.name) + + dl_content, success = gcs_read_blob(bucket_name=self.gc_bucket_name, blob_name=blob_uuid) + self.assertTrue(success) + self.assertEqual(content, dl_content) + @patch("observatory_platform.airflow.workflow.Variable.get") def test_upload_download_blobs_from_cloud_storage(self, mock_get_variable): runner = CliRunner() diff --git a/observatory_platform/url_utils.py b/observatory_platform/url_utils.py index e2acc24f4..6682df931 100644 --- a/observatory_platform/url_utils.py +++ b/observatory_platform/url_utils.py @@ -266,5 +266,5 @@ def get_filename_from_http_header(url: str) -> str: if response.status_code != 200: raise AirflowException(f"get_filename_from_http_header: url={response.url}, status_code={response.status_code}") header = response.headers["Content-Disposition"] - value, params = cgi.parse_header(header) + _, params = cgi.parse_header(header) return params.get("filename")