Skip to content

Commit

Permalink
Suppress jaydebeapi.Error when setAutoCommit or getAutoCommit is unsu…
Browse files Browse the repository at this point in the history
…pported by JDBC driver (#38707)


---------

Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
dabla and davidblain-infrabel authored Apr 12, 2024
1 parent 7ab24c7 commit 41869d3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
19 changes: 17 additions & 2 deletions airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# under the License.
from __future__ import annotations

import traceback
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

import jaydebeapi
Expand All @@ -27,6 +30,15 @@
from airflow.models.connection import Connection


@contextmanager
def suppress_and_warn(*exceptions: type[BaseException]):
"""Context manager that suppresses the given exceptions and logs a warning message."""
try:
yield
except exceptions as e:
warnings.warn(f"Exception suppressed: {e}\n{traceback.format_exc()}", category=UserWarning)


class JdbcHook(DbApiHook):
"""General hook for JDBC access.
Expand Down Expand Up @@ -152,7 +164,8 @@ def set_autocommit(self, conn: jaydebeapi.Connection, autocommit: bool) -> None:
:param conn: The connection.
:param autocommit: The connection's autocommit setting.
"""
conn.jconn.setAutoCommit(autocommit)
with suppress_and_warn(jaydebeapi.Error):
conn.jconn.setAutoCommit(autocommit)

def get_autocommit(self, conn: jaydebeapi.Connection) -> bool:
"""Get autocommit setting for the provided connection.
Expand All @@ -162,4 +175,6 @@ def get_autocommit(self, conn: jaydebeapi.Connection) -> bool:
to True on the connection. False if it is either not set, set to
False, or the connection does not support auto-commit.
"""
return conn.jconn.getAutoCommit()
with suppress_and_warn(jaydebeapi.Error):
return conn.jconn.getAutoCommit()
return False
28 changes: 27 additions & 1 deletion tests/providers/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from unittest import mock
from unittest.mock import Mock, patch

import jaydebeapi
import pytest

from airflow.exceptions import DeserializingResultError
from airflow.models import Connection
from airflow.providers.jdbc.hooks.jdbc import JdbcHook
from airflow.providers.jdbc.hooks.jdbc import JdbcHook, suppress_and_warn
from airflow.utils import db
from airflow.utils.context import AirflowContextDeprecationWarning

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -82,13 +85,27 @@ def test_jdbc_conn_set_autocommit(self, _):
jdbc_hook.set_autocommit(jdbc_conn, False)
jdbc_conn.jconn.setAutoCommit.assert_called_once_with(False)

@patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
def test_jdbc_conn_set_autocommit_when_not_supported(self, _):
jdbc_hook = JdbcHook()
jdbc_conn = jdbc_hook.get_conn()
jdbc_conn.jconn.setAutoCommit.side_effect = jaydebeapi.Error()
jdbc_hook.set_autocommit(jdbc_conn, False)

@patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
def test_jdbc_conn_get_autocommit(self, _):
jdbc_hook = JdbcHook()
jdbc_conn = jdbc_hook.get_conn()
jdbc_hook.get_autocommit(jdbc_conn)
jdbc_conn.jconn.getAutoCommit.assert_called_once_with()

@patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
def test_jdbc_conn_get_autocommit_when_not_supported_then_return_false(self, _):
jdbc_hook = JdbcHook()
jdbc_conn = jdbc_hook.get_conn()
jdbc_conn.jconn.getAutoCommit.side_effect = jaydebeapi.Error()
assert jdbc_hook.get_autocommit(jdbc_conn) is False

def test_driver_hook_params(self):
hook = get_hook(hook_params=dict(driver_path="Blah driver path", driver_class="Blah driver class"))
assert hook.driver_path == "Blah driver path"
Expand Down Expand Up @@ -161,3 +178,12 @@ def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self,
"have supplied 'driver_class' via connection extra but it will not be used"
) in caplog.text
assert driver_class == "Blah driver class"

def test_suppress_and_warn_when_raised_exception_is_suppressed(self):
with suppress_and_warn(AirflowContextDeprecationWarning):
raise AirflowContextDeprecationWarning()

def test_suppress_and_warn_when_raised_exception_is_not_suppressed(self):
with pytest.raises(AirflowContextDeprecationWarning):
with suppress_and_warn(DeserializingResultError):
raise AirflowContextDeprecationWarning()

0 comments on commit 41869d3

Please sign in to comment.