From 7cf1d70e909ae319ff659e1455e6fcad1e8cf905 Mon Sep 17 00:00:00 2001 From: Vaggelis Danias Date: Fri, 13 Sep 2024 18:17:16 +0300 Subject: [PATCH] Refactor(optimizer): Optimize USING expansion (#4115) * Refactor expand_using * Set return type --------- Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com> --- sqlglot/optimizer/qualify_columns.py | 21 +++++++++++++-------- tests/test_optimizer.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index bffd0ab5b8..12b179cc65 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -145,6 +145,13 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: + columns = {} + + def _gather_source_columns(source_name: str) -> None: + for column_name in resolver.get_source_columns(source_name): + if column_name not in columns: + columns[column_name] = source_name + joins = list(scope.find_all(exp.Join)) names = {join.alias_or_name for join in joins} ordered = [key for key in scope.selected_sources if key not in names] @@ -152,8 +159,14 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: # Mapping of automatically joined column names to an ordered set of source names (dict). column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} + for source_name in ordered: + _gather_source_columns(source_name) + for i, join in enumerate(joins): source_table = ordered[-1] + if source_table: + _gather_source_columns(source_table) + join_table = join.alias_or_name ordered.append(join_table) @@ -161,14 +174,6 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: if not using: continue - columns = {} - - for source_name in scope.selected_sources: - if source_name in ordered[:-1]: - for column_name in resolver.get_source_columns(source_name): - if column_name not in columns: - columns[column_name] = source_name - join_columns = resolver.get_source_columns(join_table) conditions = [] using_identifier_count = len(using) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a71644a39c..857ba1aab5 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -412,6 +412,22 @@ def test_qualify_columns(self, logger): ).sql("postgres"), "SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.c_id = c.c_id", ) + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + "SELECT A.b_id FROM A JOIN B ON A.b_id=B.b_id JOIN C ON B.b_id = C.b_id JOIN D USING(d_id)", + dialect="postgres", + ), + schema={ + "A": {"b_id": "int"}, + "B": {"b_id": "int", "d_id": "int"}, + "C": {"b_id": "int"}, + "D": {"d_id": "int"}, + }, + quote_identifiers=False, + ).sql("postgres"), + "SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id", + ) self.check_file( "qualify_columns",