Skip to content

Commit

Permalink
feat(ingest/athena): Add option for Athena partitioned profiling (dat…
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored and aviv-julienjehannet committed Jul 25, 2024
1 parent af26252 commit 94f2ddc
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 44 deletions.
9 changes: 8 additions & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,14 @@
# sqlalchemy-bigquery is included here since it provides an implementation of
# a SQLalchemy-conform STRUCT type definition
"athena": sql_common
| {"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0", "sqlalchemy-bigquery>=1.4.1"},
# We need to set tenacity lower than 8.4.0 as
# this version has missing dependency asyncio
# https://github.com/jd/tenacity/issues/471
| {
"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0",
"sqlalchemy-bigquery>=1.4.1",
"tenacity!=8.4.0",
},
"azure-ad": set(),
"bigquery": sql_common
| bigquery_common
Expand Down
122 changes: 116 additions & 6 deletions metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -39,6 +40,7 @@
from great_expectations.dataset.dataset import Dataset
from great_expectations.dataset.sqlalchemy_dataset import SqlAlchemyDataset
from great_expectations.datasource.sqlalchemy_datasource import SqlAlchemyDatasource
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
from great_expectations.profile.base import ProfilerDataType
from great_expectations.profile.basic_dataset_profiler import BasicDatasetProfilerBase
from sqlalchemy.engine import Connection, Engine
Expand Down Expand Up @@ -72,9 +74,14 @@
get_query_columns,
)

if TYPE_CHECKING:
from pyathena.cursor import Cursor

assert MARKUPSAFE_PATCHED
logger: logging.Logger = logging.getLogger(__name__)

_original_get_column_median = SqlAlchemyDataset.get_column_median

P = ParamSpec("P")
POSTGRESQL = "postgresql"
MYSQL = "mysql"
Expand Down Expand Up @@ -203,6 +210,47 @@ def _get_column_quantiles_bigquery_patch( # type:ignore
return list()


def _get_column_quantiles_awsathena_patch( # type:ignore
self, column: str, quantiles: Iterable
) -> list:
import ast

table_name = ".".join(
[f'"{table_part}"' for table_part in str(self._table).split(".")]
)

quantiles_list = list(quantiles)
quantiles_query = (
f"SELECT approx_percentile({column}, ARRAY{str(quantiles_list)}) as quantiles "
f"from (SELECT {column} from {table_name})"
)
try:
quantiles_results = self.engine.execute(quantiles_query).fetchone()[0]
quantiles_results_list = ast.literal_eval(quantiles_results)
return quantiles_results_list

except ProgrammingError as pe:
self._treat_quantiles_exception(pe)
return []


def _get_column_median_patch(self, column):
# AWS Athena and presto have an special function that can be used to retrieve the median
if (
self.sql_engine_dialect.name.lower() == GXSqlDialect.AWSATHENA
or self.sql_engine_dialect.name.lower() == GXSqlDialect.TRINO
):
table_name = ".".join(
[f'"{table_part}"' for table_part in str(self._table).split(".")]
)
element_values = self.engine.execute(
f"SELECT approx_percentile({column}, 0.5) FROM {table_name}"
)
return convert_to_json_serializable(element_values.fetchone()[0])
else:
return _original_get_column_median(self, column)


def _is_single_row_query_method(query: Any) -> bool:
SINGLE_ROW_QUERY_FILES = {
# "great_expectations/dataset/dataset.py",
Expand Down Expand Up @@ -1038,6 +1086,12 @@ def generate_profiles(
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_bigquery",
_get_column_quantiles_bigquery_patch,
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset._get_column_quantiles_awsathena",
_get_column_quantiles_awsathena_patch,
), unittest.mock.patch(
"great_expectations.dataset.sqlalchemy_dataset.SqlAlchemyDataset.get_column_median",
_get_column_median_patch,
), concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as async_executor, SQLAlchemyQueryCombiner(
Expand Down Expand Up @@ -1114,15 +1168,16 @@ def _generate_profile_from_request(
**request.batch_kwargs,
)

def _drop_trino_temp_table(self, temp_dataset: Dataset) -> None:
def _drop_temp_table(self, temp_dataset: Dataset) -> None:
schema = temp_dataset._table.schema
table = temp_dataset._table.name
table_name = f'"{schema}"."{table}"' if schema else f'"{table}"'
try:
with self.base_engine.connect() as connection:
connection.execute(f"drop view if exists {schema}.{table}")
logger.debug(f"View {schema}.{table} was dropped.")
connection.execute(f"drop view if exists {table_name}")
logger.debug(f"View {table_name} was dropped.")
except Exception:
logger.warning(f"Unable to delete trino temporary table: {schema}.{table}")
logger.warning(f"Unable to delete temporary table: {table_name}")

def _generate_single_profile(
self,
Expand All @@ -1149,6 +1204,19 @@ def _generate_single_profile(
}

bigquery_temp_table: Optional[str] = None
temp_view: Optional[str] = None
if platform and platform.upper() == "ATHENA" and (custom_sql):
if custom_sql is not None:
# Note that limit and offset are not supported for custom SQL.
temp_view = create_athena_temp_table(
self, custom_sql, pretty_name, self.base_engine.raw_connection()
)
ge_config["table"] = temp_view
ge_config["schema"] = None
ge_config["limit"] = None
ge_config["offset"] = None
custom_sql = None

if platform == BIGQUERY and (
custom_sql or self.config.limit or self.config.offset
):
Expand Down Expand Up @@ -1234,8 +1302,16 @@ def _generate_single_profile(
)
return None
finally:
if batch is not None and self.base_engine.engine.name == TRINO:
self._drop_trino_temp_table(batch)
if batch is not None and self.base_engine.engine.name.upper() in [
"TRINO",
"AWSATHENA",
]:
if (
self.base_engine.engine.name.upper() == "TRINO"
or temp_view is not None
):
self._drop_temp_table(batch)
# if we are not on Trino then we only drop table if temp table variable was set

def _get_ge_dataset(
self,
Expand Down Expand Up @@ -1299,6 +1375,40 @@ def _get_column_types_to_ignore(dialect_name: str) -> List[str]:
return []


def create_athena_temp_table(
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
sql: str,
table_pretty_name: str,
raw_connection: Any,
) -> Optional[str]:
try:
cursor: "Cursor" = cast("Cursor", raw_connection.cursor())
logger.debug(f"Creating view for {table_pretty_name}: {sql}")
temp_view = f"ge_{uuid.uuid4()}"
if "." in table_pretty_name:
schema_part = table_pretty_name.split(".")[-1]
schema_part_quoted = ".".join(
[f'"{part}"' for part in str(schema_part).split(".")]
)
temp_view = f"{schema_part_quoted}_{temp_view}"

temp_view = f"ge_{uuid.uuid4()}"
cursor.execute(f'create or replace view "{temp_view}" as {sql}')
except Exception as e:
if not instance.config.catch_exceptions:
raise e
logger.exception(f"Encountered exception while profiling {table_pretty_name}")
instance.report.report_warning(
table_pretty_name,
f"Profiling exception {e} when running custom sql {sql}",
)
return None
finally:
raw_connection.close()

return temp_view


def create_bigquery_temp_table(
instance: Union[DatahubGEProfiler, _SingleDatasetProfiler],
bq_sql: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class GEProfilingConfig(ConfigModel):

partition_profiling_enabled: bool = Field(
default=True,
description="Whether to profile partitioned tables. Only BigQuery supports this. "
description="Whether to profile partitioned tables. Only BigQuery and Aws Athena supports this. "
"If enabled, latest partition data is used for profiling.",
)
partition_datetime: Optional[datetime.datetime] = Field(
Expand Down
Loading

0 comments on commit 94f2ddc

Please sign in to comment.