Skip to content

Commit

Permalink
Fix a bunch of deprecation warnings AWS tests (#26857)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Oct 6, 2022
1 parent 98b283c commit 6dd4593
Show file tree
Hide file tree
Showing 24 changed files with 172 additions and 93 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def poll_query_status(
"""
if max_tries:
warnings.warn(
f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
"in a future release. Please use method `max_polling_attempts` instead.",
f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated "
f"and will be removed in a future release. Please use 'max_polling_attempts' instead.",
DeprecationWarning,
stacklevel=2,
)
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def get_uri(self) -> str:
if 'user' in conn_params:
conn_params['username'] = conn_params.pop('user')

return str(URL(drivername='redshift+redshift_connector', **conn_params))
# Compatibility: The 'create' factory method was added in SQLAlchemy 1.4
# to replace calling the default URL constructor directly.
create_url = getattr(URL, "create", URL)
return str(create_url(drivername='redshift+redshift_connector', **conn_params))

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""Overrides DbApiHook get_sqlalchemy_engine to pass redshift_connector specific kwargs"""
Expand Down
13 changes: 6 additions & 7 deletions airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def _render_filename(self, ti, try_number):
def set_context(self, ti):
super().set_context(ti)
self.handler = watchtower.CloudWatchLogHandler(
log_group=self.log_group,
stream_name=self._render_filename(ti, ti.try_number),
log_group_name=self.log_group,
log_stream_name=self._render_filename(ti, ti.try_number),
boto3_client=self.hook.get_conn(),
)

Expand Down Expand Up @@ -98,12 +98,11 @@ def get_cloudwatch_logs(self, stream_name: str) -> str:
:return: string of all logs from the given log stream
"""
try:
events = list(
self.hook.get_log_events(
log_group=self.log_group, log_stream_name=stream_name, start_from_head=True
)
events = self.hook.get_log_events(
log_group=self.log_group,
log_stream_name=stream_name,
start_from_head=True,
)

return '\n'.join(self._event_to_str(event) for event in events)
except Exception:
msg = f'Could not read remote logs from log_group: {self.log_group} log_stream: {stream_name}.'
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def execute(self, context: Context) -> str | None:
self.client_request_token,
self.workgroup,
)
query_status = self.hook.poll_query_status(self.query_execution_id, self.max_polling_attempts)
query_status = self.hook.poll_query_status(
self.query_execution_id,
max_polling_attempts=self.max_polling_attempts,
)

if query_status in AthenaHook.FAILURE_STATES:
error_message = self.hook.get_state_change_reason(self.query_execution_id)
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def execute(self, context: Context) -> str | None:
)
if self.wait_for_completion:
query_status = self.hook.poll_query_status(
self.job_id, self.max_polling_attempts, self.poll_interval
self.job_id,
max_polling_attempts=self.max_polling_attempts,
poll_interval=self.poll_interval,
)

if query_status in EmrContainerHook.FAILURE_STATES:
Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,19 +330,19 @@ def _remove_escaping_in_secret_dict(self, secret: dict[str, Any], conn_id: str)

if warn_user:
msg = (
"When ``full_url_mode=True``, URL-encoding secret values is deprecated. In future versions, "
f" this value will not be un-escaped. For the conn_id {conn_id!r}, please remove the"
" URL-encoding."
"\n\nThis warning was raised because the SecretsManagerBackend detected that this connection"
" was URL-encoded."
"When full_url_mode=False, URL-encoding secret values is deprecated. In future versions, "
f"this value will not be un-escaped. For the conn_id {conn_id!r}, please remove the "
"URL-encoding.\n\n"
"This warning was raised because the SecretsManagerBackend detected that this "
"connection was URL-encoded."
)
if idempotent:
msg = f" Once the values for conn_id {conn_id!r} are decoded, this warning will go away."
if not idempotent:
msg += (
" In addition to decoding the values for your connection, you must also set"
" ``secret_values_are_urlencoded=False`` for your config variable"
" ``secrets.backend_kwargs`` because this connection's URL encoding is not idempotent."
" secret_values_are_urlencoded=False for your config variable"
" secrets.backend_kwargs because this connection's URL encoding is not idempotent."
" For more information, see:"
" https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/secrets-backends"
"/aws-secrets-manager.html#url-encoding-of-secrets-when-full-url-mode-is-false"
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ def __init__(
self.max_retries = max_retries

def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(self.job_id, self.max_retries, self.poll_interval)
state = self.hook.poll_query_status(
self.job_id,
max_polling_attempts=self.max_retries,
poll_interval=self.poll_interval,
)

if state in self.FAILURE_STATES:
raise AirflowException('EMR Containers sensor failed')
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_hook_poll_query_when_final(self, mock_conn):
def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.poll_query_status(
query_execution_id=MOCK_DATA['query_execution_id'], max_tries=1
query_execution_id=MOCK_DATA['query_execution_id'],
max_polling_attempts=1,
)
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == 'RUNNING'
Expand Down
6 changes: 6 additions & 0 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def setUp(self, get_client_type_mock):
aws_conn_id='airflow_test',
region_name=AWS_REGION,
)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
self.batch_client.get_connection = lambda _: None
self.client_mock = get_client_type_mock.return_value
assert self.batch_client.client == self.client_mock # setup client property

Expand Down Expand Up @@ -307,6 +310,9 @@ class TestBatchClientDelays(unittest.TestCase):
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
def setUp(self):
self.batch_client = BatchClientHook(aws_conn_id='airflow_test', region_name=AWS_REGION)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
self.batch_client.get_connection = lambda _: None

def test_init(self):
assert self.batch_client.max_retries == self.batch_client.MAX_RETRIES
Expand Down
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/hooks/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def test_init(self):
aws_conn_id="aws_conn_test",
region_name="region-test",
)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
ec2_hook.get_connection = lambda _: None
assert ec2_hook.aws_conn_id == "aws_conn_test"
assert ec2_hook.region_name == "region-test"

Expand Down
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,9 @@ def test_generate_config_file(self, mock_conn, aws_conn_id, region_name, expecte
'cluster': {'certificateAuthority': {'data': 'test-cert'}, 'endpoint': 'test-endpoint'}
}
hook = EksHook(aws_conn_id=aws_conn_id, region_name=region_name)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
hook.get_connection = lambda _: None
with hook.generate_config_file(
eks_cluster_name='test-cluster', pod_namespace='k8s-namespace'
) as config_file:
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/hooks/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_query_status_polling_with_timeout(self, mock_session):
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(job_id='job123456', max_tries=2)
query_status = self.emr_containers.poll_query_status(job_id='job123456', max_polling_attempts=2)
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert query_status == 'RUNNING'
6 changes: 6 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def test_check_for_bucket_raises_error_with_invalid_conn_id(self, monkeypatch):
monkeypatch.delenv('AWS_ACCESS_KEY_ID', raising=False)
monkeypatch.delenv('AWS_SECRET_ACCESS_KEY', raising=False)
hook = S3Hook(aws_conn_id="does_not_exist")
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
hook.get_connection = lambda _: None
with pytest.raises(NoCredentialsError):
hook.check_for_bucket("test-non-existing-bucket")

Expand Down Expand Up @@ -587,6 +590,9 @@ def test_that_extra_args_not_changed_between_calls(self, s3_bucket):
"SSEKMSKeyId": "arn:aws:kms:region:acct-id:key/key-id",
}
s3_hook = S3Hook(aws_conn_id="s3_test", extra_args=original)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
s3_hook.get_connection = lambda _: None
assert s3_hook.extra_args == original
assert s3_hook.extra_args is not original

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self):
output_location='s3://test_s3_bucket/',
client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
sleep_time=0,
max_tries=3,
max_polling_attempts=3,
dag=self.dag,
)

Expand Down
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def setUp(self, get_client_type_mock):
tags={},
)
self.client_mock = self.get_client_type_mock.return_value
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
self.batch.hook.get_connection = lambda _: None
assert self.batch.hook.client == self.client_mock # setup client property

# don't pause in unit tests
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/amazon/aws/operators/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytest

from airflow import configuration
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator
Expand All @@ -41,7 +41,7 @@
class TestEmrContainerOperator(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.hooks.emr.EmrContainerHook')
def setUp(self, emr_hook_mock):
configuration.load_test_config()
conf.load_test_config()

self.emr_hook_mock = emr_hook_mock
self.emr_container = EmrContainerOperator(
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_execute_with_polling_timeout(self, mock_check_query_status):
job_driver={},
configuration_overrides={},
poll_interval=0,
max_tries=3,
max_polling_attempts=3,
)

with patch('boto3.session.Session', boto3_session_mock):
Expand All @@ -145,7 +145,7 @@ def test_execute_with_polling_timeout(self, mock_check_query_status):
class TestEmrEksCreateClusterOperator(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.hooks.emr.EmrContainerHook')
def setUp(self, emr_hook_mock):
configuration.load_test_config()
conf.load_test_config()

self.emr_hook_mock = emr_hook_mock
self.emr_container = EmrEksCreateClusterOperator(
Expand Down
Loading

0 comments on commit 6dd4593

Please sign in to comment.