Skip to content

Commit

Permalink
Support connection extra parameters in MsSqlHook (#44310)
Browse files Browse the repository at this point in the history
* enable extras

* mark db_test

* connections to fixture
  • Loading branch information
jx2lee authored Nov 26, 2024
1 parent 6748b2a commit e464539
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs)
def get_conn(self) -> PymssqlConnection:
"""Return ``pymssql`` connection object."""
conn = self.connection
extra_conn_args = {key: val for key, val in conn.extra_dejson.items() if key != "sqlalchemy_scheme"}
return pymssql.connect(
server=conn.host,
user=conn.login,
password=conn.password,
database=self.schema or conn.schema,
port=str(conn.port),
**extra_conn_args,
)

def set_autocommit(
Expand Down
308 changes: 147 additions & 161 deletions providers/tests/microsoft/mssql/hooks/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

from unittest import mock
from urllib.parse import quote_plus

import pytest

Expand All @@ -31,33 +30,9 @@
except ImportError:
pytest.skip("MSSQL not available", allow_module_level=True)

PYMSSQL_CONN = Connection(
conn_type="mssql", host="ip", schema="share", login="username", password="password", port=8081
)
PYMSSQL_CONN_ALT = Connection(
conn_type="mssql", host="ip", schema="", login="username", password="password", port=8081
)
PYMSSQL_CONN_ALT_1 = Connection(
conn_type="mssql",
host="ip",
schema="",
login="username",
password="password",
port=8081,
extra={"SQlalchemy_Scheme": "mssql+testdriver"},
)
PYMSSQL_CONN_ALT_2 = Connection(
conn_type="mssql",
host="ip",
schema="",
login="username",
password="password",
port=8081,
extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"},
)


def get_primary_keys(self, table: str) -> list[str]:

@pytest.fixture
def get_primary_keys():
return [
"GroupDisplayName",
"OwnerPrincipalName",
Expand All @@ -66,11 +41,49 @@ def get_primary_keys(self, table: str) -> list[str]:
]


@pytest.fixture
def mssql_connections():
return {
"default": Connection(
conn_type="mssql", host="ip", schema="share", login="username", password="password", port=8081
),
"alt": Connection(
conn_type="mssql", host="ip", schema="", login="username", password="password", port=8081
),
"alt_1": Connection(
conn_type="mssql",
host="ip",
schema="",
login="username",
password="password",
port=8081,
extra={"SQlalchemy_Scheme": "mssql+testdriver"},
),
"alt_2": Connection(
conn_type="mssql",
host="ip",
schema="",
login="username",
password="password",
port=8081,
extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"},
),
}


URI_TEST_CASES = [
("default", "mssql+pymssql://username:password@ip:8081/share"),
("alt", "mssql+pymssql://username:password@ip:8081"),
("alt_1", "mssql+testdriver://username:password@ip:8081/"),
("alt_2", "mssql+testdriver://username:password@ip:8081/?myparam=5%40-%2F%2F%2A"),
]


class TestMsSqlHook:
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn):
get_connection.return_value = PYMSSQL_CONN
def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn, mssql_connections):
get_connection.return_value = mssql_connections["default"]
mssql_get_conn.return_value = mock.Mock()

hook = MsSqlHook()
Expand All @@ -81,8 +94,8 @@ def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn)

@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
def test_set_autocommit_should_invoke_autocommit(self, get_connection, mssql_get_conn):
get_connection.return_value = PYMSSQL_CONN
def test_set_autocommit_should_invoke_autocommit(self, get_connection, mssql_get_conn, mssql_connections):
get_connection.return_value = mssql_connections["default"]
mssql_get_conn.return_value = mock.Mock()
autocommit_value = mock.Mock()

Expand All @@ -95,8 +108,10 @@ def test_set_autocommit_should_invoke_autocommit(self, get_connection, mssql_get

@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
def test_get_autocommit_should_return_autocommit_state(self, get_connection, mssql_get_conn):
get_connection.return_value = PYMSSQL_CONN
def test_get_autocommit_should_return_autocommit_state(
self, get_connection, mssql_get_conn, mssql_connections
):
get_connection.return_value = mssql_connections["default"]
mssql_get_conn.return_value = mock.Mock()
mssql_get_conn.return_value.autocommit_state = "autocommit_state"

Expand All @@ -106,47 +121,10 @@ def test_get_autocommit_should_return_autocommit_state(self, get_connection, mss
mssql_get_conn.assert_called_once()
assert hook.get_autocommit(conn) == "autocommit_state"

@pytest.mark.parametrize(
"conn, exp_uri",
[
(
PYMSSQL_CONN,
(
"mssql+pymssql://"
f"{quote_plus(PYMSSQL_CONN.login)}:{quote_plus(PYMSSQL_CONN.password)}"
f"@{PYMSSQL_CONN.host}:{PYMSSQL_CONN.port}/{PYMSSQL_CONN.schema}"
),
),
(
PYMSSQL_CONN_ALT,
(
"mssql+pymssql://"
f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}"
),
),
(
PYMSSQL_CONN_ALT_1,
(
f"{PYMSSQL_CONN_ALT_1.extra_dejson['SQlalchemy_Scheme']}://"
f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}/"
),
),
(
PYMSSQL_CONN_ALT_2,
(
f"{PYMSSQL_CONN_ALT_2.extra_dejson['SQlalchemy_Scheme']}://"
f"{quote_plus(PYMSSQL_CONN_ALT_2.login)}:{quote_plus(PYMSSQL_CONN_ALT_2.password)}"
f"@{PYMSSQL_CONN_ALT_2.host}:{PYMSSQL_CONN_ALT_2.port}/"
f"?myparam={quote_plus(PYMSSQL_CONN_ALT_2.extra_dejson['myparam'])}"
),
),
],
)
@pytest.mark.parametrize("conn_id,exp_uri", URI_TEST_CASES)
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
def test_get_uri_driver_rewrite(self, get_connection, conn, exp_uri):
get_connection.return_value = conn
def test_get_uri_driver_rewrite(self, get_connection, mssql_connections, conn_id, exp_uri):
get_connection.return_value = mssql_connections[conn_id]

hook = MsSqlHook()
res_uri = hook.get_uri()
Expand All @@ -155,8 +133,8 @@ def test_get_uri_driver_rewrite(self, get_connection, conn, exp_uri):
assert res_uri == exp_uri

@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
def test_sqlalchemy_scheme_is_default(self, get_connection):
get_connection.return_value = PYMSSQL_CONN
def test_sqlalchemy_scheme_is_default(self, get_connection, mssql_connections):
get_connection.return_value = mssql_connections["default"]

hook = MsSqlHook()
assert hook.sqlalchemy_scheme == hook.DEFAULT_SQLALCHEMY_SCHEME
Expand All @@ -167,101 +145,109 @@ def test_sqlalchemy_scheme_is_from_hook(self):
assert hook.sqlalchemy_scheme == "mssql+mytestdriver"

@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection):
get_connection.return_value = PYMSSQL_CONN_ALT_1
def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection, mssql_connections):
get_connection.return_value = mssql_connections["alt_1"]

hook = MsSqlHook()
scheme = hook.sqlalchemy_scheme
get_connection.assert_called()
assert scheme == PYMSSQL_CONN_ALT_1.extra_dejson["SQlalchemy_Scheme"]
assert scheme == mssql_connections["alt_1"].extra_dejson["SQlalchemy_Scheme"]

@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
def test_get_sqlalchemy_engine(self, get_connection):
get_connection.return_value = PYMSSQL_CONN
def test_get_sqlalchemy_engine(self, get_connection, mssql_connections):
get_connection.return_value = mssql_connections["default"]

hook = MsSqlHook()
hook.get_sqlalchemy_engine()

@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_primary_keys", get_primary_keys)
def test_generate_insert_sql(self, get_connection):
get_connection.return_value = PYMSSQL_CONN
def test_generate_insert_sql(self, get_connection, mssql_connections, get_primary_keys):
get_connection.return_value = mssql_connections["default"]

hook = MsSqlHook()
with mock.patch.object(hook, "get_primary_keys", return_value=get_primary_keys):
sql = hook._generate_insert_sql(
table="YAMMER_GROUPS_ACTIVITY_DETAIL",
values=[
"2024-07-17",
"daa5b44c-80d6-4e22-85b5-a94e04cf7206",
"no-reply@microsoft.com",
"2024-07-17",
0,
0.0,
"MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5",
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
"PT0S",
"PT0S",
"PT0S",
0,
0,
0,
"Yes",
0,
0,
"APACHE",
0.0,
0,
"Yes",
1,
"2024-07-17T00:00:00+00:00",
],
target_fields=[
"ReportRefreshDate",
"UserId",
"UserPrincipalName",
"LastActivityDate",
"IsDeleted",
"DeletedDate",
"AssignedProducts",
"TeamChatMessageCount",
"PrivateChatMessageCount",
"CallCount",
"MeetingCount",
"MeetingsOrganizedCount",
"MeetingsAttendedCount",
"AdHocMeetingsOrganizedCount",
"AdHocMeetingsAttendedCount",
"ScheduledOne-timeMeetingsOrganizedCount",
"ScheduledOne-timeMeetingsAttendedCount",
"ScheduledRecurringMeetingsOrganizedCount",
"ScheduledRecurringMeetingsAttendedCount",
"AudioDuration",
"VideoDuration",
"ScreenShareDuration",
"AudioDurationInSeconds",
"VideoDurationInSeconds",
"ScreenShareDurationInSeconds",
"HasOtherAction",
"UrgentMessages",
"PostMessages",
"TenantDisplayName",
"SharedChannelTenantDisplayNames",
"ReplyMessages",
"IsLicensed",
"ReportPeriod",
"LoadDate",
],
replace=True,
)
assert sql == load_file("resources", "replace.sql")

@pytest.mark.db_test
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
def test_get_extra(self, get_connection, mssql_connections):
get_connection.return_value = mssql_connections["alt_2"]

hook = MsSqlHook()
sql = hook._generate_insert_sql(
table="YAMMER_GROUPS_ACTIVITY_DETAIL",
values=[
"2024-07-17",
"daa5b44c-80d6-4e22-85b5-a94e04cf7206",
"no-reply@microsoft.com",
"2024-07-17",
0,
0.0,
"MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5",
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
"PT0S",
"PT0S",
"PT0S",
0,
0,
0,
"Yes",
0,
0,
"APACHE",
0.0,
0,
"Yes",
1,
"2024-07-17T00:00:00+00:00",
],
target_fields=[
"ReportRefreshDate",
"UserId",
"UserPrincipalName",
"LastActivityDate",
"IsDeleted",
"DeletedDate",
"AssignedProducts",
"TeamChatMessageCount",
"PrivateChatMessageCount",
"CallCount",
"MeetingCount",
"MeetingsOrganizedCount",
"MeetingsAttendedCount",
"AdHocMeetingsOrganizedCount",
"AdHocMeetingsAttendedCount",
"ScheduledOne-timeMeetingsOrganizedCount",
"ScheduledOne-timeMeetingsAttendedCount",
"ScheduledRecurringMeetingsOrganizedCount",
"ScheduledRecurringMeetingsAttendedCount",
"AudioDuration",
"VideoDuration",
"ScreenShareDuration",
"AudioDurationInSeconds",
"VideoDurationInSeconds",
"ScreenShareDurationInSeconds",
"HasOtherAction",
"UrgentMessages",
"PostMessages",
"TenantDisplayName",
"SharedChannelTenantDisplayNames",
"ReplyMessages",
"IsLicensed",
"ReportPeriod",
"LoadDate",
],
replace=True,
)
assert sql == load_file("resources", "replace.sql")
assert hook.get_connection().extra

0 comments on commit e464539

Please sign in to comment.