From e46453966ef521cc8f6dc7019566ea2fdcc0063b Mon Sep 17 00:00:00 2001 From: jaejun <63435794+jx2lee@users.noreply.github.com> Date: Tue, 26 Nov 2024 16:47:16 +0900 Subject: [PATCH] Support connection extra parameters in MsSqlHook (#44310) * enable extras * mark db_test * connections to fixture --- .../providers/microsoft/mssql/hooks/mssql.py | 2 + .../tests/microsoft/mssql/hooks/test_mssql.py | 308 +++++++++--------- 2 files changed, 149 insertions(+), 161 deletions(-) diff --git a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py index d45a43a188ce6..a367250ed33c4 100644 --- a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py +++ b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py @@ -137,12 +137,14 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) def get_conn(self) -> PymssqlConnection: """Return ``pymssql`` connection object.""" conn = self.connection + extra_conn_args = {key: val for key, val in conn.extra_dejson.items() if key != "sqlalchemy_scheme"} return pymssql.connect( server=conn.host, user=conn.login, password=conn.password, database=self.schema or conn.schema, port=str(conn.port), + **extra_conn_args, ) def set_autocommit( diff --git a/providers/tests/microsoft/mssql/hooks/test_mssql.py b/providers/tests/microsoft/mssql/hooks/test_mssql.py index 1b43bb787835b..be8f921112a4a 100644 --- a/providers/tests/microsoft/mssql/hooks/test_mssql.py +++ b/providers/tests/microsoft/mssql/hooks/test_mssql.py @@ -18,7 +18,6 @@ from __future__ import annotations from unittest import mock -from urllib.parse import quote_plus import pytest @@ -31,33 +30,9 @@ except ImportError: pytest.skip("MSSQL not available", allow_module_level=True) -PYMSSQL_CONN = Connection( - conn_type="mssql", host="ip", schema="share", login="username", password="password", port=8081 -) -PYMSSQL_CONN_ALT = Connection( - conn_type="mssql", host="ip", schema="", login="username", password="password", port=8081 -) -PYMSSQL_CONN_ALT_1 = Connection( - conn_type="mssql", - host="ip", - schema="", - login="username", - password="password", - port=8081, - extra={"SQlalchemy_Scheme": "mssql+testdriver"}, -) -PYMSSQL_CONN_ALT_2 = Connection( - conn_type="mssql", - host="ip", - schema="", - login="username", - password="password", - port=8081, - extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"}, -) - - -def get_primary_keys(self, table: str) -> list[str]: + +@pytest.fixture +def get_primary_keys(): return [ "GroupDisplayName", "OwnerPrincipalName", @@ -66,11 +41,49 @@ def get_primary_keys(self, table: str) -> list[str]: ] +@pytest.fixture +def mssql_connections(): + return { + "default": Connection( + conn_type="mssql", host="ip", schema="share", login="username", password="password", port=8081 + ), + "alt": Connection( + conn_type="mssql", host="ip", schema="", login="username", password="password", port=8081 + ), + "alt_1": Connection( + conn_type="mssql", + host="ip", + schema="", + login="username", + password="password", + port=8081, + extra={"SQlalchemy_Scheme": "mssql+testdriver"}, + ), + "alt_2": Connection( + conn_type="mssql", + host="ip", + schema="", + login="username", + password="password", + port=8081, + extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"}, + ), + } + + +URI_TEST_CASES = [ + ("default", "mssql+pymssql://username:password@ip:8081/share"), + ("alt", "mssql+pymssql://username:password@ip:8081"), + ("alt_1", "mssql+testdriver://username:password@ip:8081/"), + ("alt_2", "mssql+testdriver://username:password@ip:8081/?myparam=5%40-%2F%2F%2A"), +] + + class TestMsSqlHook: @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn") @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection") - def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn): - get_connection.return_value = PYMSSQL_CONN + def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn, mssql_connections): + get_connection.return_value = mssql_connections["default"] mssql_get_conn.return_value = mock.Mock() hook = MsSqlHook() @@ -81,8 +94,8 @@ def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn) @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn") @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection") - def test_set_autocommit_should_invoke_autocommit(self, get_connection, mssql_get_conn): - get_connection.return_value = PYMSSQL_CONN + def test_set_autocommit_should_invoke_autocommit(self, get_connection, mssql_get_conn, mssql_connections): + get_connection.return_value = mssql_connections["default"] mssql_get_conn.return_value = mock.Mock() autocommit_value = mock.Mock() @@ -95,8 +108,10 @@ def test_set_autocommit_should_invoke_autocommit(self, get_connection, mssql_get @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn") @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection") - def test_get_autocommit_should_return_autocommit_state(self, get_connection, mssql_get_conn): - get_connection.return_value = PYMSSQL_CONN + def test_get_autocommit_should_return_autocommit_state( + self, get_connection, mssql_get_conn, mssql_connections + ): + get_connection.return_value = mssql_connections["default"] mssql_get_conn.return_value = mock.Mock() mssql_get_conn.return_value.autocommit_state = "autocommit_state" @@ -106,47 +121,10 @@ def test_get_autocommit_should_return_autocommit_state(self, get_connection, mss mssql_get_conn.assert_called_once() assert hook.get_autocommit(conn) == "autocommit_state" - @pytest.mark.parametrize( - "conn, exp_uri", - [ - ( - PYMSSQL_CONN, - ( - "mssql+pymssql://" - f"{quote_plus(PYMSSQL_CONN.login)}:{quote_plus(PYMSSQL_CONN.password)}" - f"@{PYMSSQL_CONN.host}:{PYMSSQL_CONN.port}/{PYMSSQL_CONN.schema}" - ), - ), - ( - PYMSSQL_CONN_ALT, - ( - "mssql+pymssql://" - f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}" - f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}" - ), - ), - ( - PYMSSQL_CONN_ALT_1, - ( - f"{PYMSSQL_CONN_ALT_1.extra_dejson['SQlalchemy_Scheme']}://" - f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}" - f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}/" - ), - ), - ( - PYMSSQL_CONN_ALT_2, - ( - f"{PYMSSQL_CONN_ALT_2.extra_dejson['SQlalchemy_Scheme']}://" - f"{quote_plus(PYMSSQL_CONN_ALT_2.login)}:{quote_plus(PYMSSQL_CONN_ALT_2.password)}" - f"@{PYMSSQL_CONN_ALT_2.host}:{PYMSSQL_CONN_ALT_2.port}/" - f"?myparam={quote_plus(PYMSSQL_CONN_ALT_2.extra_dejson['myparam'])}" - ), - ), - ], - ) + @pytest.mark.parametrize("conn_id,exp_uri", URI_TEST_CASES) @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") - def test_get_uri_driver_rewrite(self, get_connection, conn, exp_uri): - get_connection.return_value = conn + def test_get_uri_driver_rewrite(self, get_connection, mssql_connections, conn_id, exp_uri): + get_connection.return_value = mssql_connections[conn_id] hook = MsSqlHook() res_uri = hook.get_uri() @@ -155,8 +133,8 @@ def test_get_uri_driver_rewrite(self, get_connection, conn, exp_uri): assert res_uri == exp_uri @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") - def test_sqlalchemy_scheme_is_default(self, get_connection): - get_connection.return_value = PYMSSQL_CONN + def test_sqlalchemy_scheme_is_default(self, get_connection, mssql_connections): + get_connection.return_value = mssql_connections["default"] hook = MsSqlHook() assert hook.sqlalchemy_scheme == hook.DEFAULT_SQLALCHEMY_SCHEME @@ -167,101 +145,109 @@ def test_sqlalchemy_scheme_is_from_hook(self): assert hook.sqlalchemy_scheme == "mssql+mytestdriver" @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") - def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection): - get_connection.return_value = PYMSSQL_CONN_ALT_1 + def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection, mssql_connections): + get_connection.return_value = mssql_connections["alt_1"] hook = MsSqlHook() scheme = hook.sqlalchemy_scheme get_connection.assert_called() - assert scheme == PYMSSQL_CONN_ALT_1.extra_dejson["SQlalchemy_Scheme"] + assert scheme == mssql_connections["alt_1"].extra_dejson["SQlalchemy_Scheme"] @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") - def test_get_sqlalchemy_engine(self, get_connection): - get_connection.return_value = PYMSSQL_CONN + def test_get_sqlalchemy_engine(self, get_connection, mssql_connections): + get_connection.return_value = mssql_connections["default"] hook = MsSqlHook() hook.get_sqlalchemy_engine() @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") - @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_primary_keys", get_primary_keys) - def test_generate_insert_sql(self, get_connection): - get_connection.return_value = PYMSSQL_CONN + def test_generate_insert_sql(self, get_connection, mssql_connections, get_primary_keys): + get_connection.return_value = mssql_connections["default"] + + hook = MsSqlHook() + with mock.patch.object(hook, "get_primary_keys", return_value=get_primary_keys): + sql = hook._generate_insert_sql( + table="YAMMER_GROUPS_ACTIVITY_DETAIL", + values=[ + "2024-07-17", + "daa5b44c-80d6-4e22-85b5-a94e04cf7206", + "no-reply@microsoft.com", + "2024-07-17", + 0, + 0.0, + "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5", + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + "PT0S", + "PT0S", + "PT0S", + 0, + 0, + 0, + "Yes", + 0, + 0, + "APACHE", + 0.0, + 0, + "Yes", + 1, + "2024-07-17T00:00:00+00:00", + ], + target_fields=[ + "ReportRefreshDate", + "UserId", + "UserPrincipalName", + "LastActivityDate", + "IsDeleted", + "DeletedDate", + "AssignedProducts", + "TeamChatMessageCount", + "PrivateChatMessageCount", + "CallCount", + "MeetingCount", + "MeetingsOrganizedCount", + "MeetingsAttendedCount", + "AdHocMeetingsOrganizedCount", + "AdHocMeetingsAttendedCount", + "ScheduledOne-timeMeetingsOrganizedCount", + "ScheduledOne-timeMeetingsAttendedCount", + "ScheduledRecurringMeetingsOrganizedCount", + "ScheduledRecurringMeetingsAttendedCount", + "AudioDuration", + "VideoDuration", + "ScreenShareDuration", + "AudioDurationInSeconds", + "VideoDurationInSeconds", + "ScreenShareDurationInSeconds", + "HasOtherAction", + "UrgentMessages", + "PostMessages", + "TenantDisplayName", + "SharedChannelTenantDisplayNames", + "ReplyMessages", + "IsLicensed", + "ReportPeriod", + "LoadDate", + ], + replace=True, + ) + assert sql == load_file("resources", "replace.sql") + + @pytest.mark.db_test + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") + def test_get_extra(self, get_connection, mssql_connections): + get_connection.return_value = mssql_connections["alt_2"] hook = MsSqlHook() - sql = hook._generate_insert_sql( - table="YAMMER_GROUPS_ACTIVITY_DETAIL", - values=[ - "2024-07-17", - "daa5b44c-80d6-4e22-85b5-a94e04cf7206", - "no-reply@microsoft.com", - "2024-07-17", - 0, - 0.0, - "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5", - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - "PT0S", - "PT0S", - "PT0S", - 0, - 0, - 0, - "Yes", - 0, - 0, - "APACHE", - 0.0, - 0, - "Yes", - 1, - "2024-07-17T00:00:00+00:00", - ], - target_fields=[ - "ReportRefreshDate", - "UserId", - "UserPrincipalName", - "LastActivityDate", - "IsDeleted", - "DeletedDate", - "AssignedProducts", - "TeamChatMessageCount", - "PrivateChatMessageCount", - "CallCount", - "MeetingCount", - "MeetingsOrganizedCount", - "MeetingsAttendedCount", - "AdHocMeetingsOrganizedCount", - "AdHocMeetingsAttendedCount", - "ScheduledOne-timeMeetingsOrganizedCount", - "ScheduledOne-timeMeetingsAttendedCount", - "ScheduledRecurringMeetingsOrganizedCount", - "ScheduledRecurringMeetingsAttendedCount", - "AudioDuration", - "VideoDuration", - "ScreenShareDuration", - "AudioDurationInSeconds", - "VideoDurationInSeconds", - "ScreenShareDurationInSeconds", - "HasOtherAction", - "UrgentMessages", - "PostMessages", - "TenantDisplayName", - "SharedChannelTenantDisplayNames", - "ReplyMessages", - "IsLicensed", - "ReportPeriod", - "LoadDate", - ], - replace=True, - ) - assert sql == load_file("resources", "replace.sql") + assert hook.get_connection().extra