Skip to content

Commit

Permalink
Add IAM authentication to Amazon Redshift Connection by AWS Connection (
Browse files Browse the repository at this point in the history
#28187)

* Add IAM authentication to Amazon Redshift Connection by AWS Connection

* Fix type checking

* Fixed documentation
  • Loading branch information
IAL32 authored May 2, 2023
1 parent b7e5b47 commit 2f247a2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 9 deletions.
42 changes: 42 additions & 0 deletions airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,35 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import redshift_connector
from redshift_connector import Connection as RedshiftConnection
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from airflow.models.connection import Connection


class RedshiftSQLHook(DbApiHook):
"""
Execute statements against Amazon Redshift, using redshift_connector
This hook requires the redshift_conn_id connection.
Note: For AWS IAM authentication, use iam in the extra connection parameters
and set it to true. Leave the password field empty. This will use the
"aws_default" connection to get the temporary token unless you override
with aws_conn_id when initializing the hook.
The cluster-identifier is extracted from the beginning of
the host field, so is optional. It can however be overridden in the extra field.
extras example: ``{"iam":true}``
:param redshift_conn_id: reference to
:ref:`Amazon Redshift connection id<howto/connection:redshift>`
Expand All @@ -44,6 +58,10 @@ class RedshiftSQLHook(DbApiHook):
hook_name = "Amazon Redshift"
supports_autocommit = True

def __init__(self, *args, aws_conn_id: str = "aws_default", **kwargs) -> None:
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id

@staticmethod
def get_ui_field_behaviour() -> dict:
"""Returns custom field behavior"""
Expand All @@ -62,6 +80,9 @@ def _get_conn_params(self) -> dict[str, str | int]:

conn_params: dict[str, str | int] = {}

if conn.extra_dejson.get("iam", False):
conn.login, conn.password, conn.port = self.get_iam_token(conn)

if conn.login:
conn_params["user"] = conn.login
if conn.password:
Expand All @@ -75,6 +96,27 @@ def _get_conn_params(self) -> dict[str, str | int]:

return conn_params

def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
"""
Uses AWSHook to retrieve a temporary password to connect to Redshift.
Port is required. If none is provided, default is used for each service
"""
port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
cluster_identifier = conn.extra_dejson.get("cluster_identifier", conn.host.split(".")[0])
redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="redshift").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
DbUser=conn.login,
DbName=conn.schema,
ClusterIdentifier=cluster_identifier,
AutoCreate=False,
)
token = cluster_creds["DbPassword"]
login = cluster_creds["DbUser"]
return login, token, port

def get_uri(self) -> str:
"""Overrides DbApiHook get_uri to use redshift_connector sqlalchemy dialect as driver name"""
conn_params = self._get_conn_params()
Expand Down
12 changes: 10 additions & 2 deletions docs/apache-airflow-providers-amazon/connections/redshift.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,19 @@ Examples
* **Password**: ``********``
* **Port**: ``5439``

**IAM Authentication**
**Credentials Authentication**

Uses AWS IAM to retrieve a temporary password to connect to Amazon Redshift. Port is required.
Uses the credentials in Connection to connect to Amazon Redshift. Port is required.
If none is provided, default is used (5439). This assumes all other Connection fields e.g. **Login** are empty.
In this method, **cluster_identifier** replaces **Host** and **Port** in order to uniquely identify the cluster.

**IAM Authentication**

Uses the AWS IAM profile given at hook initialization to retrieve a temporary password to connect
to Amazon Redshift. **Port** is required. If none is provided, default is used (5439). **Login**
and **Schema** are also required. This assumes all other Connection fields are empty.
In this method, if **cluster_identifier** is not set within the extras, it is automatically
inferred by the **Host** field in Connection.
`More details about AWS IAM authentication to generate database user credentials <https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`_.

* **Extra**:
Expand Down
65 changes: 58 additions & 7 deletions tests/providers/amazon/aws/hooks/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,24 @@

from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.utils.types import NOTSET

LOGIN_USER = "login"
LOGIN_PASSWORD = "password"
LOGIN_HOST = "host"
LOGIN_PORT = 5439
LOGIN_SCHEMA = "dev"


class TestRedshiftSQLHookConn:
def setup_method(self):
self.connection = Connection(
conn_type="redshift", login="login", password="password", host="host", port=5439, schema="dev"
conn_type="redshift",
login=LOGIN_USER,
password=LOGIN_PASSWORD,
host=LOGIN_HOST,
port=LOGIN_PORT,
schema=LOGIN_SCHEMA,
)

self.db_hook = RedshiftSQLHook()
Expand All @@ -51,20 +63,59 @@ def test_get_conn(self, mock_connect):
def test_get_conn_extra(self, mock_connect):
self.connection.extra = json.dumps(
{
"iam": True,
"iam": False,
"cluster_identifier": "my-test-cluster",
"profile": "default",
}
)
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user="login",
password="password",
host="host",
port=5439,
user=LOGIN_USER,
password=LOGIN_PASSWORD,
host=LOGIN_HOST,
port=LOGIN_PORT,
cluster_identifier="my-test-cluster",
profile="default",
database=LOGIN_SCHEMA,
iam=False,
)

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
def test_get_conn_iam(self, mock_connect, mock_aws_hook_conn, aws_conn_id):
mock_conn_extra = {"iam": True, "profile": "default", "cluster_identifier": "my-test-cluster"}
if aws_conn_id is not NOTSET:
self.db_hook.aws_conn_id = aws_conn_id
self.connection.extra = json.dumps(mock_conn_extra)

mock_db_user = f"IAM:{self.connection.login}"
mock_db_pass = "aws_token"

# Mock AWS Connection
mock_aws_hook_conn.get_cluster_credentials.return_value = {
"DbPassword": mock_db_pass,
"DbUser": mock_db_user,
}

self.db_hook.get_conn()

# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_aws_hook_conn.get_cluster_credentials.assert_called_once_with(
DbUser=LOGIN_USER,
DbName=LOGIN_SCHEMA,
ClusterIdentifier="my-test-cluster",
AutoCreate=False,
)

mock_connect.assert_called_once_with(
user=mock_db_user,
password=mock_db_pass,
host=LOGIN_HOST,
port=LOGIN_PORT,
cluster_identifier="my-test-cluster",
profile="default",
database="dev",
database=LOGIN_SCHEMA,
iam=True,
)

Expand Down

0 comments on commit 2f247a2

Please sign in to comment.