From 86e981ba9811ce52eead993c9ecaa1d7f526cfee Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Tue, 21 May 2024 06:51:19 -0700 Subject: [PATCH] feat: add support for where clause as a sql expression --- sqlframe/base/dataframe.py | 7 +++++-- tests/integration/test_int_dataframe.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index e2384e7..8afa376 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -606,8 +606,11 @@ def alias(self, name: str, **kwargs) -> Self: return df._convert_leaf_to_cte(sequence_id=new_sequence_id) @operation(Operation.WHERE) - def where(self, column: t.Union[Column, bool], **kwargs) -> Self: - col = self._ensure_and_normalize_col(column) + def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self: + if isinstance(column, str): + col = sqlglot.parse_one(column, dialect=self.session.input_dialect) + else: + col = self._ensure_and_normalize_col(column) return self.copy(expression=self.expression.where(col.expression)) filter = where diff --git a/tests/integration/test_int_dataframe.py b/tests/integration/test_int_dataframe.py index ba57250..1dc4526 100644 --- a/tests/integration/test_int_dataframe.py +++ b/tests/integration/test_int_dataframe.py @@ -302,6 +302,17 @@ def test_where_multiple_chained( compare_frames(df_employee, dfs_employee) +def test_where_sql_expr( + pyspark_employee: PySparkDataFrame, + get_df: t.Callable[[str], _BaseDataFrame], + compare_frames: t.Callable, +): + employee = get_df("employee") + df_employee = pyspark_employee.where("age = 37 AND fname = 'Jack'") + dfs_employee = employee.where("age = 37 AND fname = 'Jack'") + compare_frames(df_employee, dfs_employee) + + def test_operators( pyspark_employee: PySparkDataFrame, get_df: t.Callable[[str], _BaseDataFrame],