Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GlueCrawlerOperature failure when using tags #28005

Merged
merged 12 commits into from
Dec 6, 2022
54 changes: 50 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sts import StsHook


class GlueCrawlerHook(AwsBaseHook):
Expand Down Expand Up @@ -78,16 +79,61 @@ def update_crawler(self, **crawler_kwargs) -> bool:
crawler_name = crawler_kwargs["Name"]
current_crawler = self.get_crawler(crawler_name)

tags_updated = self.update_tags(crawler_name, crawler_kwargs.pop("Tags", {}))

update_config = {
key: value for key, value in crawler_kwargs.items() if current_crawler[key] != crawler_kwargs[key]
key: value
for key, value in crawler_kwargs.items()
if current_crawler.get(key, None) != crawler_kwargs.get(key)
}
if update_config != {}:
if update_config:
self.log.info("Updating crawler: %s", crawler_name)
self.glue_client.update_crawler(**crawler_kwargs)
self.log.info("Updated configurations: %s", update_config)
return True
else:
return False
return tags_updated

def update_tags(self, crawler_name: str, crawler_tags: dict) -> bool:
"""
Updates crawler tags

:param crawler_name: Name of the crawler for which to update tags
:param crawler_tags: Dictionary of new tags. If empty, all tags will be deleted
:return True if tags were updated and false otherwise
"""
account_number = StsHook(aws_conn_id=self.aws_conn_id).get_account_number()
crawler_arn = (
f"arn:{self.conn_partition}:glue:{self.conn_region_name}:{account_number}:crawler/{crawler_name}"
)
current_crawler_tags: dict = self.glue_client.get_tags(ResourceArn=crawler_arn)["Tags"]

update_tags = {}
delete_tags = []
for key, value in current_crawler_tags.items():
wanted_tag_value = crawler_tags.get(key, None)
if wanted_tag_value is None:
# key is missing from new configuration, mark it for deletion
delete_tags.append(key)
elif wanted_tag_value != value:
update_tags[key] = wanted_tag_value

update_tags = {
key: value
for key, value in crawler_tags.items()
if current_crawler_tags.get(key, None) != crawler_tags.get(key)
}
updated_tags = False
if update_tags:
self.log.info("Updating crawler tags: %s", crawler_name)
self.glue_client.tag_resource(ResourceArn=crawler_arn, TagsToAdd=update_tags)
self.log.info("Updated crawler tags: %s", crawler_name)
updated_tags = True
if delete_tags:
self.log.info("Deleting crawler tags: %s", crawler_name)
self.glue_client.untag_resource(ResourceArn=crawler_arn, TagsToRemove=delete_tags)
self.log.info("Deleted crawler tags: %s", crawler_name)
updated_tags = True
return updated_tags

def create_crawler(self, **crawler_kwargs) -> str:
"""
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self,
config,
aws_conn_id="aws_default",
region_name: str | None = None,
poll_interval: int = 5,
wait_for_completion: bool = True,
**kwargs,
Expand All @@ -58,12 +59,13 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.wait_for_completion = wait_for_completion
self.region_name = region_name
self.config = config

@cached_property
def hook(self) -> GlueCrawlerHook:
"""Create and return an GlueCrawlerHook."""
return GlueCrawlerHook(self.aws_conn_id)
return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
"""
Expand Down
94 changes: 81 additions & 13 deletions tests/providers/amazon/aws/hooks/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from copy import deepcopy
from unittest import mock

from moto import mock_sts
from moto.core import DEFAULT_ACCOUNT_ID

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.hooks.sts import StsHook

AWS_REGION = "us-west-2"

mock_crawler_name = "test-crawler"
mock_role_name = "test-role"
Expand Down Expand Up @@ -79,22 +85,22 @@
}
""",
"SecurityConfiguration": "test",
"Tags": {"test": "foo"},
"Tags": {"test": "foo", "bar": "test"},
}


class TestGlueCrawlerHook(unittest.TestCase):
@classmethod
def setUp(cls):
cls.hook = GlueCrawlerHook(aws_conn_id="aws_default")
cls.hook = GlueCrawlerHook(aws_conn_id="aws_default", region_name=AWS_REGION)

def test_init(self):
self.assertEqual(self.hook.aws_conn_id, "aws_default")
assert self.hook.aws_conn_id == "aws_default"

@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_has_crawler(self, mock_get_conn):
response = self.hook.has_crawler(mock_crawler_name)
self.assertEqual(response, True)
assert response == True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)

@mock.patch.object(GlueCrawlerHook, "get_conn")
Expand All @@ -105,39 +111,101 @@ class MockException(Exception):
mock_get_conn.return_value.exceptions.EntityNotFoundException = MockException
mock_get_conn.return_value.get_crawler.side_effect = MockException("AAA")
response = self.hook.has_crawler(mock_crawler_name)
self.assertEqual(response, False)
assert not response
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}

mock_config_two = deepcopy(mock_config)
mock_config_two["Role"] = "test-2-role"
mock_config_two.pop("Tags")
response = self.hook.update_crawler(**mock_config_two)
assert response == True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_missing_keys(self, mock_get_conn):
mock_config_missing_configuration = deepcopy(mock_config)
mock_config_missing_configuration.pop("Configuration")
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config_missing_configuration}

mock_config_two = deepcopy(mock_config)
mock_config_two.pop("Tags")
response = self.hook.update_crawler(**mock_config_two)
self.assertEqual(response, True)
assert response, True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_tags_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}
crawler_arn = f"arn:aws:glue:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:crawler/{mock_crawler_name}"

response = self.hook.update_tags(mock_crawler_name, mock_config["Tags"])
assert not response
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=crawler_arn)
mock_get_conn.return_value.tag_resource.assert_not_called()
mock_get_conn.return_value.untag_resource.assert_not_called()

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_remove_all_tags(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}
crawler_arn = f"arn:aws:glue:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:crawler/{mock_crawler_name}"

response = self.hook.update_tags(mock_crawler_name, {})
assert response == True
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=crawler_arn)
mock_get_conn.return_value.tag_resource.assert_not_called()
mock_get_conn.return_value.untag_resource.assert_called_once_with(
ResourceArn=crawler_arn, TagsToRemove=["test", "bar"]
)

@mock_sts
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_replace_tag(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}
crawler_arn = f"arn:aws:glue:{AWS_REGION}:{DEFAULT_ACCOUNT_ID}:crawler/{mock_crawler_name}"

mock_config_two = deepcopy(mock_config)
mock_config_two.pop("Tags")
response = self.hook.update_tags(mock_crawler_name, {"test": "bla", "bar": "test"})
assert response == True
mock_get_conn.return_value.get_tags.assert_called_once_with(ResourceArn=crawler_arn)
mock_get_conn.return_value.untag_resource.assert_not_called()
mock_get_conn.return_value.tag_resource.assert_called_once_with(
ResourceArn=crawler_arn, TagsToAdd={"test": "bla"}
)

@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
mock_get_conn.return_value.get_tags.return_value = {"Tags": mock_config["Tags"]}
response = self.hook.update_crawler(**mock_config)
self.assertEqual(response, False)
assert not response
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)

@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_create_crawler(self, mock_get_conn):
mock_get_conn.return_value.create_crawler.return_value = {"Crawler": {"Name": mock_crawler_name}}
glue_crawler = self.hook.create_crawler(**mock_config)
self.assertIn("Crawler", glue_crawler)
self.assertIn("Name", glue_crawler["Crawler"])
self.assertEqual(glue_crawler["Crawler"]["Name"], mock_crawler_name)
assert "Crawler" in glue_crawler
assert "Name" in glue_crawler["Crawler"]
assert glue_crawler["Crawler"]["Name"] == mock_crawler_name

@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_start_crawler(self, mock_get_conn):
result = self.hook.start_crawler(mock_crawler_name)
self.assertEqual(result, mock_get_conn.return_value.start_crawler.return_value)
assert result == mock_get_conn.return_value.start_crawler.return_value

mock_get_conn.return_value.start_crawler.assert_called_once_with(Name=mock_crawler_name)

Expand All @@ -159,7 +227,7 @@ def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get
]
}
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
self.assertEqual(result, "MOCK_STATUS")
assert result == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
Expand Down Expand Up @@ -195,7 +263,7 @@ def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_
},
]
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
self.assertEqual(result, "MOCK_STATUS")
assert result == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_execute_without_failure(self, mock_hook):

mock_hook.assert_has_calls(
[
mock.call("aws_default"),
mock.call("aws_default", region_name=None),
mock.call().has_crawler("test-crawler"),
mock.call().update_crawler(**mock_config),
mock.call().start_crawler(mock_crawler_name),
Expand Down