diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py index c6963520f87f1..b24ccf05f453d 100644 --- a/airflow/providers/amazon/aws/operators/cloud_formation.py +++ b/airflow/providers/amazon/aws/operators/cloud_formation.py @@ -15,66 +15,79 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains CloudFormation create/delete stack operators.""" +"""This module contains AWS CloudFormation create/delete stack operators.""" from __future__ import annotations from typing import TYPE_CHECKING, Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class CloudFormationCreateStackOperator(BaseOperator): +class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]): """ - An operator that creates a CloudFormation stack. + An operator that creates a AWS CloudFormation stack. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:CloudFormationCreateStackOperator` :param stack_name: stack name (templated) - :param cloudformation_parameters: parameters to be passed to CloudFormation. - :param aws_conn_id: aws connection to uses + :param cloudformation_parameters: parameters to be passed to AWS CloudFormation. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("stack_name", "cloudformation_parameters") - template_ext: Sequence[str] = () + aws_hook_class = CloudFormationHook + template_fields: Sequence[str] = aws_template_fields("stack_name", "cloudformation_parameters") ui_color = "#6b9659" - def __init__( - self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = "aws_default", **kwargs - ): + def __init__(self, *, stack_name: str, cloudformation_parameters: dict, **kwargs): super().__init__(**kwargs) self.stack_name = stack_name self.cloudformation_parameters = cloudformation_parameters - self.aws_conn_id = aws_conn_id def execute(self, context: Context): self.log.info("CloudFormation parameters: %s", self.cloudformation_parameters) - - cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id) - cloudformation_hook.create_stack(self.stack_name, self.cloudformation_parameters) + self.hook.create_stack(self.stack_name, self.cloudformation_parameters) -class CloudFormationDeleteStackOperator(BaseOperator): +class CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]): """ - An operator that deletes a CloudFormation stack. - - :param stack_name: stack name (templated) - :param cloudformation_parameters: parameters to be passed to CloudFormation. + An operator that deletes a AWS CloudFormation stack. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:CloudFormationDeleteStackOperator` - :param aws_conn_id: aws connection to uses + :param stack_name: stack name (templated) + :param cloudformation_parameters: parameters to be passed to CloudFormation. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("stack_name",) - template_ext: Sequence[str] = () + aws_hook_class = CloudFormationHook + template_fields: Sequence[str] = aws_template_fields("stack_name") ui_color = "#1d472b" ui_fgcolor = "#FFF" @@ -93,6 +106,4 @@ def __init__( def execute(self, context: Context): self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters) - - cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id) - cloudformation_hook.delete_stack(self.stack_name, self.cloudformation_parameters) + self.hook.delete_stack(self.stack_name, self.cloudformation_parameters) diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py index 5c6b1f2246938..044ca50484df9 100644 --- a/airflow/providers/amazon/aws/sensors/cloud_formation.py +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -18,18 +18,19 @@ """This module contains sensors for AWS CloudFormation.""" from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Sequence +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + if TYPE_CHECKING: from airflow.utils.context import Context from airflow.exceptions import AirflowSkipException from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook -from airflow.sensors.base import BaseSensorOperator -class CloudFormationCreateStackSensor(BaseSensorOperator): +class CloudFormationCreateStackSensor(AwsBaseSensor[CloudFormationHook]): """ Waits for a stack to be created successfully on AWS CloudFormation. @@ -38,19 +39,25 @@ class CloudFormationCreateStackSensor(BaseSensorOperator): :ref:`howto/sensor:CloudFormationCreateStackSensor` :param stack_name: The name of the stack to wait for (templated) - :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are - stored - :param poke_interval: Time in seconds that the job should wait between each try + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("stack_name",) + aws_hook_class = CloudFormationHook + template_fields: Sequence[str] = aws_template_fields("stack_name") ui_color = "#C5CAE9" - def __init__(self, *, stack_name, aws_conn_id="aws_default", region_name=None, **kwargs): + def __init__(self, *, stack_name, **kwargs): super().__init__(**kwargs) self.stack_name = stack_name - self.aws_conn_id = aws_conn_id - self.region_name = region_name def poke(self, context: Context): stack_status = self.hook.get_stack_status(self.stack_name) @@ -65,13 +72,8 @@ def poke(self, context: Context): raise AirflowSkipException(message) raise ValueError(message) - @cached_property - def hook(self) -> CloudFormationHook: - """Create and return a CloudFormationHook.""" - return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - -class CloudFormationDeleteStackSensor(BaseSensorOperator): +class CloudFormationDeleteStackSensor(AwsBaseSensor[CloudFormationHook]): """ Waits for a stack to be deleted successfully on AWS CloudFormation. @@ -80,12 +82,20 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator): :ref:`howto/sensor:CloudFormationDeleteStackSensor` :param stack_name: The name of the stack to wait for (templated) - :param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are - stored - :param poke_interval: Time in seconds that the job should wait between each try + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("stack_name",) + aws_hook_class = CloudFormationHook + template_fields: Sequence[str] = aws_template_fields("stack_name") ui_color = "#C5CAE9" def __init__( @@ -113,8 +123,3 @@ def poke(self, context: Context): if self.soft_fail: raise AirflowSkipException(message) raise ValueError(message) - - @cached_property - def hook(self) -> CloudFormationHook: - """Create and return a CloudFormationHook.""" - return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/docs/apache-airflow-providers-amazon/operators/cloudformation.rst b/docs/apache-airflow-providers-amazon/operators/cloudformation.rst index 4051be0ccd225..ff45efcdb645e 100644 --- a/docs/apache-airflow-providers-amazon/operators/cloudformation.rst +++ b/docs/apache-airflow-providers-amazon/operators/cloudformation.rst @@ -31,6 +31,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py index df54596b02a8c..071ba5c847040 100644 --- a/tests/providers/amazon/aws/operators/test_cloud_formation.py +++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py @@ -40,6 +40,35 @@ def mocked_hook_client(): class TestCloudFormationCreateStackOperator: + def test_init(self): + op = CloudFormationCreateStackOperator( + task_id="cf_create_stack_init", + stack_name="fake-stack", + cloudformation_parameters={}, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.client_type == "cloudformation" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is True + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = CloudFormationCreateStackOperator( + task_id="cf_create_stack_init", + stack_name="fake-stack", + cloudformation_parameters={}, + ) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + def test_create_stack(self, mocked_hook_client): stack_name = "myStack" timeout = 15 @@ -60,6 +89,30 @@ def test_create_stack(self, mocked_hook_client): class TestCloudFormationDeleteStackOperator: + def test_init(self): + op = CloudFormationDeleteStackOperator( + task_id="cf_delete_stack_init", + stack_name="fake-stack", + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-east-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.client_type == "cloudformation" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "us-east-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = CloudFormationDeleteStackOperator(task_id="cf_delete_stack_init", stack_name="fake-stack") + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + def test_delete_stack(self, mocked_hook_client): stack_name = "myStackToBeDeleted" diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py index 51b9c385f175d..ca4177441143a 100644 --- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py +++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py @@ -23,6 +23,7 @@ import pytest from moto import mock_cloudformation +from airflow.exceptions import AirflowSkipException from airflow.providers.amazon.aws.sensors.cloud_formation import ( CloudFormationCreateStackSensor, CloudFormationDeleteStackSensor, @@ -40,6 +41,30 @@ class TestCloudFormationCreateStackSensor: def setup_method(self, method): self.client = boto3.client("cloudformation", region_name="us-east-1") + def test_init(self): + sensor = CloudFormationCreateStackSensor( + task_id="cf_create_stack_init", + stack_name="fake-stack", + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-central-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert sensor.hook.client_type == "cloudformation" + assert sensor.hook.resource_type is None + assert sensor.hook.aws_conn_id == "fake-conn-id" + assert sensor.hook._region_name == "eu-central-1" + assert sensor.hook._verify is False + assert sensor.hook._config is not None + assert sensor.hook._config.read_timeout == 42 + + sensor = CloudFormationCreateStackSensor(task_id="cf_create_stack_init", stack_name="fake-stack") + assert sensor.hook.aws_conn_id == "aws_default" + assert sensor.hook._region_name is None + assert sensor.hook._verify is None + assert sensor.hook._config is None + @mock_cloudformation def test_poke(self): self.client.create_stack(StackName="foobar", TemplateBody='{"Resources": {}}') @@ -51,10 +76,17 @@ def test_poke_false(self, mocked_hook_client): op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo") assert not op.poke({}) - def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client): + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(True, AirflowSkipException, id="soft-fail"), + pytest.param(False, ValueError, id="non-soft-fail"), + ], + ) + def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client, soft_fail, expected_exception): mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "bar"}]} - op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo") - with pytest.raises(ValueError, match="Stack foo in bad state: bar"): + op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo", soft_fail=soft_fail) + with pytest.raises(expected_exception, match="Stack foo in bad state: bar"): op.poke({}) @@ -63,6 +95,30 @@ class TestCloudFormationDeleteStackSensor: def setup_method(self, method): self.client = boto3.client("cloudformation", region_name="us-east-1") + def test_init(self): + sensor = CloudFormationDeleteStackSensor( + task_id="cf_delete_stack_init", + stack_name="fake-stack", + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert sensor.hook.client_type == "cloudformation" + assert sensor.hook.resource_type is None + assert sensor.hook.aws_conn_id == "fake-conn-id" + assert sensor.hook._region_name == "ca-west-1" + assert sensor.hook._verify is True + assert sensor.hook._config is not None + assert sensor.hook._config.read_timeout == 42 + + sensor = CloudFormationDeleteStackSensor(task_id="cf_delete_stack_init", stack_name="fake-stack") + assert sensor.hook.aws_conn_id == "aws_default" + assert sensor.hook._region_name is None + assert sensor.hook._verify is None + assert sensor.hook._config is None + @mock_cloudformation def test_poke(self): stack_name = "foobar" @@ -76,10 +132,17 @@ def test_poke_false(self, mocked_hook_client): op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo") assert not op.poke({}) - def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client): + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(True, AirflowSkipException, id="soft-fail"), + pytest.param(False, ValueError, id="non-soft-fail"), + ], + ) + def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client, soft_fail, expected_exception): mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "bar"}]} - op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo") - with pytest.raises(ValueError, match="Stack foo in bad state: bar"): + op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo", soft_fail=soft_fail) + with pytest.raises(expected_exception, match="Stack foo in bad state: bar"): op.poke({}) @mock_cloudformation