From 5f47e60962b3123b1e6c8b42bef2c2643f54b601 Mon Sep 17 00:00:00 2001 From: darkag Date: Wed, 6 Sep 2023 23:09:53 +0200 Subject: [PATCH] Custom fetch all handler for vertica to not miss errors (#34041) * Custom fetch all handler for vertica to not miss errors * missing parameter * Fix test (set nextset to none) * fix static checks * fix static-check error * fix static-check error * rename variable * add docstring * fix docstring --- airflow/providers/vertica/hooks/vertica.py | 80 ++++++++++++++++++- tests/providers/vertica/hooks/test_vertica.py | 1 + 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/airflow/providers/vertica/hooks/vertica.py b/airflow/providers/vertica/hooks/vertica.py index 06b2e3cf179b7..91672e2aec7a5 100644 --- a/airflow/providers/vertica/hooks/vertica.py +++ b/airflow/providers/vertica/hooks/vertica.py @@ -17,13 +17,45 @@ # under the License. from __future__ import annotations +from typing import Any, Callable, Iterable, Mapping, overload + from vertica_python import connect -from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler + + +def vertica_fetch_all_handler(cursor) -> list[tuple] | None: + """ + Replace the default DbApiHook fetch_all_handler in order to fix this issue https://github.com/apache/airflow/issues/32993. + + Returned value will not change after the initial call of fetch_all_handler, all the remaining code is here + only to make vertica client throws error. + With Vertica, if you run the following sql (with split_statements set to false): + + INSERT INTO MyTable (Key, Label) values (1, 'test 1'); + INSERT INTO MyTable (Key, Label) values (1, 'test 2'); + INSERT INTO MyTable (Key, Label) values (3, 'test 3'); + + each insert will have its own result set and if you don't try to fetch data of those result sets + you won't detect error on the second insert. + """ + result = fetch_all_handler(cursor) + # loop on all statement result sets to get errors + if cursor.description is not None: + while cursor.nextset(): + if cursor.description is not None: + row = cursor.fetchone() + while row: + row = cursor.fetchone() + return result class VerticaHook(DbApiHook): - """Interact with Vertica.""" + """ + Interact with Vertica. + + This hook use a customized version of default fetch_all_handler named vertica_fetch_all_handler. + """ conn_name_attr = "vertica_conn_id" default_conn_name = "vertica_default" @@ -32,7 +64,7 @@ class VerticaHook(DbApiHook): supports_autocommit = True def get_conn(self) -> connect: - """Return verticaql connection object.""" + """Return vertica connection object.""" conn = self.get_connection(self.vertica_conn_id) # type: ignore conn_config = { "user": conn.login, @@ -99,3 +131,45 @@ def get_conn(self) -> connect: conn = connect(**conn_config) return conn + + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: None = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> None: + ... + + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: Callable[[Any], Any] = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> Any | list[Any]: + ... + + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping | None = None, + handler: Callable[[Any], Any] | None = None, + split_statements: bool = False, + return_last: bool = True, + ) -> Any | list[Any] | None: + """ + Overwrite the common sql run. + + Will automatically replace fetch_all_handler by vertica_fetch_all_handler. + """ + if handler == fetch_all_handler: + handler = vertica_fetch_all_handler + return DbApiHook.run(self, sql, autocommit, parameters, handler, split_statements, return_last) diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py index 146c3bcd1136b..e5ff2538ebeab 100644 --- a/tests/providers/vertica/hooks/test_vertica.py +++ b/tests/providers/vertica/hooks/test_vertica.py @@ -127,6 +127,7 @@ def test_get_conn_extra_parameters_cast(self, mock_connect): class TestVerticaHook: def setup_method(self): self.cur = mock.MagicMock(rowcount=0) + self.cur.nextset.side_effect = [None] self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur conn = self.conn