From 5a409bbd1c9c6a19540baaad463296ac3275927f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 11 Oct 2024 11:43:52 +0200 Subject: [PATCH 1/4] refactor --- src/databricks/labs/ucx/source_code/base.py | 8 ++++---- src/databricks/labs/ucx/source_code/linters/pyspark.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index 69bf0e307f..6290f877f8 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -262,14 +262,14 @@ def collect_tables(self, source_code: str) -> Iterable[UsedTable]: ... @dataclass -class TableInfoNode: +class UsedTableNode: table: UsedTable node: NodeNG class TablePyCollector(TableCollector, ABC): - def collect_tables(self, source_code: str): + def collect_tables(self, source_code: str) -> Iterable[UsedTable]: try: tree = Tree.normalize_and_parse(source_code) for table_node in self.collect_tables_from_tree(tree): @@ -282,7 +282,7 @@ def collect_tables(self, source_code: str): logger.warning('syntax-error', exc_info=e) @abstractmethod - def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: ... + def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: ... class TableSqlCollector(TableCollector, ABC): ... @@ -467,7 +467,7 @@ def collect_tables(self, source_code: str) -> Iterable[UsedTable]: except AstroidSyntaxError as e: logger.warning('syntax-error', exc_info=e) - def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: + def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: for collector in self._table_collectors: yield from collector.collect_tables_from_tree(tree) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 6ee0a89d4e..3a8c9eb54c 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -15,7 +15,7 @@ SqlLinter, Fixer, UsedTable, - TableInfoNode, + UsedTableNode, TablePyCollector, TableSqlCollector, DfsaPyCollector, @@ -388,14 +388,14 @@ def _find_matcher(self, node: NodeNG): return None return matcher if matcher.matches(node) else None - def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: + def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: for node in tree.walk(): matcher = self._find_matcher(node) if matcher is None: continue assert isinstance(node, Call) for used_table in matcher.collect_tables(self._from_table, self._index, self._session_state, node): - yield TableInfoNode(used_table, node) # B + yield UsedTableNode(used_table, node) class _SparkSqlAnalyzer: @@ -468,11 +468,11 @@ class SparkSqlTablePyCollector(_SparkSqlAnalyzer, TablePyCollector): def __init__(self, sql_collector: TableSqlCollector): self._sql_collector = sql_collector - def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: + def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: assert self._sql_collector for call_node, query in self._visit_call_nodes(tree): for value in InferredValue.infer_from_node(query): if not value.is_inferred(): continue # TODO error handling strategy for table in self._sql_collector.collect_tables(value.as_string()): - yield TableInfoNode(table, call_node) # A + yield UsedTableNode(table, call_node) # A From 331232d64f40b131d9b59d93227818bd592d3281 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 11 Oct 2024 11:44:09 +0200 Subject: [PATCH 2/4] cleanup --- src/databricks/labs/ucx/source_code/linters/pyspark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 3a8c9eb54c..91ab4c273e 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -475,4 +475,4 @@ def collect_tables_from_tree(self, tree: Tree) -> Iterable[UsedTableNode]: if not value.is_inferred(): continue # TODO error handling strategy for table in self._sql_collector.collect_tables(value.as_string()): - yield UsedTableNode(table, call_node) # A + yield UsedTableNode(table, call_node) From b015a24f7f1b4f9f61358ee7d5d5f2cb692a1468 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 11 Oct 2024 12:20:54 +0200 Subject: [PATCH 3/4] revert workaround --- src/databricks/labs/ucx/source_code/base.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index 6290f877f8..3d0995955c 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -273,11 +273,7 @@ def collect_tables(self, source_code: str) -> Iterable[UsedTable]: try: tree = Tree.normalize_and_parse(source_code) for table_node in self.collect_tables_from_tree(tree): - # see https://github.com/databrickslabs/ucx/issues/2887 - if isinstance(table_node, UsedTable): - yield table_node - else: - yield table_node.table + yield table_node.table except AstroidSyntaxError as e: logger.warning('syntax-error', exc_info=e) @@ -458,12 +454,7 @@ def collect_tables(self, source_code: str) -> Iterable[UsedTable]: try: tree = self._parse_and_append(source_code) for table_node in self.collect_tables_from_tree(tree): - # there's a bug in the code that causes this to be necessary - # see https://github.com/databrickslabs/ucx/issues/2887 - if isinstance(table_node, UsedTable): - yield table_node - else: - yield table_node.table + yield table_node.table except AstroidSyntaxError as e: logger.warning('syntax-error', exc_info=e) From 33efffb630422950c7a6b0c1f43081c31ebb21e3 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 11 Oct 2024 12:21:05 +0200 Subject: [PATCH 4/4] add samples --- .../source_code/samples/functional/table-access.py | 11 +++++++++++ .../source_code/samples/functional/table-access.sql | 12 ++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 tests/unit/source_code/samples/functional/table-access.py create mode 100644 tests/unit/source_code/samples/functional/table-access.sql diff --git a/tests/unit/source_code/samples/functional/table-access.py b/tests/unit/source_code/samples/functional/table-access.py new file mode 100644 index 0000000000..2539950b44 --- /dev/null +++ b/tests/unit/source_code/samples/functional/table-access.py @@ -0,0 +1,11 @@ +# Databricks notebook source +# ucx[default-format-changed-in-dbr8:+1:0:+1:18] The default format changed in Databricks Runtime 8.0, from Parquet to Delta +spark.table("a.b").count() +spark.sql("SELECT * FROM b.c LEFT JOIN c.d USING (e)") +%sql SELECT * FROM b.c LEFT JOIN c.d USING (e) + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC SELECT * FROM b.c LEFT JOIN c.d USING (e) + diff --git a/tests/unit/source_code/samples/functional/table-access.sql b/tests/unit/source_code/samples/functional/table-access.sql new file mode 100644 index 0000000000..0dabee4794 --- /dev/null +++ b/tests/unit/source_code/samples/functional/table-access.sql @@ -0,0 +1,12 @@ +-- Databricks notebook source + +SELECT * FROM b.c LEFT JOIN c.d USING (e) + +-- COMMAND ---------- + +-- MAGIC %python +-- ucx[default-format-changed-in-dbr8:+1:0:+1:18] The default format changed in Databricks Runtime 8.0, from Parquet to Delta +-- MAGIC spark.table("a.b").count() +-- MAGIC spark.sql("SELECT * FROM b.c LEFT JOIN c.d USING (e)") + +