From fc4ec374ed98910131f0faf9283da1640b31745e Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 26 Sep 2024 18:34:40 +0200 Subject: [PATCH] Fix sqlglot crasher with 'drop schema ...' statement (#2758) ## Changes Fix a crash ### Linked issues None ### Functionality None ### Tests - [x] added unit tests --------- Co-authored-by: Eric Vergnaud --- .../ucx/source_code/linters/from_table.py | 82 +++++++++++-------- .../source_code/linters/test_from_table.py | 8 ++ 2 files changed, 58 insertions(+), 32 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/from_table.py b/src/databricks/labs/ucx/source_code/linters/from_table.py index 9da4dd7a2c..45ffc01697 100644 --- a/src/databricks/labs/ucx/source_code/linters/from_table.py +++ b/src/databricks/labs/ucx/source_code/linters/from_table.py @@ -1,8 +1,8 @@ import logging from sqlglot import parse as parse_sql -from sqlglot.expressions import Table, Expression, Use, Create +from sqlglot.expressions import Table, Expression, Use, Create, Drop from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex -from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState, SqlLinter, Fixer +from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState, SqlLinter, Fixer, Failure logger = logging.getLogger(__name__) @@ -43,37 +43,55 @@ def schema(self): def lint_expression(self, expression: Expression): for table in expression.find_all(Table): - if isinstance(expression, Use): - # Sqlglot captures the database name in the Use statement as a Table, with - # the schema as the table name. - self._session_state.schema = table.name - continue - if isinstance(expression, Create) and getattr(expression, "kind", None) == "SCHEMA": - # Sqlglot captures the schema name in the Create statement as a Table, with - # the schema as the db name. - self._session_state.schema = table.db - continue + try: + yield from self._unsafe_lint_expression(expression, table) + except Exception as _: # pylint: disable=broad-exception-caught + yield Failure( + code='sql-parse-error', + message=f"Could not parse SQL expression: {expression} ", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ) - # we only migrate tables in the hive_metastore catalog - if self._catalog(table) != 'hive_metastore': - continue - # Sqlglot uses db instead of schema, watch out for that - src_schema = table.db if table.db else self._session_state.schema - if not src_schema: - logger.error(f"Could not determine schema for table {table.name}") - continue - dst = self._index.get(src_schema, table.name) - if not dst: - continue - yield Deprecation( - code='table-migrated-to-uc', - message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog", - # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 - start_line=0, - start_col=0, - end_line=0, - end_col=1024, - ) + def _unsafe_lint_expression(self, expression: Expression, table: Table): + if isinstance(expression, Use): + # Sqlglot captures the database name in the Use statement as a Table, with + # the schema as the table name. + self._session_state.schema = table.name + return + if isinstance(expression, Drop) and getattr(expression, "kind", None) == "SCHEMA": + # Sqlglot captures the schema name in the Drop statement as a Table, with + # the schema as the db name. + return + if isinstance(expression, Create) and getattr(expression, "kind", None) == "SCHEMA": + # Sqlglot captures the schema name in the Create statement as a Table, with + # the schema as the db name. + self._session_state.schema = table.db + return + + # we only migrate tables in the hive_metastore catalog + if self._catalog(table) != 'hive_metastore': + return + # Sqlglot uses db instead of schema, watch out for that + src_schema = table.db if table.db else self._session_state.schema + if not src_schema: + logger.error(f"Could not determine schema for table {table.name}") + return + dst = self._index.get(src_schema, table.name) + if not dst: + return + yield Deprecation( + code='table-migrated-to-uc', + message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 + start_line=0, + start_col=0, + end_line=0, + end_col=1024, + ) @staticmethod def _catalog(table): diff --git a/tests/unit/source_code/linters/test_from_table.py b/tests/unit/source_code/linters/test_from_table.py index 6469b1c622..921486ecc5 100644 --- a/tests/unit/source_code/linters/test_from_table.py +++ b/tests/unit/source_code/linters/test_from_table.py @@ -87,6 +87,14 @@ def test_parses_create_schema(migration_index): assert not list(advices) +def test_parses_drop_schema(migration_index): + query = "DROP SCHEMA xyz" + session_state = CurrentSessionState(schema="old") + ftf = FromTableSqlLinter(migration_index, session_state=session_state) + advices = ftf.lint(query) + assert not list(advices) + + def test_raises_advice_when_parsing_unsupported_sql(migration_index): query = "XDESCRIBE DETAILS xyz" # not a valid query session_state = CurrentSessionState(schema="old")