diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 0d591ce81313..b89c957dd119 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1150,9 +1150,6 @@ impl Expr { /// Should be used in aggregation context. If you want to filter on a /// DataFrame level, use `LazyFrame::filter`. pub fn filter>(self, predicate: E) -> Self { - if has_expr(&self, |e| matches!(e, Expr::Wildcard)) { - panic!("filter '*' not allowed, use LazyFrame::filter") - }; Expr::Filter { input: Arc::new(self), by: Arc::new(predicate.into()), diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 25a605124496..029a3ea19400 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -374,6 +374,58 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult output_schema: None, filter: None, }, + // `select(pl.all().filter())` -> `filter()` + // `select((pl.all().).filter())` -> `select(pl.all().).filter()` + DslPlan::Select { + expr, + input, + options, + } if { + // This is a hack so that we don't break `list.eval(pl.element().filter())`. That expression + // hits this codepath with `col("").filter(..)` when it goes through `run_on_group_by_engine` + // and for some reason doesn't give a correct result if we do this rewrite. + // + // We detect that we are called by `run_on_group_by_engine` because it calls to here with + // nearly all optimizations turned off, so we just picked the `PREDICATE_PUSHDOWN` to check. + ctxt.opt_flags.contains(OptFlags::PREDICATE_PUSHDOWN) + } && { + let mut rewrite = false; + + if let [Expr::Filter { input, .. }] = &expr[..] { + if has_expr(input.as_ref(), |e| matches!(e, Expr::Wildcard)) { + rewrite = true + } + } + + rewrite + } => + { + let Expr::Filter { + input: filter_input, + by: predicate, + } = expr.into_iter().next().unwrap() + else { + unreachable!() + }; + + let input = match filter_input.as_ref() { + // If it's just a wildcard we don't need the `Select` node. + Expr::Wildcard => input, + _ => Arc::new(DslPlan::Select { + expr: vec![Arc::unwrap_or_clone(filter_input)], + input, + options, + }), + }; + + return to_alp_impl( + DslPlan::Filter { + input, + predicate: Arc::unwrap_or_clone(predicate), + }, + ctxt, + ); + }, DslPlan::Select { expr, input, diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index 31bf08534df7..72ac1bf2677b 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -653,3 +653,47 @@ def test_function_expr_scalar_identification_18755() -> None: pl.DataFrame({"a": [1, 2]}).with_columns(pl.lit(5).shrink_dtype().alias("b")), pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int8)}), ) + + +def test_filter_all() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": ["p", "q", "r", "s", "t"], + "p": [True, False, True, True, False], + } + ) + + expect = pl.DataFrame({"a": [1, 3, 4], "b": ["p", "r", "s"], "p": True}) + + assert_frame_equal( + df.select(pl.all().filter("p")), + expect, + ) + + q = df.lazy().select(pl.all().filter(~pl.col("p"))) + # Ensure this is re-written to a `Filter` node during IR conversion. + assert r'SELECTION: col("p").not()' in q.explain() + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": [2, 5], "b": ["q", "t"], "p": False}), + ) + + # Ensure no regression in the non-wildcard case, as the predicate may refer + # to columns that are not part of the `select()` + q = df.lazy().select(pl.col("a").filter(~pl.col("p"))) + assert_frame_equal(q.collect(), pl.DataFrame({"a": [2, 5]})) + + q = df.lazy().select((pl.all().reverse()).filter(~pl.col("p"))) + assert r'FILTER col("p").not()' in q.explain() + + q = df.lazy().select( + pl.sum_horizontal(pl.all().cast(pl.String)).filter(pl.col("p")) + ) + print(q.explain()) + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": [5, 2], "b": ["t", "q"], "p": False}), + )