Skip to content

Commit

Permalink
Fix sqlglot crasher with 'drop schema ...' statement (#2758)
Browse files Browse the repository at this point in the history
## Changes
Fix a crash 

### Linked issues
None

### Functionality
None

### Tests
- [x] added unit tests

---------

Co-authored-by: Eric Vergnaud <eric.vergnaud@databricks.com>
  • Loading branch information
ericvergnaud and ericvergnaud authored Sep 26, 2024
1 parent e3d34d1 commit fc4ec37
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 32 deletions.
82 changes: 50 additions & 32 deletions src/databricks/labs/ucx/source_code/linters/from_table.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from sqlglot import parse as parse_sql
from sqlglot.expressions import Table, Expression, Use, Create
from sqlglot.expressions import Table, Expression, Use, Create, Drop
from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex
from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState, SqlLinter, Fixer
from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState, SqlLinter, Fixer, Failure

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,37 +43,55 @@ def schema(self):

def lint_expression(self, expression: Expression):
for table in expression.find_all(Table):
if isinstance(expression, Use):
# Sqlglot captures the database name in the Use statement as a Table, with
# the schema as the table name.
self._session_state.schema = table.name
continue
if isinstance(expression, Create) and getattr(expression, "kind", None) == "SCHEMA":
# Sqlglot captures the schema name in the Create statement as a Table, with
# the schema as the db name.
self._session_state.schema = table.db
continue
try:
yield from self._unsafe_lint_expression(expression, table)
except Exception as _: # pylint: disable=broad-exception-caught
yield Failure(
code='sql-parse-error',
message=f"Could not parse SQL expression: {expression} ",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)

# we only migrate tables in the hive_metastore catalog
if self._catalog(table) != 'hive_metastore':
continue
# Sqlglot uses db instead of schema, watch out for that
src_schema = table.db if table.db else self._session_state.schema
if not src_schema:
logger.error(f"Could not determine schema for table {table.name}")
continue
dst = self._index.get(src_schema, table.name)
if not dst:
continue
yield Deprecation(
code='table-migrated-to-uc',
message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)
def _unsafe_lint_expression(self, expression: Expression, table: Table):
if isinstance(expression, Use):
# Sqlglot captures the database name in the Use statement as a Table, with
# the schema as the table name.
self._session_state.schema = table.name
return
if isinstance(expression, Drop) and getattr(expression, "kind", None) == "SCHEMA":
# Sqlglot captures the schema name in the Drop statement as a Table, with
# the schema as the db name.
return
if isinstance(expression, Create) and getattr(expression, "kind", None) == "SCHEMA":
# Sqlglot captures the schema name in the Create statement as a Table, with
# the schema as the db name.
self._session_state.schema = table.db
return

# we only migrate tables in the hive_metastore catalog
if self._catalog(table) != 'hive_metastore':
return
# Sqlglot uses db instead of schema, watch out for that
src_schema = table.db if table.db else self._session_state.schema
if not src_schema:
logger.error(f"Could not determine schema for table {table.name}")
return
dst = self._index.get(src_schema, table.name)
if not dst:
return
yield Deprecation(
code='table-migrated-to-uc',
message=f"Table {src_schema}.{table.name} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
start_col=0,
end_line=0,
end_col=1024,
)

@staticmethod
def _catalog(table):
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/source_code/linters/test_from_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def test_parses_create_schema(migration_index):
assert not list(advices)


def test_parses_drop_schema(migration_index):
query = "DROP SCHEMA xyz"
session_state = CurrentSessionState(schema="old")
ftf = FromTableSqlLinter(migration_index, session_state=session_state)
advices = ftf.lint(query)
assert not list(advices)


def test_raises_advice_when_parsing_unsupported_sql(migration_index):
query = "XDESCRIBE DETAILS xyz" # not a valid query
session_state = CurrentSessionState(schema="old")
Expand Down

0 comments on commit fc4ec37

Please sign in to comment.