Skip to content

Commit

Permalink
Use base aws classes in AWS CloudFormation Operators/Sensors (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Jan 14, 2024
1 parent c2d02b4 commit 1455a3b
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 57 deletions.
63 changes: 37 additions & 26 deletions airflow/providers/amazon/aws/operators/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,79 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains CloudFormation create/delete stack operators."""
"""This module contains AWS CloudFormation create/delete stack operators."""
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class CloudFormationCreateStackOperator(BaseOperator):
class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]):
"""
An operator that creates a CloudFormation stack.
An operator that creates a AWS CloudFormation stack.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:CloudFormationCreateStackOperator`
:param stack_name: stack name (templated)
:param cloudformation_parameters: parameters to be passed to CloudFormation.
:param aws_conn_id: aws connection to uses
:param cloudformation_parameters: parameters to be passed to AWS CloudFormation.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("stack_name", "cloudformation_parameters")
template_ext: Sequence[str] = ()
aws_hook_class = CloudFormationHook
template_fields: Sequence[str] = aws_template_fields("stack_name", "cloudformation_parameters")
ui_color = "#6b9659"

def __init__(
self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = "aws_default", **kwargs
):
def __init__(self, *, stack_name: str, cloudformation_parameters: dict, **kwargs):
super().__init__(**kwargs)
self.stack_name = stack_name
self.cloudformation_parameters = cloudformation_parameters
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
self.log.info("CloudFormation parameters: %s", self.cloudformation_parameters)

cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
cloudformation_hook.create_stack(self.stack_name, self.cloudformation_parameters)
self.hook.create_stack(self.stack_name, self.cloudformation_parameters)


class CloudFormationDeleteStackOperator(BaseOperator):
class CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
"""
An operator that deletes a CloudFormation stack.
:param stack_name: stack name (templated)
:param cloudformation_parameters: parameters to be passed to CloudFormation.
An operator that deletes a AWS CloudFormation stack.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:CloudFormationDeleteStackOperator`
:param aws_conn_id: aws connection to uses
:param stack_name: stack name (templated)
:param cloudformation_parameters: parameters to be passed to CloudFormation.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("stack_name",)
template_ext: Sequence[str] = ()
aws_hook_class = CloudFormationHook
template_fields: Sequence[str] = aws_template_fields("stack_name")
ui_color = "#1d472b"
ui_fgcolor = "#FFF"

Expand All @@ -93,6 +106,4 @@ def __init__(

def execute(self, context: Context):
self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters)

cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
cloudformation_hook.delete_stack(self.stack_name, self.cloudformation_parameters)
self.hook.delete_stack(self.stack_name, self.cloudformation_parameters)
55 changes: 30 additions & 25 deletions airflow/providers/amazon/aws/sensors/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
"""This module contains sensors for AWS CloudFormation."""
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowSkipException
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
from airflow.sensors.base import BaseSensorOperator


class CloudFormationCreateStackSensor(BaseSensorOperator):
class CloudFormationCreateStackSensor(AwsBaseSensor[CloudFormationHook]):
"""
Waits for a stack to be created successfully on AWS CloudFormation.
Expand All @@ -38,19 +39,25 @@ class CloudFormationCreateStackSensor(BaseSensorOperator):
:ref:`howto/sensor:CloudFormationCreateStackSensor`
:param stack_name: The name of the stack to wait for (templated)
:param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are
stored
:param poke_interval: Time in seconds that the job should wait between each try
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("stack_name",)
aws_hook_class = CloudFormationHook
template_fields: Sequence[str] = aws_template_fields("stack_name")
ui_color = "#C5CAE9"

def __init__(self, *, stack_name, aws_conn_id="aws_default", region_name=None, **kwargs):
def __init__(self, *, stack_name, **kwargs):
super().__init__(**kwargs)
self.stack_name = stack_name
self.aws_conn_id = aws_conn_id
self.region_name = region_name

def poke(self, context: Context):
stack_status = self.hook.get_stack_status(self.stack_name)
Expand All @@ -65,13 +72,8 @@ def poke(self, context: Context):
raise AirflowSkipException(message)
raise ValueError(message)

@cached_property
def hook(self) -> CloudFormationHook:
"""Create and return a CloudFormationHook."""
return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)


class CloudFormationDeleteStackSensor(BaseSensorOperator):
class CloudFormationDeleteStackSensor(AwsBaseSensor[CloudFormationHook]):
"""
Waits for a stack to be deleted successfully on AWS CloudFormation.
Expand All @@ -80,12 +82,20 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
:ref:`howto/sensor:CloudFormationDeleteStackSensor`
:param stack_name: The name of the stack to wait for (templated)
:param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are
stored
:param poke_interval: Time in seconds that the job should wait between each try
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("stack_name",)
aws_hook_class = CloudFormationHook
template_fields: Sequence[str] = aws_template_fields("stack_name")
ui_color = "#C5CAE9"

def __init__(
Expand Down Expand Up @@ -113,8 +123,3 @@ def poke(self, context: Context):
if self.soft_fail:
raise AirflowSkipException(message)
raise ValueError(message)

@cached_property
def hook(self) -> CloudFormationHook:
"""Create and return a CloudFormationHook."""
return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
53 changes: 53 additions & 0 deletions tests/providers/amazon/aws/operators/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,35 @@ def mocked_hook_client():


class TestCloudFormationCreateStackOperator:
def test_init(self):
op = CloudFormationCreateStackOperator(
task_id="cf_create_stack_init",
stack_name="fake-stack",
cloudformation_parameters={},
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="eu-west-1",
verify=True,
botocore_config={"read_timeout": 42},
)
assert op.hook.client_type == "cloudformation"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is True
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = CloudFormationCreateStackOperator(
task_id="cf_create_stack_init",
stack_name="fake-stack",
cloudformation_parameters={},
)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

def test_create_stack(self, mocked_hook_client):
stack_name = "myStack"
timeout = 15
Expand All @@ -60,6 +89,30 @@ def test_create_stack(self, mocked_hook_client):


class TestCloudFormationDeleteStackOperator:
def test_init(self):
op = CloudFormationDeleteStackOperator(
task_id="cf_delete_stack_init",
stack_name="fake-stack",
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="us-east-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.client_type == "cloudformation"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "us-east-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = CloudFormationDeleteStackOperator(task_id="cf_delete_stack_init", stack_name="fake-stack")
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

def test_delete_stack(self, mocked_hook_client):
stack_name = "myStackToBeDeleted"

Expand Down
Loading

0 comments on commit 1455a3b

Please sign in to comment.