Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use base aws classes in AWS Glue Data Catalog Sensors #40492

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,30 @@
from __future__ import annotations

from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlueCatalogPartitionSensor(BaseSensorOperator):
class GlueCatalogPartitionSensor(AwsBaseSensor[GlueCatalogHook]):
"""
Waits for a partition to show up in AWS Glue Catalog.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:GlueCatalogPartitionSensor`

:param table_name: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:param expression: The partition clause to wait for. This is passed as
Expand All @@ -46,19 +50,27 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
#aws-glue-api-catalog-partitions-GetPartitions
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
if not specified.
:param database_name: The name of the catalog database where the partitions reside.
:param poke_interval: Time in seconds that the job should wait in
between each tries
:param deferrable: If true, then the sensor will wait asynchronously for the partition to
show up in the AWS Glue Catalog.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = GlueCatalogHook

template_fields: Sequence[str] = aws_template_fields(
"database_name",
"table_name",
"expression",
Expand All @@ -70,19 +82,16 @@ def __init__(
*,
table_name: str,
expression: str = "ds='{{ ds }}'",
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
database_name: str = "default",
poke_interval: int = 60 * 3,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(poke_interval=poke_interval, **kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
super().__init__(**kwargs)
self.table_name = table_name
self.expression = expression
self.database_name = database_name
self.poke_interval = poke_interval
self.deferrable = deferrable

def execute(self, context: Context) -> Any:
Expand All @@ -93,7 +102,10 @@ def execute(self, context: Context) -> Any:
table_name=self.table_name,
expression=self.expression,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
waiter_delay=int(self.poke_interval),
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.timeout),
Expand Down Expand Up @@ -126,7 +138,3 @@ def execute_complete(self, context: Context, event: dict | None = None) -> None:
def get_hook(self) -> GlueCatalogHook:
"""Get the GlueCatalogHook."""
return self.hook

@cached_property
def hook(self) -> GlueCatalogHook:
return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
18 changes: 15 additions & 3 deletions airflow/providers/amazon/aws/triggers/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,21 @@ def __init__(
database_name: str,
table_name: str,
expression: str = "",
waiter_delay: int = 60,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 60,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
self.database_name = database_name
self.table_name = table_name
self.expression = expression
self.waiter_delay = waiter_delay

self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.waiter_delay = waiter_delay
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -121,12 +126,19 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"waiter_delay": self.waiter_delay,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self) -> GlueCatalogHook:
return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return GlueCatalogHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def poke(self, client: Any) -> bool:
if "." in self.table_name:
Expand Down
14 changes: 14 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/glue.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.senso
:start-after: [START howto_sensor_glue_data_quality_rule_recommendation_run]
:end-before: [END howto_sensor_glue_data_quality_rule_recommendation_run]

.. _howto/sensor:GlueCatalogPartitionSensor:

Wait on an AWS Glue Catalog Partition
======================================

To wait for a partition to show up in AWS Glue Catalog until it
reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor`

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_glue.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_glue_catalog_partition]
:end-before: [END howto_sensor_glue_catalog_partition]

Reference
---------

Expand Down
2 changes: 0 additions & 2 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,6 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest):
MISSING_EXAMPLES_FOR_CLASSES = {
# S3 Exasol transfer difficult to test, see: https://github.com/apache/airflow/issues/22632
"airflow.providers.amazon.aws.transfers.exasol_to_s3.ExasolToS3Operator",
# Glue Catalog sensor difficult to test
"airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor",
}

DEPRECATED_CLASSES = {
Expand Down
25 changes: 25 additions & 0 deletions tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,28 @@ def test_fail_execute_complete(self, soft_fail, expected_exception):
message = f"Trigger error: event is {event}"
with pytest.raises(expected_exception, match=message):
op.execute_complete(context={}, event=event)

def test_init(self):
default_op_kwargs = {
"task_id": "test_task",
"table_name": "test_table",
}

sensor = GlueCatalogPartitionSensor(**default_op_kwargs)
assert sensor.hook.aws_conn_id == "aws_default"
assert sensor.hook._region_name is None
assert sensor.hook._verify is None
assert sensor.hook._config is None

sensor = GlueCatalogPartitionSensor(
**default_op_kwargs,
aws_conn_id=None,
region_name="eu-west-2",
verify=True,
botocore_config={"read_timeout": 42},
)
assert sensor.hook.aws_conn_id is None
assert sensor.hook._region_name == "eu-west-2"
assert sensor.hook._verify is True
assert sensor.hook._config is not None
assert sensor.hook._config.read_timeout == 42
22 changes: 22 additions & 0 deletions tests/providers/amazon/aws/triggers/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ async def test_poke(self, mock_async_get_partitions):

assert response is True

def test_serialization(self):
trigger = GlueCatalogPartitionTrigger(
database_name="test_database",
table_name="test_table",
expression="id=12",
aws_conn_id="fake_conn_id",
region_name="eu-west-2",
verify=True,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.amazon.aws.triggers.glue.GlueCatalogPartitionTrigger"
assert kwargs == {
"database_name": "test_database",
"table_name": "test_table",
"expression": "id=12",
"waiter_delay": 60,
"aws_conn_id": "fake_conn_id",
"region_name": "eu-west-2",
"verify": True,
"botocore_config": None,
}


class TestGlueDataQualityEvaluationRunCompletedTrigger:
EXPECTED_WAITER_NAME = "data_quality_ruleset_evaluation_run_complete"
Expand Down
15 changes: 12 additions & 3 deletions tests/system/providers/amazon/aws/example_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import GlueCatalogPartitionSensor
from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, prune_logs
Expand All @@ -49,7 +50,7 @@
sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()

# Example csv data used as input to the example AWS Glue Job.
EXAMPLE_CSV = """
EXAMPLE_CSV = """product,value
apple,0.5
milk,2.5
bread,4.0
Expand Down Expand Up @@ -115,7 +116,7 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None:
upload_csv = S3CreateObjectOperator(
task_id="upload_csv",
s3_bucket=bucket_name,
s3_key="input/input.csv",
s3_key="input/category=mixed/input.csv",
data=EXAMPLE_CSV,
replace=True,
)
Expand Down Expand Up @@ -146,6 +147,15 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None:
# [END howto_sensor_glue_crawler]
wait_for_crawl.timeout = 500

# [START howto_sensor_glue_catalog_partition]
wait_for_catalog_partition = GlueCatalogPartitionSensor(
task_id="wait_for_catalog_partition",
table_name="input",
database_name=glue_db_name,
expression="category='mixed'",
)
# [END howto_sensor_glue_catalog_partition]

# [START howto_operator_glue]
submit_glue_job = GlueJobOperator(
task_id="submit_glue_job",
Expand Down Expand Up @@ -211,7 +221,6 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None:
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
Expand Down