diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 591cb3e14d77..49ea9014ca02 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -2107,106 +2107,41 @@ impl JoinBuilder { } - // // Finish with join predicates - // pub fn join_where(self, predicates: Vec) -> PolarsResult { - // let to_inner = Arc::unwrap_or_clone; - // - // let mut ie_left_on = vec![]; - // let mut ie_right_on = vec![]; - // let mut ie_op = vec![]; - // - // let mut eq_left_on = vec![]; - // let mut eq_right_on = vec![]; - // - // let mut remaining_preds = vec![]; - // - // fn to_inequality_operator(op: &Operator) -> Option { - // match op { - // Operator::Lt => Some(InequalityOperator::Lt), - // Operator::LtEq => Some(InequalityOperator::LtEq), - // Operator::Gt => Some(InequalityOperator::Gt), - // Operator::GtEq => Some(InequalityOperator::GtEq), - // _ => None, - // } - // } - // - // for pred in predicates.into_iter() { - // let Expr::BinaryExpr {left, op, right} = pred else { polars_bail!(InvalidOperation: "can only join on binary expressions") }; - // polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); - // - // if let Some(ie_op_) = to_inequality_operator(&op) { - // ie_left_on.push(to_inner(left)); - // ie_right_on.push(to_inner(right)); - // ie_op.push(ie_op_) - // } else if matches!(op, Operator::Eq) { - // eq_left_on.push(to_inner(left)); - // eq_right_on.push(to_inner(right)); - // } else { - // - // remaining_preds.push(pred); - // } - // } - // - // - // fn parse_ie_join_expressions( - // expressions: Vec, - // ) -> PolarsResult<(Vec, Vec, Vec)> { - // - // let mut left_on = Vec::with_capacity(2); - // let mut operators = Vec::with_capacity(2); - // let mut right_on = Vec::with_capacity(2); - // - // for expression in expressions.into_iter() { - // let (left, op, right) = parse_inequality_expression(expression)?; - // left_on.push(left); - // operators.push(op); - // right_on.push(right); - // } - // - // Ok((left_on, operators, right_on)) - // } - // - // fn parse_inequality_expression(expression: Expr) -> PolarsResult<(Expr, InequalityOperator, Expr)> { - // fn to_inequality_operator(op: &Operator) -> PolarsResult { - // match op { - // Operator::Lt => Ok(InequalityOperator::Lt), - // Operator::LtEq => Ok(InequalityOperator::LtEq), - // Operator::Gt => Ok(InequalityOperator::Gt), - // Operator::GtEq => Ok(InequalityOperator::GtEq), - // _ => Err(PyValueError::new_err(format!( - // "expected an inequality operator in join inequality, got '{}'", - // op - // ))), - // } - // } - // - // match expression.inner { - // Expr::BinaryExpr { left, op, right } => { - // let inequality_op = to_inequality_operator(&op)?; - // Ok(((*left).clone(), inequality_op, (*right).clone())) - // }, - // _ => Err(PyValueError::new_err( - // "expected a binary expression for a join inequality", - // )), - // } - // } - // - // let mut opt_state = self.lf.opt_state; - // let other = self.other.expect("with not set"); - // - // // If any of the nodes reads from files we must activate this plan as well. - // if other.opt_state.contains(OptFlags::FILE_CACHING) { - // opt_state |= OptFlags::FILE_CACHING; - // } - // - // let args = JoinArgs { - // how: self.how, - // validation: self.validation, - // suffix: self.suffix, - // slice: None, - // join_nulls: self.join_nulls, - // coalesce: self.coalesce, - // }; - // - // } + // Finish with join predicates + pub fn join_where(self, predicates: Vec) -> LazyFrame { + let mut opt_state = self.lf.opt_state; + let other = self.other.expect("with not set"); + + // If any of the nodes reads from files we must activate this plan as well. + if other.opt_state.contains(OptFlags::FILE_CACHING) { + opt_state |= OptFlags::FILE_CACHING; + } + + let args = JoinArgs { + how: self.how, + validation: self.validation, + suffix: self.suffix, + slice: None, + join_nulls: self.join_nulls, + coalesce: self.coalesce, + }; + let options = JoinOptions { + allow_parallel: self.allow_parallel, + force_parallel: self.force_parallel, + args, + ..Default::default() + }; + + let lp = DslPlan::Join { + input_left: Arc::new(self.lf.logical_plan), + input_right: Arc::new(other.logical_plan), + left_on: Default::default(), + right_on: Default::default(), + predicates, + options: Arc::from(options), + }; + + LazyFrame::from_logical_plan(lp, opt_state) + + } } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 4e55b97125d6..86692ee33619 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -85,6 +85,15 @@ pub(super) struct DslConversionContext<'a> { pub(super) opt_flags: &'a mut OptFlags, } +pub(super) fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str) -> PolarsResult { + let lp_node = ctxt.lp_arena.add(lp); + ctxt.conversion_optimizer + .coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node) + .map_err(|e| e.context(format!("'{name}' failed").into()))?; + + Ok(lp_node) +} + /// converts LogicalPlan to IR /// it adds expressions & lps to the respective arenas as it traverses the plan /// finally it returns the top node of the logical plan @@ -92,14 +101,6 @@ pub(super) struct DslConversionContext<'a> { pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult { let owned = Arc::unwrap_or_clone; - fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str) -> PolarsResult { - let lp_node = ctxt.lp_arena.add(lp); - ctxt.conversion_optimizer - .coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node) - .map_err(|e| e.context(format!("'{name}' failed").into()))?; - - Ok(lp_node) - } let v = match lp { DslPlan::Scan { @@ -541,10 +542,9 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult left_on, right_on, predicates, - mut options, + options, } => { - let ir = join::resolve_join(input_left, input_right, left_on, right_on, predicates, options, ctxt)?; - return run_conversion(ir, ctxt, "join"); + return join::resolve_join(input_left, input_right, left_on, right_on, predicates, options, ctxt) }, DslPlan::HStack { input, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 0c83d77f357f..b77803f1f6b2 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -5,6 +5,18 @@ use crate::plans::AExpr; use crate::prelude::FunctionOptions; use super::*; +fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> { + for e in keys { + if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { + polars_bail!( + InvalidOperation: + "'alias' is not allowed in a join key, use 'with_columns' first", + ) + } + } + Ok(()) + +} pub fn resolve_join( input_left: Arc, input_right: Arc, @@ -13,19 +25,21 @@ pub fn resolve_join( predicates: Vec, mut options: Arc, ctxt: &mut DslConversionContext -) -> PolarsResult { +) -> PolarsResult { + if !predicates.is_empty() { + debug_assert!(left_on.is_empty() && right_on.is_empty()); + return resolve_join_where(input_left, input_right, predicates, options, ctxt) + } + let owned = Arc::unwrap_or_clone; if matches!(options.args.how, JoinType::Cross) { polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); } else { + check_join_keys(&left_on)?; + check_join_keys(&right_on)?; + let mut turn_off_coalesce = false; for e in left_on.iter().chain(right_on.iter()) { - if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { - polars_bail!( - ComputeError: - "'alias' is not allowed in a join key, use 'with_columns' first", - ) - } // Any expression that is not a simple column expression will turn of coalescing. turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_))); } @@ -41,7 +55,7 @@ pub fn resolve_join( polars_ensure!( left_on.len() == right_on.len(), - ComputeError: + InvalidOperation: format!( "the number of columns given as join key (left: {}, right:{}) should be equal", left_on.len(), @@ -96,6 +110,125 @@ pub fn resolve_join( right_on, options, }; - Ok(lp) + run_conversion(lp, ctxt, "join") +} + +impl From for Operator { + fn from(value: InequalityOperator) -> Self { + match value { + InequalityOperator::LtEq => Operator::LtEq, + InequalityOperator::Lt => Operator::Lt, + InequalityOperator::GtEq => Operator::GtEq, + InequalityOperator::Gt => Operator::Gt, + } + } +} + +fn resolve_join_where( + input_left: Arc, + input_right: Arc, + predicates: Vec, + mut options: Arc, + ctxt: &mut DslConversionContext +) -> PolarsResult { + check_join_keys(&predicates)?; + + let owned = |e: Arc| (*e).clone(); + + let mut ie_left_on = vec![]; + let mut ie_right_on = vec![]; + let mut ie_op = vec![]; + + let mut eq_left_on = vec![]; + let mut eq_right_on = vec![]; + + let mut remaining_preds = vec![]; + + fn to_inequality_operator(op: &Operator) -> Option { + match op { + Operator::Lt => Some(InequalityOperator::Lt), + Operator::LtEq => Some(InequalityOperator::LtEq), + Operator::Gt => Some(InequalityOperator::Gt), + Operator::GtEq => Some(InequalityOperator::GtEq), + _ => None, + } + } + + for pred in predicates.into_iter() { + let Expr::BinaryExpr {left, op, right} = pred.clone() else { polars_bail!(InvalidOperation: "can only join on binary expressions") }; + polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); + + if let Some(ie_op_) = to_inequality_operator(&op) { + // We already have an IEjoin or an Inner join, push to remaining + if ie_op.len() >= 2 || !eq_right_on.is_empty() { + remaining_preds.push(Expr::BinaryExpr {left, op, right}) + } else { + ie_left_on.push(owned(left)); + ie_right_on.push(owned(right)); + ie_op.push(ie_op_) + } + } else if matches!(op, Operator::Eq) { + eq_left_on.push(owned(left)); + eq_right_on.push(owned(right)); + } else { + remaining_preds.push(pred); + } + } + + let join_node = if !eq_left_on.is_empty() { + let join_node = resolve_join(input_left, input_right, eq_left_on, eq_right_on, vec![], options.clone(), ctxt)?; + + for ((l, op), r) in ie_left_on.into_iter().zip(ie_op.into_iter()).zip(ie_right_on.into_iter()) { + remaining_preds.push(Expr::BinaryExpr {left: Arc::from(l), op: op.into(), right: Arc::from(r)}) + } + join_node + + } else if ie_right_on.len() == 2 { + let opts = Arc::make_mut(&mut options); + opts.args.how = JoinType::IEJoin(IEJoinOptions { + operator1: ie_op[0], + operator2: ie_op[1], + }); + + resolve_join(input_left, input_right, ie_left_on, ie_right_on, vec![], options.clone(), ctxt)? + } else { + let opts = Arc::make_mut(&mut options); + opts.args.how = JoinType::Cross; + + resolve_join(input_left, input_right, vec![], vec![], vec![], options.clone(), ctxt)? + }; + + let IR::Join {input_right, ..} = ctxt.lp_arena.get(join_node) else { unreachable!()}; + let schema_right = ctxt.lp_arena.get(*input_right).schema(ctxt.lp_arena).into_owned(); + + + let suffix = options.args.suffix(); + + let mut last_node = join_node; + + // Ensure that the predicates use the proper suffix + for e in remaining_preds { + let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?; + let AExpr::BinaryExpr {left, op, mut right} = *ctxt.expr_arena.get(predicate.node()) else { unreachable!() }; + + let original_right = right; + for name in aexpr_to_leaf_names(right, ctxt.expr_arena) { + if !schema_right.contains(name.as_str()) { + let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); + polars_ensure!(schema_right.contains(new_name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation"); + + right = rename_matching_aexpr_leaf_names(right, ctxt.expr_arena, name.as_str(), new_name); + } + } + ctxt.expr_arena.swap(right, original_right); + + let ir = IR::Filter { + input: last_node, + predicate + }; + last_node = ctxt.lp_arena.add(ir); + + } + Ok(last_node) } \ No newline at end of file diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index f5288f3dd80c..275c0a687e92 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -969,21 +969,17 @@ impl PyLazyFrame { .into()) } - fn join_where(&self, other: Self, on: Vec, suffix: String) -> PyResult { + fn join_where(&self, other: Self, predicates: Vec, suffix: String) -> PyResult { let ldf = self.ldf.clone(); let other = other.ldf; - let (left_on, operators, right_on) = parse_ie_join_expressions(on)?; + + let predicates = predicates.to_exprs(); + Ok(ldf .join_builder() .with(other) - .left_on(left_on) - .right_on(right_on) - .how(JoinType::IEJoin(IEJoinOptions { - operator1: operators[0], - operator2: operators[1], - })) .suffix(suffix) - .finish() + .join_where(predicates) .into()) }