Skip to content

Commit

Permalink
Don't use Pandas for SQLTableCheckOperator
Browse files Browse the repository at this point in the history
Pandas is an optional extra for common-sql provider, so _forcing_ it for
a query that is going to return a couple of rows is not a good idea
  • Loading branch information
ashb committed Aug 19, 2022
1 parent 5c48ed1 commit 61a2d66
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 24 deletions.
10 changes: 4 additions & 6 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,15 @@ def execute(self, context: 'Context'):
self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) "
f"AS check_table {partition_clause_statement};"

records = hook.get_pandas_df(self.sql)
records = hook.get_records(self.sql)

if records.empty:
if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")

records.columns = records.columns.str.lower()
self.log.info("Record:\n%s", records)

for row in records.iterrows():
check = row[1].get("check_name")
result = row[1].get("check_result")
for row in records:
check, result = row
self.checks[check]["success"] = parse_boolean(str(result))

failed_tests = _get_failed_checks(self.checks)
Expand Down
25 changes: 7 additions & 18 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from unittest import mock
from unittest.mock import MagicMock

import pandas as pd
import pytest

from airflow import DAG
Expand All @@ -47,7 +46,7 @@ class MockHook:
def get_first(self):
return

def get_pandas_df(self):
def get_records(self):
return


Expand Down Expand Up @@ -120,32 +119,22 @@ class TestTableCheckOperator:
}

def _construct_operator(self, monkeypatch, checks, return_df):
def get_pandas_df_return(*arg):
def get_records(*arg):
return return_df

operator = SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks)
monkeypatch.setattr(operator, "get_db_hook", _get_mock_db_hook)
monkeypatch.setattr(MockHook, "get_pandas_df", get_pandas_df_return)
monkeypatch.setattr(MockHook, "get_records", get_records)
return operator

def test_pass_all_checks_check(self, monkeypatch):
df = pd.DataFrame(
data={
"check_name": ["row_count_check", "column_sum_check"],
"check_result": [
"1",
"y",
],
}
)
operator = self._construct_operator(monkeypatch, self.checks, df)
records = [('row_count_check', 1), ('column_sum_check', 'y')]
operator = self._construct_operator(monkeypatch, self.checks, records)
operator.execute(context=MagicMock())

def test_fail_all_checks_check(self, monkeypatch):
df = pd.DataFrame(
data={"check_name": ["row_count_check", "column_sum_check"], "check_result": ["0", "n"]}
)
operator = self._construct_operator(monkeypatch, self.checks, df)
records = [('row_count_check', 0), ('column_sum_check', 'n')]
operator = self._construct_operator(monkeypatch, self.checks, records)
with pytest.raises(AirflowException):
operator.execute(context=MagicMock())

Expand Down

0 comments on commit 61a2d66

Please sign in to comment.