Skip to content

Commit

Permalink
Refactor expand_using
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD committed Sep 13, 2024
1 parent 66c3295 commit ce1e48a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
21 changes: 13 additions & 8 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,35 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->


def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
columns = {}

def _gather_source_columns(source_name: str):
for column_name in resolver.get_source_columns(source_name):
if column_name not in columns:
columns[column_name] = source_name

joins = list(scope.find_all(exp.Join))
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]

# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}

for source_name in ordered:
_gather_source_columns(source_name)

for i, join in enumerate(joins):
source_table = ordered[-1]
if source_table:
_gather_source_columns(source_table)

join_table = join.alias_or_name
ordered.append(join_table)

using = join.args.get("using")
if not using:
continue

columns = {}

for source_name in scope.selected_sources:
if source_name in ordered[:-1]:
for column_name in resolver.get_source_columns(source_name):
if column_name not in columns:
columns[column_name] = source_name

join_columns = resolver.get_source_columns(join_table)
conditions = []
using_identifier_count = len(using)
Expand Down
23 changes: 18 additions & 5 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,21 +397,34 @@ def test_qualify_columns(self, logger):
"SELECT u.user_id AS user_id, l.log_date AS log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date AS log_date FROM (SELECT l.log_date AS log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l",
)

chained_schema = {
"A": {"b_id": "int"},
"B": {"b_id": "int", "c_id": "int"},
"C": {"b_id": "int"},
"D": {"c_id": "int"},
}
self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"SELECT A.b_id FROM A JOIN B ON A.b_id=B.b_id JOIN C USING(c_id)",
dialect="postgres",
),
schema={
"A": {"b_id": "int"},
"B": {"b_id": "int", "c_id": "int"},
"C": {"c_id": "int"},
},
schema=chained_schema,
quote_identifiers=False,
).sql("postgres"),
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.c_id = c.c_id",
)
self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"SELECT A.b_id FROM A JOIN B ON A.b_id=B.b_id JOIN C ON B.b_id = C.b_id JOIN D USING(c_id)",
dialect="postgres",
),
schema=chained_schema,
quote_identifiers=False,
).sql("postgres"),
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.c_id = d.c_id",
)

self.check_file(
"qualify_columns",
Expand Down

0 comments on commit ce1e48a

Please sign in to comment.