Skip to content

Commit

Permalink
Added release storage to bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Dec 19, 2024
1 parent e772ea7 commit 3009584
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 4 deletions.
58 changes: 58 additions & 0 deletions observatory_platform/airflow/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
from __future__ import annotations

import logging
import json
import os
import tempfile
import uuid
from typing import Optional

import pendulum
from airflow.exceptions import AirflowException

from observatory_platform.airflow.workflow import make_workflow_folder
from observatory_platform.google.gcs import gcs_upload_file, gcs_read_blob, gcs_blob_uri

DATE_TIME_FORMAT = "YYYY-MM-DD_HH:mm:ss"

Expand Down Expand Up @@ -54,6 +59,50 @@ def set_task_state(success: bool, task_id: str, release: Release = None):
raise AirflowException(msg_failed)


def release_blob(uuid: str) -> str:
"""Generates the blob for a release object"""

return f"releases/{uuid}.json"


def release_to_bucket(release: Release, bucket: str, id: Optional[str] = None) -> str:
"""Uploads a release object to a bucket in json format. Will put it in {bucket}/releases.
:param release: The release object
:param bucket: The name of the bucket to upload to
:param id: The id to use as an identifier. Will be generated if not supplied.
:return: The id as a string
"""

if not id:
id = str(uuid.uuid4())

with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(json.dumps(release.to_dict()))
f.flush() # Force write stream to file
success, _ = gcs_upload_file(bucket_name=bucket, blob_name=release_blob(id), file_path=f.name)
if not success:
raise RuntimeError(f"Release could not be uploaded to gs://{bucket}/{release_blob}.json")

return id


def release_from_bucket(bucket: str, id: str) -> dict:
"""Downloads a release from a bucket.
:param bucket: The name of the bucket containing the release
:param uuid: The uuid of the release
:return: The content of the release as a json dictionary
"""

blob_name = release_blob(id)
content, success = gcs_read_blob(bucket_name=bucket, blob_name=blob_name)
if not success:
raise RuntimeError(f"Release at gs://{bucket}/{blob_name} could not be downloaded")

return json.loads(content)


class Release:
def __init__(self, *, dag_id: str, run_id: str):
"""Construct a Release instance
Expand Down Expand Up @@ -116,6 +165,15 @@ def transform_folder(self):
os.makedirs(path, exist_ok=True)
return path

@staticmethod
def from_dict(_dict: dict):
"""Converts the release dict to its object equivalent"""
raise NotImplementedError("_from_dict() not implemented for this Release object")

def to_dict(self) -> dict:
"""Transforms the release to its dictionary equivalent"""
raise NotImplementedError("_to_dict() not implemented for this Release object")

def __str__(self):
return f"Release(dag_id={self.dag_id}, run_id={self.run_id})"

Expand Down
66 changes: 64 additions & 2 deletions observatory_platform/airflow/tests/test_release.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2024 Curtin University
# Copyright 2019-2024 Curtin Universityrelease_blob(id)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,11 +12,73 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from random import randint
import json
import uuid
import tempfile

import pendulum
from airflow.exceptions import AirflowException

from observatory_platform.airflow.release import set_task_state, make_snapshot_date
from observatory_platform.airflow.release import (
Release,
set_task_state,
make_snapshot_date,
release_to_bucket,
release_from_bucket,
)
from observatory_platform.google.gcs import gcs_list_blobs, gcs_upload_file
from observatory_platform.sandbox.test_utils import SandboxTestCase
from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment


class _MyRelease(Release):
def __init__(self, dag_id: str, run_id: str, my_int: int, my_time: pendulum.DateTime):
super().__init__(dag_id=dag_id, run_id=run_id)
self.my_int = my_int
self.my_time = my_time

def to_dict(self):
return dict(dag_id=self.dag_id, run_id=self.run_id, my_int=self.my_int, my_time=self.my_time.timestamp())

@staticmethod
def from_dict(dict_: dict):
return _MyRelease(dict_["dag_id"], dict_["run_id"], dict_["my_int"], pendulum.from_timestamp(dict_["my_time"]))

def __str__(self):
return f"{self.dag_id}, {self.run_id}, {self.my_int}, {self.my_time.timestamp()}"


class TestGCSFunctions(SandboxTestCase):
release = _MyRelease(
dag_id="test_dag",
run_id=str(uuid.uuid4()),
my_int=randint(-10e9, 10e9),
my_time=pendulum.datetime(randint(0, 2000), 1, 1),
)
gcp_project_id = os.getenv("TEST_GCP_PROJECT_ID")
gcp_data_location = os.getenv("TEST_GCP_DATA_LOCATION")

def test_release_to_bucket(self):
env = SandboxEnvironment(project_id=self.gcp_project_id, data_location=self.gcp_data_location)
bucket = env.add_bucket()
with env.create():
id = release_to_bucket(self.release, bucket)
blobs = [b.name for b in gcs_list_blobs(bucket)]
self.assertIn(f"releases/{id}.pkl", blobs)

def test_release_from_bucket(self):
env = SandboxEnvironment(project_id=self.gcp_project_id, data_location=self.gcp_data_location)
bucket = env.add_bucket()
id = "test_release"
with env.create():
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(json.dumps(self.release.to_dict()))
f.flush() # Force write stream to file
gcs_upload_file(bucket_name=bucket, blob_name=f"releases/{id}.json", file_path=f.name)
release = release_from_bucket(bucket, id)
self.assertEqual(str(_MyRelease.from_dict(release)), str(self.release))


class TestWorkflow(SandboxTestCase):
Expand Down
52 changes: 51 additions & 1 deletion observatory_platform/google/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from multiprocessing import cpu_count
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Union

import pendulum
from airflow import AirflowException
Expand Down Expand Up @@ -191,6 +191,56 @@ def gcs_copy_blob(
return success


def gcs_read_blob(
*,
bucket_name: str,
blob_name: str,
retries: int = 3,
connection_sem: threading.BoundedSemaphore = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
client: storage.Client = None,
) -> Union[str, bool]:
"""Read the contents of a blob
:param bucket_name: the name of the Google Cloud storage bucket.
:param blob_name: the path to the blob.
:param retries: the number of times to retry downloading the blob.
:param connection_sem: a BoundedSemaphore to limit the number of download connections that can run at once.
:param chunk_size: the chunk size to use when downloading a blob in multiple parts, must be a multiple of 256 KB.
:param client: Storage client. If None default Client is created.
:return: The contents of the blob, or False if unsuccessful
"""

func_name = gcs_read_blob.__name__
logging.info(f"{func_name}: gs://{bucket_name}/{blob_name}")

if client is None:
client = storage.Client()
bucket = client.bucket(bucket_name)
blob: Blob = bucket.blob(blob_name)
uri = gcs_blob_uri(bucket_name, blob_name)

# Get connection semaphore
if connection_sem is not None:
connection_sem.acquire()

success = False
for i in range(0, retries):
try:
blob.chunk_size = chunk_size
downloaded_blob = blob.download_as_text()
success = True
break
except ChunkedEncodingError as e:
logging.error(f"{func_name}: exception downloading file: try={i}, uri={uri}, exception={e}")

# Release connection semaphore
if connection_sem is not None:
connection_sem.release()

return downloaded_blob, success


def gcs_download_blob(
*,
bucket_name: str,
Expand Down
15 changes: 15 additions & 0 deletions observatory_platform/google/tests/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from datetime import timedelta
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
import tempfile

import boto3
import pendulum
Expand All @@ -38,6 +40,7 @@
gcs_delete_bucket_dir,
gcs_download_blob,
gcs_download_blobs,
gcs_read_blob,
gcs_upload_file,
gcs_upload_files,
gcs_list_buckets_with_prefix,
Expand Down Expand Up @@ -188,6 +191,18 @@ def test_gcs_copy_blob(self):
if blob.exists():
blob.delete()

def test_gcs_read_blob(self):
blob_uuid = str(uuid4())
content = str(uuid4())
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(content)
f.flush() # Force write stream to file
gcs_upload_file(bucket_name=self.gc_bucket_name, blob_name=blob_uuid, file_path=f.name)

dl_content, success = gcs_read_blob(bucket_name=self.gc_bucket_name, blob_name=blob_uuid)
self.assertTrue(success)
self.assertEqual(content, dl_content)

@patch("observatory_platform.airflow.workflow.Variable.get")
def test_upload_download_blobs_from_cloud_storage(self, mock_get_variable):
runner = CliRunner()
Expand Down
2 changes: 1 addition & 1 deletion observatory_platform/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,5 +266,5 @@ def get_filename_from_http_header(url: str) -> str:
if response.status_code != 200:
raise AirflowException(f"get_filename_from_http_header: url={response.url}, status_code={response.status_code}")
header = response.headers["Content-Disposition"]
value, params = cgi.parse_header(header)
_, params = cgi.parse_header(header)
return params.get("filename")

0 comments on commit 3009584

Please sign in to comment.