Skip to content

Commit

Permalink
Fix GlueCrawlerOperature failure when using tags (#28005)
Browse files Browse the repository at this point in the history
  • Loading branch information
IAL32 authored Dec 6, 2022
1 parent b3d7e17 commit 3ee5c40
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
49 changes: 45 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sts import StsHook


class GlueCrawlerHook(AwsBaseHook):
Expand Down Expand Up @@ -78,16 +79,56 @@ def update_crawler(self, **crawler_kwargs) -> bool:
crawler_name = crawler_kwargs["Name"]
current_crawler = self.get_crawler(crawler_name)

tags_updated = self.update_tags(crawler_name, crawler_kwargs.pop("Tags", {}))

update_config = {
key: value for key, value in crawler_kwargs.items() if current_crawler[key] != crawler_kwargs[key]
key: value
for key, value in crawler_kwargs.items()
if current_crawler.get(key, None) != crawler_kwargs.get(key)
}
if update_config != {}:
if update_config:
self.log.info("Updating crawler: %s", crawler_name)
self.glue_client.update_crawler(**crawler_kwargs)
self.log.info("Updated configurations: %s", update_config)
return True
else:
return False
return tags_updated

def update_tags(self, crawler_name: str, crawler_tags: dict) -> bool:
"""
Updates crawler tags
:param crawler_name: Name of the crawler for which to update tags
:param crawler_tags: Dictionary of new tags. If empty, all tags will be deleted
:return: True if tags were updated and false otherwise
"""
account_number = StsHook(aws_conn_id=self.aws_conn_id).get_account_number()
crawler_arn = (
f"arn:{self.conn_partition}:glue:{self.conn_region_name}:{account_number}:crawler/{crawler_name}"
)
current_crawler_tags: dict = self.glue_client.get_tags(ResourceArn=crawler_arn)["Tags"]

update_tags = {}
delete_tags = []
for key, value in current_crawler_tags.items():
wanted_tag_value = crawler_tags.get(key, None)
if wanted_tag_value is None:
# key is missing from new configuration, mark it for deletion
delete_tags.append(key)
elif wanted_tag_value != value:
update_tags[key] = wanted_tag_value

updated_tags = False
if update_tags:
self.log.info("Updating crawler tags: %s", crawler_name)
self.glue_client.tag_resource(ResourceArn=crawler_arn, TagsToAdd=update_tags)
self.log.info("Updated crawler tags: %s", crawler_name)
updated_tags = True
if delete_tags:
self.log.info("Deleting crawler tags: %s", crawler_name)
self.glue_client.untag_resource(ResourceArn=crawler_arn, TagsToRemove=delete_tags)
self.log.info("Deleted crawler tags: %s", crawler_name)
updated_tags = True
return updated_tags

def create_crawler(self, **crawler_kwargs) -> str:
"""
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self,
config,
aws_conn_id="aws_default",
region_name: str | None = None,
poll_interval: int = 5,
wait_for_completion: bool = True,
**kwargs,
Expand All @@ -58,12 +59,13 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.wait_for_completion = wait_for_completion
self.region_name = region_name
self.config = config

@cached_property
def hook(self) -> GlueCrawlerHook:
"""Create and return an GlueCrawlerHook."""
return GlueCrawlerHook(self.aws_conn_id)
return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
"""
Expand Down
69 changes: 64 additions & 5 deletions tests/providers/amazon/aws/hooks/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from copy import deepcopy
from unittest import mock

from moto import mock_sts
from moto.core import DEFAULT_ACCOUNT_ID

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook

mock_crawler_name = "test-crawler"
Expand Down Expand Up @@ -78,13 +81,16 @@
}
""",
"SecurityConfiguration": "test",
"Tags": {"test": "foo"},
"Tags": {"test": "foo", "bar": "test"},
}


class TestGlueCrawlerHook:
def setup_method(self):
self.hook = GlueCrawlerHook(aws_conn_id="aws_default")
self.crawler_arn = (
f"arn:aws:glue:{self.hook.conn_region_name}:{DEFAULT_ACCOUNT_ID}:crawler/{mock_crawler_name}"
)

def test_init(self):
assert self.hook.aws_conn_id == "aws_default"
Expand All @@ -104,16 +110,71 @@ class MockException(Exception):
assert self.hook.has_crawler(mock_crawler_name) is False
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}

mock_config_two = deepcopy(mock_config)
mock_config_two["Role"] = "test-2-role"
mock_config_two.pop("Tags")
assert self.hook.update_crawler(**mock_config_two) is True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_missing_keys(self, mock_get_conn):
mock_config_missing_configuration = deepcopy(mock_config)
mock_config_missing_configuration.pop("Configuration")
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config_missing_configuration}

mock_config_two = deepcopy(mock_config)
mock_config_two.pop("Tags")
assert self.hook.update_crawler(**mock_config_two) is True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_tags_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}

assert self.hook.update_tags(mock_crawler_name, mock_config["Tags"]) is False
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn)
mock_get_conn.return_value.tag_resource.assert_not_called()
mock_get_conn.return_value.untag_resource.assert_not_called()

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_remove_all_tags(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}

assert self.hook.update_tags(mock_crawler_name, {}) is True
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn)
mock_get_conn.return_value.tag_resource.assert_not_called()
mock_get_conn.return_value.untag_resource.assert_called_once_with(
ResourceArn=self.crawler_arn, TagsToRemove=["test", "bar"]
)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_replace_tag(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}

mock_config_two = deepcopy(mock_config)
mock_config_two.pop("Tags")
assert self.hook.update_tags(mock_crawler_name, {"test": "bla", "bar": "test"}) is True
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=self.crawler_arn)
mock_get_conn.return_value.untag_resource.assert_not_called()
mock_get_conn.return_value.tag_resource.assert_called_once_with(
ResourceArn=self.crawler_arn, TagsToAdd={"test": "bla"}
)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
Expand Down Expand Up @@ -152,8 +213,7 @@ def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get
}
]
}
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
assert result == "MOCK_STATUS"
assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
Expand Down Expand Up @@ -188,8 +248,7 @@ def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_
]
},
]
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
assert result == "MOCK_STATUS"
assert self.hook.wait_for_crawler_completion(mock_crawler_name) == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_execute_without_failure(self, mock_hook):

mock_hook.assert_has_calls(
[
mock.call("aws_default"),
mock.call("aws_default", region_name=None),
mock.call().has_crawler("test-crawler"),
mock.call().update_crawler(**mock_config),
mock.call().start_crawler(mock_crawler_name),
Expand Down

0 comments on commit 3ee5c40

Please sign in to comment.