Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GlueCrawlerOperature failure when using tags #28005

Merged
merged 12 commits into from
Dec 6, 2022
49 changes: 45 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sts import StsHook


class GlueCrawlerHook(AwsBaseHook):
Expand Down Expand Up @@ -78,16 +79,56 @@ def update_crawler(self, **crawler_kwargs) -> bool:
crawler_name = crawler_kwargs["Name"]
current_crawler = self.get_crawler(crawler_name)

tags_updated = self.update_tags(crawler_name, crawler_kwargs.pop("Tags", {}))

update_config = {
key: value for key, value in crawler_kwargs.items() if current_crawler[key] != crawler_kwargs[key]
key: value
for key, value in crawler_kwargs.items()
if current_crawler.get(key, None) != crawler_kwargs.get(key)
}
if update_config != {}:
if update_config:
self.log.info("Updating crawler: %s", crawler_name)
self.glue_client.update_crawler(**crawler_kwargs)
self.log.info("Updated configurations: %s", update_config)
return True
else:
return False
return tags_updated

def update_tags(self, crawler_name: str, crawler_tags: dict) -> bool:
"""
Updates crawler tags

:param crawler_name: Name of the crawler for which to update tags
:param crawler_tags: Dictionary of new tags. If empty, all tags will be deleted
:return: True if tags were updated and false otherwise
"""
account_number = StsHook(aws_conn_id=self.aws_conn_id).get_account_number()
crawler_arn = (
f"arn:{self.conn_partition}:glue:{self.conn_region_name}:{account_number}:crawler/{crawler_name}"
)
current_crawler_tags: dict = self.glue_client.get_tags(ResourceArn=crawler_arn)["Tags"]

update_tags = {}
delete_tags = []
for key, value in current_crawler_tags.items():
wanted_tag_value = crawler_tags.get(key, None)
if wanted_tag_value is None:
# key is missing from new configuration, mark it for deletion
delete_tags.append(key)
elif wanted_tag_value != value:
update_tags[key] = wanted_tag_value

updated_tags = False
if update_tags:
self.log.info("Updating crawler tags: %s", crawler_name)
self.glue_client.tag_resource(ResourceArn=crawler_arn, TagsToAdd=update_tags)
self.log.info("Updated crawler tags: %s", crawler_name)
updated_tags = True
if delete_tags:
self.log.info("Deleting crawler tags: %s", crawler_name)
self.glue_client.untag_resource(ResourceArn=crawler_arn, TagsToRemove=delete_tags)
self.log.info("Deleted crawler tags: %s", crawler_name)
updated_tags = True
return updated_tags

def create_crawler(self, **crawler_kwargs) -> str:
"""
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self,
config,
aws_conn_id="aws_default",
region_name: str | None = None,
poll_interval: int = 5,
wait_for_completion: bool = True,
**kwargs,
Expand All @@ -58,12 +59,13 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.wait_for_completion = wait_for_completion
self.region_name = region_name
self.config = config

@cached_property
def hook(self) -> GlueCrawlerHook:
"""Create and return an GlueCrawlerHook."""
return GlueCrawlerHook(self.aws_conn_id)
return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
"""
Expand Down
69 changes: 64 additions & 5 deletions tests/providers/amazon/aws/hooks/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from copy import deepcopy
from unittest import mock

from moto import mock_sts
from moto.core import DEFAULT_ACCOUNT_ID

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook

mock_crawler_name = "test-crawler"
Expand Down Expand Up @@ -78,13 +81,16 @@
}
""",
"SecurityConfiguration": "test",
"Tags": {"test": "foo"},
"Tags": {"test": "foo", "bar": "test"},
}


class TestGlueCrawlerHook:
def setup_method(self):
self.hook = GlueCrawlerHook(aws_conn_id="aws_default")
self.crawler_arn = (
f"arn:aws:glue:{self.hook.conn_region_name}:{DEFAULT_ACCOUNT_ID}:crawler/{mock_crawler_name}"
)

def test_init(self):
assert self.hook.aws_conn_id == "aws_default"
Expand All @@ -104,16 +110,71 @@ class MockException(Exception):
assert self.hook.has_crawler(mock_crawler_name) is False
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}

mock_config_two = deepcopy(mock_config)
mock_config_two["Role"] = "test-2-role"
mock_config_two.pop("Tags")
assert self.hook.update_crawler(**mock_config_two) is True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_missing_keys(self, mock_get_conn):
mock_config_missing_configuration = deepcopy(mock_config)
mock_config_missing_configuration.pop("Configuration")
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config_missing_configuration}

mock_config_two = deepcopy(mock_config)
mock_config_two.pop("Tags")
assert self.hook.update_crawler(**mock_config_two) is True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_tags_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}

assert self.hook.update_tags(mock_crawler_name, mock_config["Tags"]) is False
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn)
mock_get_conn.return_value.tag_resource.assert_not_called()
mock_get_conn.return_value.untag_resource.assert_not_called()

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_remove_all_tags(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}

assert self.hook.update_tags(mock_crawler_name, {}) is True
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn)
mock_get_conn.return_value.tag_resource.assert_not_called()
mock_get_conn.return_value.untag_resource.assert_called_once_with(
ResourceArn=self.crawler_arn, TagsToRemove=["test", "bar"]
)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_replace_tag(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}

mock_config_two = deepcopy(mock_config)
mock_config_two.pop("Tags")
assert self.hook.update_tags(mock_crawler_name, {"test": "bla", "bar": "test"}) is True
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn)
mock_get_conn.return_value.untag_resource.assert_not_called()
mock_get_conn.return_value.tag_resource.assert_called_once_with(
ResourceArn=self.crawler_arn, TagsToAdd={"test": "bla"}
)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
Expand Down Expand Up @@ -152,8 +213,7 @@ def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get
}
]
}
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
assert result == "MOCK_STATUS"
assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
Expand Down Expand Up @@ -188,8 +248,7 @@ def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_
]
},
]
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
assert result == "MOCK_STATUS"
assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_execute_without_failure(self, mock_hook):

mock_hook.assert_has_calls(
[
mock.call("aws_default"),
mock.call("aws_default", region_name=None),
mock.call().has_crawler("test-crawler"),
mock.call().update_crawler(**mock_config),
mock.call().start_crawler(mock_crawler_name),
Expand Down