diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py b/airflow/providers/amazon/aws/hooks/redshift_sql.py index 120ce190ccb1b..e9c2b7fecc78b 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -16,14 +16,20 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + import redshift_connector from redshift_connector import Connection as RedshiftConnection from sqlalchemy import create_engine from sqlalchemy.engine.url import URL from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook +if TYPE_CHECKING: + from airflow.models.connection import Connection + class RedshiftSQLHook(DbApiHook): """ @@ -31,6 +37,14 @@ class RedshiftSQLHook(DbApiHook): This hook requires the redshift_conn_id connection. + Note: For AWS IAM authentication, use iam in the extra connection parameters + and set it to true. Leave the password field empty. This will use the + "aws_default" connection to get the temporary token unless you override + with aws_conn_id when initializing the hook. + The cluster-identifier is extracted from the beginning of + the host field, so is optional. It can however be overridden in the extra field. + extras example: ``{"iam":true}`` + :param redshift_conn_id: reference to :ref:`Amazon Redshift connection id` @@ -44,6 +58,10 @@ class RedshiftSQLHook(DbApiHook): hook_name = "Amazon Redshift" supports_autocommit = True + def __init__(self, *args, aws_conn_id: str = "aws_default", **kwargs) -> None: + super().__init__(*args, **kwargs) + self.aws_conn_id = aws_conn_id + @staticmethod def get_ui_field_behaviour() -> dict: """Returns custom field behavior""" @@ -62,6 +80,9 @@ def _get_conn_params(self) -> dict[str, str | int]: conn_params: dict[str, str | int] = {} + if conn.extra_dejson.get("iam", False): + conn.login, conn.password, conn.port = self.get_iam_token(conn) + if conn.login: conn_params["user"] = conn.login if conn.password: @@ -75,6 +96,27 @@ def _get_conn_params(self) -> dict[str, str | int]: return conn_params + def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: + """ + Uses AWSHook to retrieve a temporary password to connect to Redshift. + Port is required. If none is provided, default is used for each service + """ + port = conn.port or 5439 + # Pull the custer-identifier from the beginning of the Redshift URL + # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster + cluster_identifier = conn.extra_dejson.get("cluster_identifier", conn.host.split(".")[0]) + redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="redshift").conn + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials + cluster_creds = redshift_client.get_cluster_credentials( + DbUser=conn.login, + DbName=conn.schema, + ClusterIdentifier=cluster_identifier, + AutoCreate=False, + ) + token = cluster_creds["DbPassword"] + login = cluster_creds["DbUser"] + return login, token, port + def get_uri(self) -> str: """Overrides DbApiHook get_uri to use redshift_connector sqlalchemy dialect as driver name""" conn_params = self._get_conn_params() diff --git a/docs/apache-airflow-providers-amazon/connections/redshift.rst b/docs/apache-airflow-providers-amazon/connections/redshift.rst index 3c4570951406e..ccd48805c4e8a 100644 --- a/docs/apache-airflow-providers-amazon/connections/redshift.rst +++ b/docs/apache-airflow-providers-amazon/connections/redshift.rst @@ -69,11 +69,19 @@ Examples * **Password**: ``********`` * **Port**: ``5439`` -**IAM Authentication** +**Credentials Authentication** -Uses AWS IAM to retrieve a temporary password to connect to Amazon Redshift. Port is required. +Uses the credentials in Connection to connect to Amazon Redshift. Port is required. If none is provided, default is used (5439). This assumes all other Connection fields e.g. **Login** are empty. In this method, **cluster_identifier** replaces **Host** and **Port** in order to uniquely identify the cluster. + +**IAM Authentication** + +Uses the AWS IAM profile given at hook initialization to retrieve a temporary password to connect +to Amazon Redshift. **Port** is required. If none is provided, default is used (5439). **Login** +and **Schema** are also required. This assumes all other Connection fields are empty. +In this method, if **cluster_identifier** is not set within the extras, it is automatically +inferred by the **Host** field in Connection. `More details about AWS IAM authentication to generate database user credentials `_. * **Extra**: diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py b/tests/providers/amazon/aws/hooks/test_redshift_sql.py index 531d6a9b470a0..335c2f28f40ef 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py @@ -23,12 +23,24 @@ from airflow.models import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook +from airflow.utils.types import NOTSET + +LOGIN_USER = "login" +LOGIN_PASSWORD = "password" +LOGIN_HOST = "host" +LOGIN_PORT = 5439 +LOGIN_SCHEMA = "dev" class TestRedshiftSQLHookConn: def setup_method(self): self.connection = Connection( - conn_type="redshift", login="login", password="password", host="host", port=5439, schema="dev" + conn_type="redshift", + login=LOGIN_USER, + password=LOGIN_PASSWORD, + host=LOGIN_HOST, + port=LOGIN_PORT, + schema=LOGIN_SCHEMA, ) self.db_hook = RedshiftSQLHook() @@ -51,20 +63,59 @@ def test_get_conn(self, mock_connect): def test_get_conn_extra(self, mock_connect): self.connection.extra = json.dumps( { - "iam": True, + "iam": False, "cluster_identifier": "my-test-cluster", "profile": "default", } ) self.db_hook.get_conn() mock_connect.assert_called_once_with( - user="login", - password="password", - host="host", - port=5439, + user=LOGIN_USER, + password=LOGIN_PASSWORD, + host=LOGIN_HOST, + port=LOGIN_PORT, + cluster_identifier="my-test-cluster", + profile="default", + database=LOGIN_SCHEMA, + iam=False, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect") + @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) + def test_get_conn_iam(self, mock_connect, mock_aws_hook_conn, aws_conn_id): + mock_conn_extra = {"iam": True, "profile": "default", "cluster_identifier": "my-test-cluster"} + if aws_conn_id is not NOTSET: + self.db_hook.aws_conn_id = aws_conn_id + self.connection.extra = json.dumps(mock_conn_extra) + + mock_db_user = f"IAM:{self.connection.login}" + mock_db_pass = "aws_token" + + # Mock AWS Connection + mock_aws_hook_conn.get_cluster_credentials.return_value = { + "DbPassword": mock_db_pass, + "DbUser": mock_db_user, + } + + self.db_hook.get_conn() + + # Check boto3 'redshift' client method `get_cluster_credentials` call args + mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with( + DbUser=LOGIN_USER, + DbName=LOGIN_SCHEMA, + ClusterIdentifier="my-test-cluster", + AutoCreate=False, + ) + + mock_connect.assert_called_once_with( + user=mock_db_user, + password=mock_db_pass, + host=LOGIN_HOST, + port=LOGIN_PORT, cluster_identifier="my-test-cluster", profile="default", - database="dev", + database=LOGIN_SCHEMA, iam=True, )