Skip to content

Commit

Permalink
feat: add support for where clause as a sql expression
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed May 21, 2024
1 parent 9df55e1 commit 86e981b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,11 @@ def alias(self, name: str, **kwargs) -> Self:
return df._convert_leaf_to_cte(sequence_id=new_sequence_id)

@operation(Operation.WHERE)
def where(self, column: t.Union[Column, bool], **kwargs) -> Self:
col = self._ensure_and_normalize_col(column)
def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self:
if isinstance(column, str):
col = sqlglot.parse_one(column, dialect=self.session.input_dialect)
else:
col = self._ensure_and_normalize_col(column)
return self.copy(expression=self.expression.where(col.expression))

filter = where
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ def test_where_multiple_chained(
compare_frames(df_employee, dfs_employee)


def test_where_sql_expr(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
compare_frames: t.Callable,
):
employee = get_df("employee")
df_employee = pyspark_employee.where("age = 37 AND fname = 'Jack'")
dfs_employee = employee.where("age = 37 AND fname = 'Jack'")
compare_frames(df_employee, dfs_employee)


def test_operators(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
Expand Down

0 comments on commit 86e981b

Please sign in to comment.