Skip to content

Commit

Permalink
Use base aws classes in AWS Glue Data Catalog Sensors (apache#40492)
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan authored and romsharon98 committed Jul 26, 2024
1 parent e88264c commit 02c429b
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 25 deletions.
42 changes: 25 additions & 17 deletions airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,30 @@
from __future__ import annotations

from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlueCatalogPartitionSensor(BaseSensorOperator):
class GlueCatalogPartitionSensor(AwsBaseSensor[GlueCatalogHook]):
"""
Waits for a partition to show up in AWS Glue Catalog.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:GlueCatalogPartitionSensor`
:param table_name: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:param expression: The partition clause to wait for. This is passed as
Expand All @@ -46,19 +50,27 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
#aws-glue-api-catalog-partitions-GetPartitions
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
if not specified.
:param database_name: The name of the catalog database where the partitions reside.
:param poke_interval: Time in seconds that the job should wait in
between each tries
:param deferrable: If true, then the sensor will wait asynchronously for the partition to
show up in the AWS Glue Catalog.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = GlueCatalogHook

template_fields: Sequence[str] = aws_template_fields(
"database_name",
"table_name",
"expression",
Expand All @@ -70,19 +82,16 @@ def __init__(
*,
table_name: str,
expression: str = "ds='{{ ds }}'",
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
database_name: str = "default",
poke_interval: int = 60 * 3,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(poke_interval=poke_interval, **kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
super().__init__(**kwargs)
self.table_name = table_name
self.expression = expression
self.database_name = database_name
self.poke_interval = poke_interval
self.deferrable = deferrable

def execute(self, context: Context) -> Any:
Expand All @@ -93,7 +102,10 @@ def execute(self, context: Context) -> Any:
table_name=self.table_name,
expression=self.expression,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
waiter_delay=int(self.poke_interval),
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.timeout),
Expand Down Expand Up @@ -126,7 +138,3 @@ def execute_complete(self, context: Context, event: dict | None = None) -> None:
def get_hook(self) -> GlueCatalogHook:
"""Get the GlueCatalogHook."""
return self.hook

@cached_property
def hook(self) -> GlueCatalogHook:
return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
18 changes: 15 additions & 3 deletions airflow/providers/amazon/aws/triggers/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,21 @@ def __init__(
database_name: str,
table_name: str,
expression: str = "",
waiter_delay: int = 60,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 60,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
self.database_name = database_name
self.table_name = table_name
self.expression = expression
self.waiter_delay = waiter_delay

self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.waiter_delay = waiter_delay
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -121,12 +126,19 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"waiter_delay": self.waiter_delay,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self) -> GlueCatalogHook:
return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return GlueCatalogHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def poke(self, client: Any) -> bool:
if "." in self.table_name:
Expand Down
14 changes: 14 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/glue.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.senso
:start-after: [START howto_sensor_glue_data_quality_rule_recommendation_run]
:end-before: [END howto_sensor_glue_data_quality_rule_recommendation_run]

.. _howto/sensor:GlueCatalogPartitionSensor:

Wait on an AWS Glue Catalog Partition
======================================

To wait for a partition to show up in AWS Glue Catalog until it
reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor`

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_glue.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_glue_catalog_partition]
:end-before: [END howto_sensor_glue_catalog_partition]

Reference
---------

Expand Down
2 changes: 0 additions & 2 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,6 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest):
MISSING_EXAMPLES_FOR_CLASSES = {
# S3 Exasol transfer difficult to test, see: https://github.com/apache/airflow/issues/22632
"airflow.providers.amazon.aws.transfers.exasol_to_s3.ExasolToS3Operator",
# Glue Catalog sensor difficult to test
"airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor",
}

DEPRECATED_CLASSES = {
Expand Down
25 changes: 25 additions & 0 deletions tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,28 @@ def test_fail_execute_complete(self, soft_fail, expected_exception):
message = f"Trigger error: event is {event}"
with pytest.raises(expected_exception, match=message):
op.execute_complete(context={}, event=event)

def test_init(self):
default_op_kwargs = {
"task_id": "test_task",
"table_name": "test_table",
}

sensor = GlueCatalogPartitionSensor(**default_op_kwargs)
assert sensor.hook.aws_conn_id == "aws_default"
assert sensor.hook._region_name is None
assert sensor.hook._verify is None
assert sensor.hook._config is None

sensor = GlueCatalogPartitionSensor(
**default_op_kwargs,
aws_conn_id=None,
region_name="eu-west-2",
verify=True,
botocore_config={"read_timeout": 42},
)
assert sensor.hook.aws_conn_id is None
assert sensor.hook._region_name == "eu-west-2"
assert sensor.hook._verify is True
assert sensor.hook._config is not None
assert sensor.hook._config.read_timeout == 42
22 changes: 22 additions & 0 deletions tests/providers/amazon/aws/triggers/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ async def test_poke(self, mock_async_get_partitions):

assert response is True

def test_serialization(self):
trigger = GlueCatalogPartitionTrigger(
database_name="test_database",
table_name="test_table",
expression="id=12",
aws_conn_id="fake_conn_id",
region_name="eu-west-2",
verify=True,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.amazon.aws.triggers.glue.GlueCatalogPartitionTrigger"
assert kwargs == {
"database_name": "test_database",
"table_name": "test_table",
"expression": "id=12",
"waiter_delay": 60,
"aws_conn_id": "fake_conn_id",
"region_name": "eu-west-2",
"verify": True,
"botocore_config": None,
}


class TestGlueDataQualityEvaluationRunCompletedTrigger:
EXPECTED_WAITER_NAME = "data_quality_ruleset_evaluation_run_complete"
Expand Down
15 changes: 12 additions & 3 deletions tests/system/providers/amazon/aws/example_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import GlueCatalogPartitionSensor
from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, prune_logs
Expand All @@ -49,7 +50,7 @@
sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()

# Example csv data used as input to the example AWS Glue Job.
EXAMPLE_CSV = """
EXAMPLE_CSV = """product,value
apple,0.5
milk,2.5
bread,4.0
Expand Down Expand Up @@ -115,7 +116,7 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None:
upload_csv = S3CreateObjectOperator(
task_id="upload_csv",
s3_bucket=bucket_name,
s3_key="input/input.csv",
s3_key="input/category=mixed/input.csv",
data=EXAMPLE_CSV,
replace=True,
)
Expand Down Expand Up @@ -146,6 +147,15 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None:
# [END howto_sensor_glue_crawler]
wait_for_crawl.timeout = 500

# [START howto_sensor_glue_catalog_partition]
wait_for_catalog_partition = GlueCatalogPartitionSensor(
task_id="wait_for_catalog_partition",
table_name="input",
database_name=glue_db_name,
expression="category='mixed'",
)
# [END howto_sensor_glue_catalog_partition]

# [START howto_operator_glue]
submit_glue_job = GlueJobOperator(
task_id="submit_glue_job",
Expand Down Expand Up @@ -211,7 +221,6 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None:
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
Expand Down

0 comments on commit 02c429b

Please sign in to comment.