From 76e2b72df868c5508f65e1114bc067840fa730ba Mon Sep 17 00:00:00 2001 From: Steven Blake <9101623+slycyberguy@users.noreply.github.com> Date: Mon, 15 Apr 2024 10:27:50 -0700 Subject: [PATCH] Add six unit tests for aws/triggers (#38819) * Remove test references from test_project_structure.py --- tests/always/test_project_structure.py | 6 - .../amazon/aws/triggers/test_athena.py | 42 ++ .../amazon/aws/triggers/test_batch.py | 43 ++ .../providers/amazon/aws/triggers/test_emr.py | 377 ++++++++++++++++++ .../amazon/aws/triggers/test_glue_crawler.py | 49 +++ .../aws/triggers/test_lambda_function.py | 48 +++ .../providers/amazon/aws/triggers/test_rds.py | 121 ++++++ 7 files changed, 680 insertions(+), 6 deletions(-) create mode 100644 tests/providers/amazon/aws/triggers/test_athena.py create mode 100644 tests/providers/amazon/aws/triggers/test_batch.py create mode 100644 tests/providers/amazon/aws/triggers/test_emr.py create mode 100644 tests/providers/amazon/aws/triggers/test_glue_crawler.py create mode 100644 tests/providers/amazon/aws/triggers/test_lambda_function.py create mode 100644 tests/providers/amazon/aws/triggers/test_rds.py diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index cd8e68594ce40..d14885aca9fe6 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -72,13 +72,7 @@ def test_providers_modules_should_have_tests(self): "tests/providers/amazon/aws/sensors/test_emr.py", "tests/providers/amazon/aws/sensors/test_sagemaker.py", "tests/providers/amazon/aws/test_exceptions.py", - "tests/providers/amazon/aws/triggers/test_athena.py", - "tests/providers/amazon/aws/triggers/test_batch.py", "tests/providers/amazon/aws/triggers/test_eks.py", - "tests/providers/amazon/aws/triggers/test_emr.py", - "tests/providers/amazon/aws/triggers/test_glue_crawler.py", - "tests/providers/amazon/aws/triggers/test_lambda_function.py", - "tests/providers/amazon/aws/triggers/test_rds.py", "tests/providers/amazon/aws/triggers/test_step_function.py", "tests/providers/amazon/aws/utils/test_rds.py", "tests/providers/amazon/aws/utils/test_sagemaker.py", diff --git a/tests/providers/amazon/aws/triggers/test_athena.py b/tests/providers/amazon/aws/triggers/test_athena.py new file mode 100644 index 0000000000000..02e0ef9237061 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_athena.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger + + +class TestAthenaTrigger: + def test_serialization(self): + query_execution_id = "test_query_execution_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = AthenaTrigger( + query_execution_id=query_execution_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.athena.AthenaTrigger" + assert kwargs == { + "query_execution_id": "test_query_execution_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py new file mode 100644 index 0000000000000..ef6e22e965d4b --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger + + +class TestBatchJobTrigger: + def test_serialization(self): + job_id = "test_job_id" + aws_conn_id = "aws_default" + region_name = "us-west-2" + + trigger = BatchJobTrigger( + job_id=job_id, + aws_conn_id=aws_conn_id, + region_name=region_name, + ) + + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.batch.BatchJobTrigger" + assert kwargs == { + "job_id": "test_job_id", + "waiter_delay": 5, + "waiter_max_attempts": 720, + "aws_conn_id": "aws_default", + "region_name": "us-west-2", + } diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py new file mode 100644 index 0000000000000..92fd08857d1d2 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -0,0 +1,377 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.triggers.emr import ( + EmrAddStepsTrigger, + EmrContainerTrigger, + EmrCreateJobFlowTrigger, + EmrServerlessCancelJobsTrigger, + EmrServerlessCreateApplicationTrigger, + EmrServerlessDeleteApplicationTrigger, + EmrServerlessStartApplicationTrigger, + EmrServerlessStartJobTrigger, + EmrServerlessStopApplicationTrigger, + EmrStepSensorTrigger, + EmrTerminateJobFlowTrigger, +) + + +class TestEmrAddStepsTrigger: + def test_serialization(self): + job_flow_id = "test_job_flow_id" + step_ids = ["step1", "step2"] + waiter_delay = 10 + waiter_max_attempts = 5 + + trigger = EmrAddStepsTrigger( + job_flow_id=job_flow_id, + step_ids=step_ids, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger" + assert kwargs == { + "job_flow_id": "test_job_flow_id", + "step_ids": ["step1", "step2"], + "waiter_delay": 10, + "waiter_max_attempts": 5, + "aws_conn_id": "aws_default", + } + + +class TestEmrCreateJobFlowTrigger: + def test_init_with_deprecated_params(self): + import warnings + + with warnings.catch_warnings(record=True) as catch_warns: + warnings.simplefilter("always") + + job_flow_id = "test_job_flow_id" + poll_interval = 10 + max_attempts = 5 + aws_conn_id = "aws_default" + waiter_delay = 30 + waiter_max_attempts = 60 + + trigger = EmrCreateJobFlowTrigger( + job_flow_id=job_flow_id, + poll_interval=poll_interval, + max_attempts=max_attempts, + aws_conn_id=aws_conn_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + ) + + assert trigger.waiter_delay == poll_interval + assert len(catch_warns) == 1 + assert issubclass(catch_warns[-1].category, DeprecationWarning) + assert "please use waiter_delay instead of poll_interval" in str(catch_warns[-1].message) + assert "and waiter_max_attempts instead of max_attempts" in str(catch_warns[-1].message) + + def test_serialization(self): + job_flow_id = "test_job_flow_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrCreateJobFlowTrigger( + job_flow_id=job_flow_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrCreateJobFlowTrigger" + assert kwargs == { + "job_flow_id": "test_job_flow_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrTerminateJobFlowTrigger: + def test_init_with_deprecated_params(self): + import warnings + + with warnings.catch_warnings(record=True) as catch_warns: + warnings.simplefilter("always") + + job_flow_id = "test_job_flow_id" + poll_interval = 10 + max_attempts = 5 + aws_conn_id = "aws_default" + waiter_delay = 30 + waiter_max_attempts = 60 + + trigger = EmrTerminateJobFlowTrigger( + job_flow_id=job_flow_id, + poll_interval=poll_interval, + max_attempts=max_attempts, + aws_conn_id=aws_conn_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + ) + + assert trigger.waiter_delay == poll_interval # Assert deprecated parameter is correctly used + assert len(catch_warns) == 1 + assert issubclass(catch_warns[-1].category, DeprecationWarning) + assert "please use waiter_delay instead of poll_interval" in str(catch_warns[-1].message) + assert "and waiter_max_attempts instead of max_attempts" in str(catch_warns[-1].message) + + def test_serialization(self): + job_flow_id = "test_job_flow_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrTerminateJobFlowTrigger( + job_flow_id=job_flow_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger" + assert kwargs == { + "job_flow_id": "test_job_flow_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrContainerTrigger: + def test_init_with_deprecated_params(self): + import warnings + + with warnings.catch_warnings(record=True) as catch_warns: + warnings.simplefilter("always") + + virtual_cluster_id = "test_virtual_cluster_id" + job_id = "test_job_id" + aws_conn_id = "aws_default" + poll_interval = 10 + waiter_delay = 30 + waiter_max_attempts = 600 + + trigger = EmrContainerTrigger( + virtual_cluster_id=virtual_cluster_id, + job_id=job_id, + aws_conn_id=aws_conn_id, + poll_interval=poll_interval, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + ) + + assert trigger.waiter_delay == poll_interval # Assert deprecated parameter is correctly used + assert len(catch_warns) == 1 + assert issubclass(catch_warns[-1].category, DeprecationWarning) + assert "please use waiter_delay instead of poll_interval" in str(catch_warns[-1].message) + + def test_serialization(self): + virtual_cluster_id = "test_virtual_cluster_id" + job_id = "test_job_id" + waiter_delay = 30 + waiter_max_attempts = 600 + aws_conn_id = "aws_default" + + trigger = EmrContainerTrigger( + virtual_cluster_id=virtual_cluster_id, + job_id=job_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger" + assert kwargs == { + "virtual_cluster_id": "test_virtual_cluster_id", + "job_id": "test_job_id", + "waiter_delay": 30, + "waiter_max_attempts": 600, + "aws_conn_id": "aws_default", + } + + +class TestEmrStepSensorTrigger: + def test_serialization(self): + job_flow_id = "test_job_flow_id" + step_id = "test_step_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrStepSensorTrigger( + job_flow_id=job_flow_id, + step_id=step_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger" + assert kwargs == { + "job_flow_id": "test_job_flow_id", + "step_id": "test_step_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrServerlessCreateApplicationTrigger: + def test_serialization(self): + application_id = "test_application_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrServerlessCreateApplicationTrigger( + application_id=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessCreateApplicationTrigger" + assert kwargs == { + "application_id": "test_application_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrServerlessStartApplicationTrigger: + def test_serialization(self): + application_id = "test_application_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrServerlessStartApplicationTrigger( + application_id=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartApplicationTrigger" + assert kwargs == { + "application_id": "test_application_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrServerlessStopApplicationTrigger: + def test_serialization(self): + application_id = "test_application_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrServerlessStopApplicationTrigger( + application_id=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessStopApplicationTrigger" + assert kwargs == { + "application_id": "test_application_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrServerlessStartJobTrigger: + def test_serialization(self): + application_id = "test_application_id" + waiter_delay = 30 + waiter_max_attempts = 60 + job_id = "job_id" + aws_conn_id = "aws_default" + + trigger = EmrServerlessStartJobTrigger( + application_id=application_id, + waiter_delay=waiter_delay, + job_id=job_id, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger" + assert kwargs == { + "application_id": "test_application_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "job_id": "job_id", + "aws_conn_id": "aws_default", + } + + +class TestEmrServerlessDeleteApplicationTrigger: + def test_serialization(self): + application_id = "test_application_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrServerlessDeleteApplicationTrigger( + application_id=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessDeleteApplicationTrigger" + assert kwargs == { + "application_id": "test_application_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + +class TestEmrServerlessCancelJobsTrigger: + def test_serialization(self): + application_id = "test_application_id" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + + trigger = EmrServerlessCancelJobsTrigger( + application_id=application_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessCancelJobsTrigger" + assert kwargs == { + "application_id": "test_application_id", + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } diff --git a/tests/providers/amazon/aws/triggers/test_glue_crawler.py b/tests/providers/amazon/aws/triggers/test_glue_crawler.py new file mode 100644 index 0000000000000..1ba2d610877db --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_glue_crawler.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger + + +class TestGlueCrawlerCompleteTrigger: + @patch("airflow.providers.amazon.aws.triggers.glue_crawler.warnings.warn") + def test_serialization(self, mock_warn): + crawler_name = "test_crawler" + poll_interval = 10 + aws_conn_id = "aws_default" + + trigger = GlueCrawlerCompleteTrigger( + crawler_name=crawler_name, + poll_interval=poll_interval, + aws_conn_id=aws_conn_id, + ) + + assert mock_warn.call_count == 1 + args, kwargs = mock_warn.call_args + assert args[0] == "please use waiter_delay instead of poll_interval." + assert kwargs == {"stacklevel": 2} + + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.glue_crawler.GlueCrawlerCompleteTrigger" + assert kwargs == { + "crawler_name": "test_crawler", + "waiter_delay": 10, + "waiter_max_attempts": 1500, + "aws_conn_id": "aws_default", + } diff --git a/tests/providers/amazon/aws/triggers/test_lambda_function.py b/tests/providers/amazon/aws/triggers/test_lambda_function.py new file mode 100644 index 0000000000000..396209412224d --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_lambda_function.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger + + +class TestLambdaCreateFunctionCompleteTrigger: + def test_serialization(self): + function_name = "test_function_name" + function_arn = "test_function_arn" + waiter_delay = 60 + waiter_max_attempts = 30 + aws_conn_id = "aws_default" + + trigger = LambdaCreateFunctionCompleteTrigger( + function_name=function_name, + function_arn=function_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + classpath, kwargs = trigger.serialize() + assert ( + classpath + == "airflow.providers.amazon.aws.triggers.lambda_function.LambdaCreateFunctionCompleteTrigger" + ) + assert kwargs == { + "function_name": "test_function_name", + "function_arn": "test_function_arn", + "waiter_delay": 60, + "waiter_max_attempts": 30, + "aws_conn_id": "aws_default", + } diff --git a/tests/providers/amazon/aws/triggers/test_rds.py b/tests/providers/amazon/aws/triggers/test_rds.py new file mode 100644 index 0000000000000..728377ef4776c --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_rds.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.triggers.rds import ( + RdsDbAvailableTrigger, + RdsDbDeletedTrigger, + RdsDbStoppedTrigger, +) + + +class TestRdsDbAvailableTrigger: + def test_serialization(self): + db_identifier = "test_db_identifier" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + response = {"key": "value"} + db_type = "instance" # or use an instance of RdsDbType if available + region_name = "us-west-2" + + trigger = RdsDbAvailableTrigger( + db_identifier=db_identifier, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + response=response, + db_type=db_type, + region_name=region_name, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.rds.RdsDbAvailableTrigger" + assert kwargs == { + "db_identifier": "test_db_identifier", + "db_type": "instance", + "response": {"key": "value"}, + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + "region_name": "us-west-2", + } + + +class TestRdsDbDeletedTrigger: + def test_serialization(self): + db_identifier = "test_db_identifier" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + response = {"key": "value"} + db_type = "instance" + region_name = "us-west-2" + + trigger = RdsDbDeletedTrigger( + db_identifier=db_identifier, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + response=response, + db_type=db_type, + region_name=region_name, + ) + + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.rds.RdsDbDeletedTrigger" + assert kwargs == { + "db_identifier": "test_db_identifier", + "db_type": "instance", + "response": {"key": "value"}, + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + "region_name": "us-west-2", + } + + +class TestRdsDbStoppedTrigger: + def test_serialization(self): + db_identifier = "test_db_identifier" + waiter_delay = 30 + waiter_max_attempts = 60 + aws_conn_id = "aws_default" + response = {"key": "value"} + db_type = "instance" + region_name = "us-west-2" + + trigger = RdsDbStoppedTrigger( + db_identifier=db_identifier, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + response=response, + db_type=db_type, + region_name=region_name, + ) + + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.rds.RdsDbStoppedTrigger" + assert kwargs == { + "db_identifier": "test_db_identifier", + "db_type": "instance", + "response": {"key": "value"}, + "waiter_delay": 30, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + "region_name": "us-west-2", + }