diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py index 6e3b454213189..27fc25a1de546 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py +++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py @@ -156,7 +156,11 @@ def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identi return response["Cluster"] if response["Cluster"] else None def create_cluster_snapshot( - self, snapshot_identifier: str, cluster_identifier: str, retention_period: int = -1 + self, + snapshot_identifier: str, + cluster_identifier: str, + retention_period: int = -1, + tags: list[Any] | None = None, ) -> str: """ Creates a snapshot of a cluster @@ -168,11 +172,15 @@ def create_cluster_snapshot( :param cluster_identifier: unique identifier of a cluster :param retention_period: The number of days that a manual snapshot is retained. If the value is -1, the manual snapshot is retained indefinitely. + :param tags: A list of tag instances """ + if tags is None: + tags = [] response = self.get_conn().create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, ManualSnapshotRetentionPeriod=retention_period, + Tags=tags, ) return response["Snapshot"] if response["Snapshot"] else None diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 77ac521c9baf6..2880240b1532d 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -298,6 +298,7 @@ class RedshiftCreateClusterSnapshotOperator(BaseOperator): :param cluster_identifier: The cluster identifier for which you want a snapshot :param retention_period: The number of days that a manual snapshot is retained. If the value is -1, the manual snapshot is retained indefinitely. + :parma tags: A list of tag instances :param wait_for_completion: Whether wait for the cluster snapshot to be in ``available`` state :param poll_interval: Time (in seconds) to wait between two consecutive calls to check state :param max_attempt: The maximum number of attempts to be made to check the state @@ -316,6 +317,7 @@ def __init__( snapshot_identifier: str, cluster_identifier: str, retention_period: int = -1, + tags: list[Any] | None = None, wait_for_completion: bool = False, poll_interval: int = 15, max_attempt: int = 20, @@ -326,6 +328,7 @@ def __init__( self.snapshot_identifier = snapshot_identifier self.cluster_identifier = cluster_identifier self.retention_period = retention_period + self.tags = tags self.wait_for_completion = wait_for_completion self.poll_interval = poll_interval self.max_attempt = max_attempt @@ -343,6 +346,7 @@ def execute(self, context: Context) -> Any: cluster_identifier=self.cluster_identifier, snapshot_identifier=self.snapshot_identifier, retention_period=self.retention_period, + tags=self.tags, ) if self.wait_for_completion: diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index 64a276f14d02c..fca7dafdaad8c 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -144,12 +144,24 @@ def test_create_cluster_snapshot_is_called_when_cluster_is_available( cluster_identifier="test_cluster", snapshot_identifier="test_snapshot", retention_period=1, + tags=[ + { + "Key": "user", + "Value": "airflow", + } + ], ) create_snapshot.execute(None) mock_get_conn.return_value.create_cluster_snapshot.assert_called_once_with( ClusterIdentifier="test_cluster", SnapshotIdentifier="test_snapshot", ManualSnapshotRetentionPeriod=1, + Tags=[ + { + "Key": "user", + "Value": "airflow", + } + ], ) mock_get_conn.return_value.get_waiter.assert_not_called()