diff --git a/providers/src/airflow/providers/jdbc/hooks/jdbc.py b/providers/src/airflow/providers/jdbc/hooks/jdbc.py index 808b946bd9762..07b5fc42d9aba 100644 --- a/providers/src/airflow/providers/jdbc/hooks/jdbc.py +++ b/providers/src/airflow/providers/jdbc/hooks/jdbc.py @@ -152,6 +152,9 @@ def driver_class(self) -> str | None: @property def sqlalchemy_url(self) -> URL: conn = self.connection + sqlalchemy_query = conn.extra_dejson.get("sqlalchemy_query", {}) + if not isinstance(sqlalchemy_query, dict): + raise AirflowException("The parameter 'sqlalchemy_query' must be of type dict!") sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme") if sqlalchemy_scheme is None: raise AirflowException( @@ -164,6 +167,7 @@ def sqlalchemy_url(self) -> URL: host=conn.host, port=conn.port, database=conn.schema, + query=sqlalchemy_query, ) def get_sqlalchemy_engine(self, engine_kwargs=None): diff --git a/providers/src/airflow/providers/postgres/hooks/postgres.py b/providers/src/airflow/providers/postgres/hooks/postgres.py index f5dcfe2df49af..9b657c14416e4 100644 --- a/providers/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/src/airflow/providers/postgres/hooks/postgres.py @@ -29,6 +29,7 @@ from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor from sqlalchemy.engine import URL +from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: @@ -85,6 +86,17 @@ class PostgresHook(DbApiHook): hook_name = "Postgres" supports_autocommit = True supports_executemany = True + ignored_extra_options = { + "iam", + "redshift", + "redshift-serverless", + "cursor", + "cluster-identifier", + "workgroup-name", + "aws_conn_id", + "sqlalchemy_scheme", + "sqlalchemy_query", + } def __init__( self, *args, options: str | None = None, enable_log_db_messages: bool = False, **kwargs @@ -97,7 +109,10 @@ def __init__( @property def sqlalchemy_url(self) -> URL: - conn = self.get_connection(self.get_conn_id()) + conn = self.connection + query = conn.extra_dejson.get("sqlalchemy_query", {}) + if not isinstance(query, dict): + raise AirflowException("The parameter 'sqlalchemy_query' must be of type dict!") return URL.create( drivername="postgresql", username=conn.login, @@ -105,6 +120,7 @@ def sqlalchemy_url(self) -> URL: host=conn.host, port=conn.port, database=self.database or conn.schema, + query=query, ) def _get_cursor(self, raw_cursor: str) -> CursorType: @@ -143,15 +159,7 @@ def get_conn(self) -> connection: conn_args["options"] = self.options for arg_name, arg_val in conn.extra_dejson.items(): - if arg_name not in [ - "iam", - "redshift", - "redshift-serverless", - "cursor", - "cluster-identifier", - "workgroup-name", - "aws_conn_id", - ]: + if arg_name not in self.ignored_extra_options: conn_args[arg_name] = arg_val self.conn = psycopg2.connect(**conn_args) diff --git a/providers/tests/jdbc/hooks/test_jdbc.py b/providers/tests/jdbc/hooks/test_jdbc.py index 73015b5b522ab..ce4e526623440 100644 --- a/providers/tests/jdbc/hooks/test_jdbc.py +++ b/providers/tests/jdbc/hooks/test_jdbc.py @@ -219,6 +219,23 @@ def test_sqlalchemy_url_with_sqlalchemy_scheme(self): assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema" + def test_sqlalchemy_url_with_sqlalchemy_scheme_and_query(self): + conn_params = dict( + extra=json.dumps(dict(sqlalchemy_scheme="mssql", sqlalchemy_query={"servicename": "test"})) + ) + hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"} + hook = get_hook(conn_params=conn_params, hook_params=hook_params) + + assert str(hook.sqlalchemy_url) == "mssql://login:password@host:1234/schema?servicename=test" + + def test_sqlalchemy_url_with_sqlalchemy_scheme_and_wrong_query_value(self): + conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql", sqlalchemy_query="wrong type"))) + hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"} + hook = get_hook(conn_params=conn_params, hook_params=hook_params) + + with pytest.raises(AirflowException): + hook.sqlalchemy_url + def test_get_sqlalchemy_engine_verify_creator_is_being_used(self): jdbc_hook = get_hook( conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}), diff --git a/providers/tests/postgres/hooks/test_postgres.py b/providers/tests/postgres/hooks/test_postgres.py index 7a720534d4b77..76206d5795866 100644 --- a/providers/tests/postgres/hooks/test_postgres.py +++ b/providers/tests/postgres/hooks/test_postgres.py @@ -25,6 +25,7 @@ import psycopg2.extras import pytest +from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.types import NOTSET @@ -65,9 +66,42 @@ def test_get_uri(self, mock_connect): assert mock_connect.call_count == 1 assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database" + def test_sqlalchemy_url(self): + conn = Connection(login="login-conn", password="password-conn", host="host", schema="database") + hook = PostgresHook(connection=conn) + assert str(hook.sqlalchemy_url) == "postgresql://login-conn:password-conn@host/database" + + def test_sqlalchemy_url_with_sqlalchemy_query(self): + conn = Connection( + login="login-conn", + password="password-conn", + host="host", + schema="database", + extra=dict(sqlalchemy_query={"gssencmode": "disable"}), + ) + hook = PostgresHook(connection=conn) + + assert ( + str(hook.sqlalchemy_url) + == "postgresql://login-conn:password-conn@host/database?gssencmode=disable" + ) + + def test_sqlalchemy_url_with_wrong_sqlalchemy_query_value(self): + conn = Connection( + login="login-conn", + password="password-conn", + host="host", + schema="database", + extra=dict(sqlalchemy_query="wrong type"), + ) + hook = PostgresHook(connection=conn) + + with pytest.raises(AirflowException): + hook.sqlalchemy_url + @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_cursor(self, mock_connect): - self.connection.extra = '{"cursor": "dictcursor"}' + self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query": {"gssencmode": "disable"}}' self.db_hook.get_conn() mock_connect.assert_called_once_with( cursor_factory=psycopg2.extras.DictCursor,