Skip to content

Commit

Permalink
fix: change counter on duplicate cte to random id (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Jan 25, 2025
1 parent 557dc43 commit 7a37ef4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
15 changes: 7 additions & 8 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import sys
import typing as t
import uuid
import zlib
from copy import copy
from dataclasses import dataclass
Expand Down Expand Up @@ -491,22 +492,20 @@ def _add_ctes_to_expression(self, expression: exp.Select, ctes: t.List[exp.CTE])
with_expression = expression.args.get("with")
if with_expression:
existing_ctes = with_expression.expressions
existing_cte_counts = {x.alias_or_name: 0 for x in existing_ctes}
existing_cte_names = {x.alias_or_name for x in existing_ctes}
replaced_cte_names = {} # type: ignore
for cte in ctes:
if replaced_cte_names:
cte = cte.transform(replace_id_value, replaced_cte_names) # type: ignore
if cte.alias_or_name in existing_cte_counts:
existing_cte_counts[cte.alias_or_name] += 10
if cte.alias_or_name in existing_cte_names:
random_filter = exp.Literal.string(uuid.uuid4().hex)
# Add unique where filter to ensure that the hash of the CTE is unique
cte.set(
"this",
cte.this.where(
exp.EQ(
this=exp.Literal.number(existing_cte_counts[cte.alias_or_name]),
expression=exp.Literal.number(
existing_cte_counts[cte.alias_or_name]
),
this=random_filter,
expression=random_filter,
)
),
)
Expand All @@ -520,7 +519,7 @@ def _add_ctes_to_expression(self, expression: exp.Select, ctes: t.List[exp.CTE])
new_cte_alias, dialect=self.session.input_dialect, into=exp.TableAlias
),
)
existing_cte_counts[new_cte_alias] = 0
existing_cte_names.add(new_cte_alias)
existing_ctes.append(cte)
else:
existing_ctes = ctes
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,3 +2355,29 @@ def test_union_common_root(
dfs_final = dfs_1.union(dfs_2)

compare_frames(df_final, dfs_final, compare_schema=False)


# https://github.com/eakmanrq/sqlframe/issues/253
def test_union_common_root_again(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], BaseDataFrame],
compare_frames: t.Callable,
):
df_1 = pyspark_employee.filter(F.col("age") > 40)
df_2 = df_1.join(
pyspark_employee.select("employee_id").distinct(),
on="employee_id",
how="right",
)
df_final = df_1.union(df_2).union(pyspark_employee)

employee = get_df("employee")
dfs_1 = employee.filter(SF.col("age") > 40)
dfs_2 = dfs_1.join(
employee.select("employee_id").distinct(),
on="employee_id",
how="right",
)
dfs_final = dfs_1.union(dfs_2).union(employee)

compare_frames(df_final, dfs_final, compare_schema=False)

0 comments on commit 7a37ef4

Please sign in to comment.