Skip to content

Commit

Permalink
Allow configuration of sqlalchemy query parameter for JdbcHook and Po…
Browse files Browse the repository at this point in the history
…stgresHook through extras (#44910)
  • Loading branch information
dabla authored Dec 18, 2024
1 parent 084218b commit a10b3fc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 11 deletions.
4 changes: 4 additions & 0 deletions providers/src/airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def driver_class(self) -> str | None:
@property
def sqlalchemy_url(self) -> URL:
conn = self.connection
sqlalchemy_query = conn.extra_dejson.get("sqlalchemy_query", {})
if not isinstance(sqlalchemy_query, dict):
raise AirflowException("The parameter 'sqlalchemy_query' must be of type dict!")
sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme")
if sqlalchemy_scheme is None:
raise AirflowException(
Expand All @@ -164,6 +167,7 @@ def sqlalchemy_url(self) -> URL:
host=conn.host,
port=conn.port,
database=conn.schema,
query=sqlalchemy_query,
)

def get_sqlalchemy_engine(self, engine_kwargs=None):
Expand Down
28 changes: 18 additions & 10 deletions providers/src/airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,6 +86,17 @@ class PostgresHook(DbApiHook):
hook_name = "Postgres"
supports_autocommit = True
supports_executemany = True
ignored_extra_options = {
"iam",
"redshift",
"redshift-serverless",
"cursor",
"cluster-identifier",
"workgroup-name",
"aws_conn_id",
"sqlalchemy_scheme",
"sqlalchemy_query",
}

def __init__(
self, *args, options: str | None = None, enable_log_db_messages: bool = False, **kwargs
Expand All @@ -97,14 +109,18 @@ def __init__(

@property
def sqlalchemy_url(self) -> URL:
conn = self.get_connection(self.get_conn_id())
conn = self.connection
query = conn.extra_dejson.get("sqlalchemy_query", {})
if not isinstance(query, dict):
raise AirflowException("The parameter 'sqlalchemy_query' must be of type dict!")
return URL.create(
drivername="postgresql",
username=conn.login,
password=conn.password,
host=conn.host,
port=conn.port,
database=self.database or conn.schema,
query=query,
)

def _get_cursor(self, raw_cursor: str) -> CursorType:
Expand Down Expand Up @@ -143,15 +159,7 @@ def get_conn(self) -> connection:
conn_args["options"] = self.options

for arg_name, arg_val in conn.extra_dejson.items():
if arg_name not in [
"iam",
"redshift",
"redshift-serverless",
"cursor",
"cluster-identifier",
"workgroup-name",
"aws_conn_id",
]:
if arg_name not in self.ignored_extra_options:
conn_args[arg_name] = arg_val

self.conn = psycopg2.connect(**conn_args)
Expand Down
17 changes: 17 additions & 0 deletions providers/tests/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,23 @@ def test_sqlalchemy_url_with_sqlalchemy_scheme(self):

assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema"

def test_sqlalchemy_url_with_sqlalchemy_scheme_and_query(self):
conn_params = dict(
extra=json.dumps(dict(sqlalchemy_scheme="mssql", sqlalchemy_query={"servicename": "test"}))
)
hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"}
hook = get_hook(conn_params=conn_params, hook_params=hook_params)

assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema?servicename=test"

def test_sqlalchemy_url_with_sqlalchemy_scheme_and_wrong_query_value(self):
conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql", sqlalchemy_query="wrong type")))
hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"}
hook = get_hook(conn_params=conn_params, hook_params=hook_params)

with pytest.raises(AirflowException):
hook.sqlalchemy_url

def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
jdbc_hook = get_hook(
conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}),
Expand Down
36 changes: 35 additions & 1 deletion providers/tests/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import psycopg2.extras
import pytest

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -65,9 +66,42 @@ def test_get_uri(self, mock_connect):
assert mock_connect.call_count == 1
assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database"

def test_sqlalchemy_url(self):
conn = Connection(login="login-conn", password="password-conn", host="host", schema="database")
hook = PostgresHook(connection=conn)
assert str(hook.sqlalchemy_url) == "postgresql://login-conn:password-conn@host/database"

def test_sqlalchemy_url_with_sqlalchemy_query(self):
conn = Connection(
login="login-conn",
password="password-conn",
host="host",
schema="database",
extra=dict(sqlalchemy_query={"gssencmode": "disable"}),
)
hook = PostgresHook(connection=conn)

assert (
str(hook.sqlalchemy_url)
== "postgresql://login-conn:password-conn@host/database?gssencmode=disable"
)

def test_sqlalchemy_url_with_wrong_sqlalchemy_query_value(self):
conn = Connection(
login="login-conn",
password="password-conn",
host="host",
schema="database",
extra=dict(sqlalchemy_query="wrong type"),
)
hook = PostgresHook(connection=conn)

with pytest.raises(AirflowException):
hook.sqlalchemy_url

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_cursor(self, mock_connect):
self.connection.extra = '{"cursor": "dictcursor"}'
self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query": {"gssencmode": "disable"}}'
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
cursor_factory=psycopg2.extras.DictCursor,
Expand Down

0 comments on commit a10b3fc

Please sign in to comment.