Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bunch of deprecation warnings AWS tests #26857

Merged
merged 7 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def poll_query_status(
"""
if max_tries:
warnings.warn(
f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
"in a future release. Please use method `max_polling_attempts` instead.",
f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated "
f"and will be removed in a future release. Please use 'max_polling_attempts' instead.",
DeprecationWarning,
stacklevel=2,
)
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def get_uri(self) -> str:
if 'user' in conn_params:
conn_params['username'] = conn_params.pop('user')

return str(URL(drivername='redshift+redshift_connector', **conn_params))
# Compatibility: The 'create' factory method was added in SQLAlchemy 1.4
# to replace calling the default URL constructor directly.
create_url = getattr(URL, "create", URL)
return str(create_url(drivername='redshift+redshift_connector', **conn_params))

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""Overrides DbApiHook get_sqlalchemy_engine to pass redshift_connector specific kwargs"""
Expand Down
13 changes: 6 additions & 7 deletions airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def _render_filename(self, ti, try_number):
def set_context(self, ti):
super().set_context(ti)
self.handler = watchtower.CloudWatchLogHandler(
log_group=self.log_group,
stream_name=self._render_filename(ti, ti.try_number),
log_group_name=self.log_group,
log_stream_name=self._render_filename(ti, ti.try_number),
boto3_client=self.hook.get_conn(),
)

Expand Down Expand Up @@ -98,12 +98,11 @@ def get_cloudwatch_logs(self, stream_name: str) -> str:
:return: string of all logs from the given log stream
"""
try:
events = list(
self.hook.get_log_events(
log_group=self.log_group, log_stream_name=stream_name, start_from_head=True
)
events = self.hook.get_log_events(
log_group=self.log_group,
log_stream_name=stream_name,
start_from_head=True,
)

return '\n'.join(self._event_to_str(event) for event in events)
except Exception:
msg = f'Could not read remote logs from log_group: {self.log_group} log_stream: {stream_name}.'
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def execute(self, context: Context) -> str | None:
self.client_request_token,
self.workgroup,
)
query_status = self.hook.poll_query_status(self.query_execution_id, self.max_polling_attempts)
query_status = self.hook.poll_query_status(
self.query_execution_id,
max_polling_attempts=self.max_polling_attempts,
)

if query_status in AthenaHook.FAILURE_STATES:
error_message = self.hook.get_state_change_reason(self.query_execution_id)
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def execute(self, context: Context) -> str | None:
)
if self.wait_for_completion:
query_status = self.hook.poll_query_status(
self.job_id, self.max_polling_attempts, self.poll_interval
self.job_id,
max_polling_attempts=self.max_polling_attempts,
poll_interval=self.poll_interval,
)

if query_status in EmrContainerHook.FAILURE_STATES:
Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,19 +330,19 @@ def _remove_escaping_in_secret_dict(self, secret: dict[str, Any], conn_id: str)

if warn_user:
msg = (
"When ``full_url_mode=True``, URL-encoding secret values is deprecated. In future versions, "
f" this value will not be un-escaped. For the conn_id {conn_id!r}, please remove the"
" URL-encoding."
"\n\nThis warning was raised because the SecretsManagerBackend detected that this connection"
" was URL-encoded."
"When full_url_mode=False, URL-encoding secret values is deprecated. In future versions, "
f"this value will not be un-escaped. For the conn_id {conn_id!r}, please remove the "
"URL-encoding.\n\n"
"This warning was raised because the SecretsManagerBackend detected that this "
"connection was URL-encoded."
)
if idempotent:
msg = f" Once the values for conn_id {conn_id!r} are decoded, this warning will go away."
if not idempotent:
msg += (
" In addition to decoding the values for your connection, you must also set"
" ``secret_values_are_urlencoded=False`` for your config variable"
" ``secrets.backend_kwargs`` because this connection's URL encoding is not idempotent."
" secret_values_are_urlencoded=False for your config variable"
" secrets.backend_kwargs because this connection's URL encoding is not idempotent."
" For more information, see:"
" https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/secrets-backends"
"/aws-secrets-manager.html#url-encoding-of-secrets-when-full-url-mode-is-false"
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ def __init__(
self.max_retries = max_retries

def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(self.job_id, self.max_retries, self.poll_interval)
state = self.hook.poll_query_status(
self.job_id,
max_polling_attempts=self.max_retries,
poll_interval=self.poll_interval,
)

if state in self.FAILURE_STATES:
raise AirflowException('EMR Containers sensor failed')
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_hook_poll_query_when_final(self, mock_conn):
def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.poll_query_status(
query_execution_id=MOCK_DATA['query_execution_id'], max_tries=1
query_execution_id=MOCK_DATA['query_execution_id'],
max_polling_attempts=1,
)
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == 'RUNNING'
Expand Down
6 changes: 6 additions & 0 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def setUp(self, get_client_type_mock):
aws_conn_id='airflow_test',
region_name=AWS_REGION,
)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
self.batch_client.get_connection = lambda _: None
self.client_mock = get_client_type_mock.return_value
assert self.batch_client.client == self.client_mock # setup client property

Expand Down Expand Up @@ -307,6 +310,9 @@ class TestBatchClientDelays(unittest.TestCase):
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
def setUp(self):
self.batch_client = BatchClientHook(aws_conn_id='airflow_test', region_name=AWS_REGION)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
self.batch_client.get_connection = lambda _: None

def test_init(self):
assert self.batch_client.max_retries == self.batch_client.MAX_RETRIES
Expand Down
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/hooks/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def test_init(self):
aws_conn_id="aws_conn_test",
region_name="region-test",
)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
ec2_hook.get_connection = lambda _: None
assert ec2_hook.aws_conn_id == "aws_conn_test"
assert ec2_hook.region_name == "region-test"

Expand Down
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,9 @@ def test_generate_config_file(self, mock_conn, aws_conn_id, region_name, expecte
'cluster': {'certificateAuthority': {'data': 'test-cert'}, 'endpoint': 'test-endpoint'}
}
hook = EksHook(aws_conn_id=aws_conn_id, region_name=region_name)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
hook.get_connection = lambda _: None
with hook.generate_config_file(
eks_cluster_name='test-cluster', pod_namespace='k8s-namespace'
) as config_file:
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/hooks/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_query_status_polling_with_timeout(self, mock_session):
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(job_id='job123456', max_tries=2)
query_status = self.emr_containers.poll_query_status(job_id='job123456', max_polling_attempts=2)
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert query_status == 'RUNNING'
6 changes: 6 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def test_check_for_bucket_raises_error_with_invalid_conn_id(self, monkeypatch):
monkeypatch.delenv('AWS_ACCESS_KEY_ID', raising=False)
monkeypatch.delenv('AWS_SECRET_ACCESS_KEY', raising=False)
hook = S3Hook(aws_conn_id="does_not_exist")
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
hook.get_connection = lambda _: None
with pytest.raises(NoCredentialsError):
hook.check_for_bucket("test-non-existing-bucket")

Expand Down Expand Up @@ -587,6 +590,9 @@ def test_that_extra_args_not_changed_between_calls(self, s3_bucket):
"SSEKMSKeyId": "arn:aws:kms:region:acct-id:key/key-id",
}
s3_hook = S3Hook(aws_conn_id="s3_test", extra_args=original)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
s3_hook.get_connection = lambda _: None
assert s3_hook.extra_args == original
assert s3_hook.extra_args is not original

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self):
output_location='s3://test_s3_bucket/',
client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
sleep_time=0,
max_tries=3,
max_polling_attempts=3,
dag=self.dag,
)

Expand Down
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def setUp(self, get_client_type_mock):
tags={},
)
self.client_mock = self.get_client_type_mock.return_value
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
self.batch.hook.get_connection = lambda _: None
assert self.batch.hook.client == self.client_mock # setup client property

# don't pause in unit tests
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/amazon/aws/operators/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytest

from airflow import configuration
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator
Expand All @@ -41,7 +41,7 @@
class TestEmrContainerOperator(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.hooks.emr.EmrContainerHook')
def setUp(self, emr_hook_mock):
configuration.load_test_config()
conf.load_test_config()

self.emr_hook_mock = emr_hook_mock
self.emr_container = EmrContainerOperator(
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_execute_with_polling_timeout(self, mock_check_query_status):
job_driver={},
configuration_overrides={},
poll_interval=0,
max_tries=3,
max_polling_attempts=3,
)

with patch('boto3.session.Session', boto3_session_mock):
Expand All @@ -145,7 +145,7 @@ def test_execute_with_polling_timeout(self, mock_check_query_status):
class TestEmrEksCreateClusterOperator(unittest.TestCase):
@mock.patch('airflow.providers.amazon.aws.hooks.emr.EmrContainerHook')
def setUp(self, emr_hook_mock):
configuration.load_test_config()
conf.load_test_config()

self.emr_hook_mock = emr_hook_mock
self.emr_container = EmrEksCreateClusterOperator(
Expand Down
Loading