From 6ca9ce005a1ff3558997ed8d20f999a16e2264b5 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Tue, 21 May 2024 20:30:25 -0700 Subject: [PATCH] fix: properly alias with column overlapping name --- sqlframe/base/dataframe.py | 2 +- tests/unit/standalone/test_dataframe.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index 8afa376..d5764dc 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -1097,7 +1097,7 @@ def withColumn(self, colName: str, col: Column) -> Self: ) if existing_col_index: expression = self.expression.copy() - expression.expressions[existing_col_index] = col.expression + expression.expressions[existing_col_index] = col.alias(colName).expression return self.copy(expression=expression) return self.copy().select(col.alias(colName), append=True) diff --git a/tests/unit/standalone/test_dataframe.py b/tests/unit/standalone/test_dataframe.py index 8af2a84..7291c89 100644 --- a/tests/unit/standalone/test_dataframe.py +++ b/tests/unit/standalone/test_dataframe.py @@ -2,6 +2,7 @@ from sqlglot import expressions as exp +from sqlframe.standalone import functions as F from sqlframe.standalone.dataframe import StandaloneDataFrame pytest_plugins = ["tests.common_fixtures", "tests.unit.standalone.fixtures"] @@ -44,3 +45,13 @@ def test_persist_storagelevel(standalone_employee: StandaloneDataFrame, compare_ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] compare_sql(df, expected_statements) + + +def test_with_column_duplicate_alias(standalone_employee: StandaloneDataFrame): + df = standalone_employee.withColumn("fname", F.col("age").cast("string")) + assert df.columns == ["employee_id", "fname", "lname", "age", "store_id"] + # Make sure that the new columns is added with an alias to `fname` + assert ( + df.sql(pretty=False) + == "SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`age` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + )