Skip to content

Commit

Permalink
Refactor(optimizer): Optimize USING expansion (#4115)
Browse files Browse the repository at this point in the history
* Refactor expand_using

* Set return type

---------

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
  • Loading branch information
VaggelisD and georgesittas authored Sep 13, 2024
1 parent a34f8b6 commit 7cf1d70
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
21 changes: 13 additions & 8 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,35 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->


def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
columns = {}

def _gather_source_columns(source_name: str) -> None:
for column_name in resolver.get_source_columns(source_name):
if column_name not in columns:
columns[column_name] = source_name

joins = list(scope.find_all(exp.Join))
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]

# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}

for source_name in ordered:
_gather_source_columns(source_name)

for i, join in enumerate(joins):
source_table = ordered[-1]
if source_table:
_gather_source_columns(source_table)

join_table = join.alias_or_name
ordered.append(join_table)

using = join.args.get("using")
if not using:
continue

columns = {}

for source_name in scope.selected_sources:
if source_name in ordered[:-1]:
for column_name in resolver.get_source_columns(source_name):
if column_name not in columns:
columns[column_name] = source_name

join_columns = resolver.get_source_columns(join_table)
conditions = []
using_identifier_count = len(using)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,22 @@ def test_qualify_columns(self, logger):
).sql("postgres"),
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.c_id = c.c_id",
)
self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"SELECT A.b_id FROM A JOIN B ON A.b_id=B.b_id JOIN C ON B.b_id = C.b_id JOIN D USING(d_id)",
dialect="postgres",
),
schema={
"A": {"b_id": "int"},
"B": {"b_id": "int", "d_id": "int"},
"C": {"b_id": "int"},
"D": {"d_id": "int"},
},
quote_identifiers=False,
).sql("postgres"),
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
)

self.check_file(
"qualify_columns",
Expand Down

0 comments on commit 7cf1d70

Please sign in to comment.