diff --git a/airflow/providers/amazon/aws/hooks/glue_crawler.py b/airflow/providers/amazon/aws/hooks/glue_crawler.py index 917b96b2c6549..e77efa7f63c0d 100644 --- a/airflow/providers/amazon/aws/hooks/glue_crawler.py +++ b/airflow/providers/amazon/aws/hooks/glue_crawler.py @@ -22,6 +22,7 @@ from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.sts import StsHook class GlueCrawlerHook(AwsBaseHook): @@ -78,16 +79,56 @@ def update_crawler(self, **crawler_kwargs) -> bool: crawler_name = crawler_kwargs["Name"] current_crawler = self.get_crawler(crawler_name) + tags_updated = self.update_tags(crawler_name, crawler_kwargs.pop("Tags", {})) + update_config = { - key: value for key, value in crawler_kwargs.items() if current_crawler[key] != crawler_kwargs[key] + key: value + for key, value in crawler_kwargs.items() + if current_crawler.get(key, None) != crawler_kwargs.get(key) } - if update_config != {}: + if update_config: self.log.info("Updating crawler: %s", crawler_name) self.glue_client.update_crawler(**crawler_kwargs) self.log.info("Updated configurations: %s", update_config) return True - else: - return False + return tags_updated + + def update_tags(self, crawler_name: str, crawler_tags: dict) -> bool: + """ + Updates crawler tags + + :param crawler_name: Name of the crawler for which to update tags + :param crawler_tags: Dictionary of new tags. If empty, all tags will be deleted + :return: True if tags were updated and false otherwise + """ + account_number = StsHook(aws_conn_id=self.aws_conn_id).get_account_number() + crawler_arn = ( + f"arn:{self.conn_partition}:glue:{self.conn_region_name}:{account_number}:crawler/{crawler_name}" + ) + current_crawler_tags: dict = self.glue_client.get_tags(ResourceArn=crawler_arn)["Tags"] + + update_tags = {} + delete_tags = [] + for key, value in current_crawler_tags.items(): + wanted_tag_value = crawler_tags.get(key, None) + if wanted_tag_value is None: + # key is missing from new configuration, mark it for deletion + delete_tags.append(key) + elif wanted_tag_value != value: + update_tags[key] = wanted_tag_value + + updated_tags = False + if update_tags: + self.log.info("Updating crawler tags: %s", crawler_name) + self.glue_client.tag_resource(ResourceArn=crawler_arn, TagsToAdd=update_tags) + self.log.info("Updated crawler tags: %s", crawler_name) + updated_tags = True + if delete_tags: + self.log.info("Deleting crawler tags: %s", crawler_name) + self.glue_client.untag_resource(ResourceArn=crawler_arn, TagsToRemove=delete_tags) + self.log.info("Deleted crawler tags: %s", crawler_name) + updated_tags = True + return updated_tags def create_crawler(self, **crawler_kwargs) -> str: """ diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index 1e30be1b262bd..59ba2031fdd7e 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -50,6 +50,7 @@ def __init__( self, config, aws_conn_id="aws_default", + region_name: str | None = None, poll_interval: int = 5, wait_for_completion: bool = True, **kwargs, @@ -58,12 +59,13 @@ def __init__( self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval self.wait_for_completion = wait_for_completion + self.region_name = region_name self.config = config @cached_property def hook(self) -> GlueCrawlerHook: """Create and return an GlueCrawlerHook.""" - return GlueCrawlerHook(self.aws_conn_id) + return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name) def execute(self, context: Context): """ diff --git a/tests/providers/amazon/aws/hooks/test_glue_crawler.py b/tests/providers/amazon/aws/hooks/test_glue_crawler.py index ac2d3cba2cea0..1f1961e0298a9 100644 --- a/tests/providers/amazon/aws/hooks/test_glue_crawler.py +++ b/tests/providers/amazon/aws/hooks/test_glue_crawler.py @@ -20,6 +20,9 @@ from copy import deepcopy from unittest import mock +from moto import mock_sts +from moto.core import DEFAULT_ACCOUNT_ID + from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook mock_crawler_name = "test-crawler" @@ -78,13 +81,16 @@ } """, "SecurityConfiguration": "test", - "Tags": {"test": "foo"}, + "Tags": {"test": "foo", "bar": "test"}, } class TestGlueCrawlerHook: def setup_method(self): self.hook = GlueCrawlerHook(aws_conn_id="aws_default") + self.crawler_arn = ( + f"arn:aws:glue:{self.hook.conn_region_name}:{DEFAULT_ACCOUNT_ID}:crawler/{mock_crawler_name}" + ) def test_init(self): assert self.hook.aws_conn_id == "aws_default" @@ -104,16 +110,71 @@ class MockException(Exception): assert self.hook.has_crawler(mock_crawler_name) is False mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) + @mock_sts @mock.patch.object(GlueCrawlerHook, "get_conn") def test_update_crawler_needed(self, mock_get_conn): mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config} mock_config_two = deepcopy(mock_config) mock_config_two["Role"] = "test-2-role" + mock_config_two.pop("Tags") + assert self.hook.update_crawler(**mock_config_two) is True + mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) + mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two) + + @mock_sts + @mock.patch.object(GlueCrawlerHook, "get_conn") + def test_update_crawler_missing_keys(self, mock_get_conn): + mock_config_missing_configuration = deepcopy(mock_config) + mock_config_missing_configuration.pop("Configuration") + mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config_missing_configuration} + + mock_config_two = deepcopy(mock_config) + mock_config_two.pop("Tags") assert self.hook.update_crawler(**mock_config_two) is True mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two) + @mock_sts + @mock.patch.object(GlueCrawlerHook, "get_conn") + def test_update_tags_not_needed(self, mock_get_conn): + mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config} + mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]} + + assert self.hook.update_tags(mock_crawler_name, mock_config["Tags"]) is False + mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn) + mock_get_conn.return_value.tag_resource.assert_not_called() + mock_get_conn.return_value.untag_resource.assert_not_called() + + @mock_sts + @mock.patch.object(GlueCrawlerHook, "get_conn") + def test_remove_all_tags(self, mock_get_conn): + mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config} + mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]} + + assert self.hook.update_tags(mock_crawler_name, {}) is True + mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn) + mock_get_conn.return_value.tag_resource.assert_not_called() + mock_get_conn.return_value.untag_resource.assert_called_once_with( + ResourceArn=self.crawler_arn, TagsToRemove=["test", "bar"] + ) + + @mock_sts + @mock.patch.object(GlueCrawlerHook, "get_conn") + def test_replace_tag(self, mock_get_conn): + mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config} + mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]} + + mock_config_two = deepcopy(mock_config) + mock_config_two.pop("Tags") + assert self.hook.update_tags(mock_crawler_name, {"test": "bla", "bar": "test"}) is True + mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn) + mock_get_conn.return_value.untag_resource.assert_not_called() + mock_get_conn.return_value.tag_resource.assert_called_once_with( + ResourceArn=self.crawler_arn, TagsToAdd={"test": "bla"} + ) + + @mock_sts @mock.patch.object(GlueCrawlerHook, "get_conn") def test_update_crawler_not_needed(self, mock_get_conn): mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config} @@ -152,8 +213,7 @@ def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get } ] } - result = self.hook.wait_for_crawler_completion(mock_crawler_name) - assert result == "MOCK_STATUS" + assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS" mock_get_conn.assert_has_calls( [ mock.call(), @@ -188,8 +248,7 @@ def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_ ] }, ] - result = self.hook.wait_for_crawler_completion(mock_crawler_name) - assert result == "MOCK_STATUS" + assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS" mock_get_conn.assert_has_calls( [ mock.call(), diff --git a/tests/providers/amazon/aws/operators/test_glue_crawler.py b/tests/providers/amazon/aws/operators/test_glue_crawler.py index 1e83318eaf24f..6070e27b3bcac 100644 --- a/tests/providers/amazon/aws/operators/test_glue_crawler.py +++ b/tests/providers/amazon/aws/operators/test_glue_crawler.py @@ -92,7 +92,7 @@ def test_execute_without_failure(self, mock_hook): mock_hook.assert_has_calls( [ - mock.call("aws_default"), + mock.call("aws_default", region_name=None), mock.call().has_crawler("test-crawler"), mock.call().update_crawler(**mock_config), mock.call().start_crawler(mock_crawler_name),