Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make expression output type known #19195

Merged
merged 8 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl DataType {
ArrowDataType::Extension(name, _, _) if name.as_str() == "POLARS_EXTENSION_TYPE" => {
#[cfg(feature = "object")]
{
DataType::Object("extension", None)
DataType::Object("object", None)
}
#[cfg(not(feature = "object"))]
{
Expand Down
58 changes: 21 additions & 37 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ pub struct ApplyExpr {
function_operates_on_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
input_schema: SchemaRef,
allow_threading: bool,
check_lengths: bool,
allow_group_aware: bool,
output_dtype: Option<DataType>,
output_field: Field,
}

impl ApplyExpr {
Expand All @@ -38,8 +38,8 @@ impl ApplyExpr {
expr: Expr,
options: FunctionOptions,
allow_threading: bool,
input_schema: Option<SchemaRef>,
output_dtype: Option<DataType>,
input_schema: SchemaRef,
output_field: Field,
returns_scalar: bool,
) -> Self {
#[cfg(debug_assertions)]
Expand All @@ -62,30 +62,7 @@ impl ApplyExpr {
allow_threading,
check_lengths: options.check_lengths(),
allow_group_aware: options.flags.contains(FunctionFlags::ALLOW_GROUP_AWARE),
output_dtype,
}
}

pub(crate) fn new_minimal(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn ColumnsUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
) -> Self {
Self {
inputs,
function,
expr,
collect_groups,
function_returns_scalar: false,
function_operates_on_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
allow_threading: true,
check_lengths: true,
allow_group_aware: true,
output_dtype: None,
output_field,
}
}

Expand Down Expand Up @@ -123,19 +100,16 @@ impl ApplyExpr {
Ok(ac)
}

fn get_input_schema(&self, df: &DataFrame) -> Cow<Schema> {
match &self.input_schema {
Some(schema) => Cow::Borrowed(schema.as_ref()),
None => Cow::Owned(df.schema()),
}
fn get_input_schema(&self, _df: &DataFrame) -> Cow<Schema> {
Cow::Borrowed(self.input_schema.as_ref())
}

/// Evaluates and flattens `Option<Column>` to `Column`.
fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult<Column> {
if let Some(out) = self.function.call_udf(inputs)? {
Ok(out)
} else {
let field = self.to_field(self.input_schema.as_ref().unwrap()).unwrap();
let field = self.to_field(self.input_schema.as_ref()).unwrap();
Ok(Column::full_null(field.name().clone(), 1, field.dtype()))
}
}
Expand Down Expand Up @@ -179,9 +153,11 @@ impl ApplyExpr {
};

let ca: ListChunked = if self.allow_threading {
let dtype = match &self.output_dtype {
Some(dtype) if dtype.is_known() && !dtype.is_null() => Some(dtype.clone()),
_ => None,
let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null()
{
Some(self.output_field.dtype.clone())
} else {
None
};

let lst = agg.list().unwrap();
Expand Down Expand Up @@ -287,6 +263,7 @@ impl ApplyExpr {
}
builder.finish()
} else {
// We still need this branch to materialize unknown/ data dependent types in eager. :(
(0..len)
.map(|_| {
container.clear();
Expand All @@ -303,6 +280,13 @@ impl ApplyExpr {
.collect::<PolarsResult<ListChunked>>()?
.with_name(field.name.clone())
};
#[cfg(debug_assertions)]
{
let inner = ca.dtype().inner_dtype().unwrap();
if field.dtype.is_known() {
assert_eq!(inner, &field.dtype);
}
}

drop(iters);

Expand Down
61 changes: 28 additions & 33 deletions crates/polars-expr/src/expressions/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExp
pub struct ColumnExpr {
name: PlSmallStr,
expr: Expr,
schema: Option<SchemaRef>,
schema: SchemaRef,
}

impl ColumnExpr {
pub fn new(name: PlSmallStr, expr: Expr, schema: Option<SchemaRef>) -> Self {
pub fn new(name: PlSmallStr, expr: Expr, schema: SchemaRef) -> Self {
Self { name, expr, schema }
}
}
Expand Down Expand Up @@ -141,42 +141,37 @@ impl PhysicalExpr for ColumnExpr {
Some(&self.expr)
}
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Series> {
let out = match &self.schema {
None => self.process_by_linear_search(df, state, false),
Some(schema) => {
match schema.get_full(&self.name) {
Some((idx, _, _)) => {
// check if the schema was correct
// if not do O(n) search
match df.get_columns().get(idx) {
Some(out) => self.process_by_idx(
out.as_materialized_series(),
state,
schema,
df,
true,
),
None => {
// partitioned group_by special case
if let Some(schema) = state.get_schema() {
self.process_from_state_schema(df, state, &schema)
} else {
self.process_by_linear_search(df, state, true)
}
},
}
},
// in the future we will throw an error here
// now we do a linear search first as the lazy reported schema may still be incorrect
// in debug builds we panic so that it can be fixed when occurring
let out = match self.schema.get_full(&self.name) {
Some((idx, _, _)) => {
// check if the schema was correct
// if not do O(n) search
match df.get_columns().get(idx) {
Some(out) => self.process_by_idx(
out.as_materialized_series(),
state,
&self.schema,
df,
true,
),
None => {
if self.name.starts_with(CSE_REPLACED) {
return self.process_cse(df, schema);
// partitioned group_by special case
if let Some(schema) = state.get_schema() {
self.process_from_state_schema(df, state, &schema)
} else {
self.process_by_linear_search(df, state, true)
}
self.process_by_linear_search(df, state, true)
},
}
},
// in the future we will throw an error here
// now we do a linear search first as the lazy reported schema may still be incorrect
// in debug builds we panic so that it can be fixed when occurring
None => {
if self.name.starts_with(CSE_REPLACED) {
return self.process_cse(df, &self.schema);
}
self.process_by_linear_search(df, state, true)
},
};
self.check_external_context(out, state)
}
Expand Down
66 changes: 36 additions & 30 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn create_physical_expressions_from_irs(
exprs: &[ExprIR],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker)
Expand All @@ -35,7 +35,7 @@ pub(crate) fn create_physical_expressions_check_state<F>(
exprs: &[ExprIR],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
checker: F,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
Expand All @@ -57,7 +57,7 @@ pub(crate) fn create_physical_expressions_from_nodes(
exprs: &[Node],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
create_physical_expressions_from_nodes_check_state(
Expand All @@ -69,7 +69,7 @@ pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
exprs: &[Node],
context: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
checker: F,
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn create_physical_expr(
expr_ir: &ExprIR,
ctxt: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?;
Expand All @@ -185,7 +185,7 @@ fn create_physical_expr_inner(
expression: Node,
ctxt: Context,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
schema: &SchemaRef,
state: &mut ExpressionConversionState,
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
use AExpr::*;
Expand Down Expand Up @@ -309,7 +309,7 @@ fn create_physical_expr_inner(
Column(column) => Ok(Arc::new(ColumnExpr::new(
column.clone(),
node_to_expr(expression, expr_arena),
schema.cloned(),
schema.clone(),
))),
Sort { expr, options } => {
let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
Expand Down Expand Up @@ -410,22 +410,18 @@ fn create_physical_expr_inner(
return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol)));
}

let field = schema
.map(|schema| {
expr_arena.get(expression).to_field(
schema,
Context::Aggregation,
expr_arena,
)
})
.transpose()?;
let field = expr_arena.get(expression).to_field(
schema,
Context::Aggregation,
expr_arena,
)?;

let groupby = GroupByMethod::from(agg.clone());
let agg_type = AggregationType {
groupby,
allow_threading: false,
};
Ok(Arc::new(AggregationExpr::new(input, agg_type, field)))
Ok(Arc::new(AggregationExpr::new(input, agg_type, Some(field))))
},
}
},
Expand Down Expand Up @@ -475,12 +471,10 @@ fn create_physical_expr_inner(
options,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
let output_dtype =
expr_arena
.get(expression)
.to_dtype(schema, Context::Default, expr_arena)
.ok()
});
.to_field(schema, Context::Default, expr_arena)?;

let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
Expand All @@ -504,7 +498,7 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
*options,
state.allow_threading,
schema.cloned(),
schema.clone(),
output_dtype,
is_scalar,
)))
Expand All @@ -516,12 +510,10 @@ fn create_physical_expr_inner(
..
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
let output_field =
expr_arena
.get(expression)
.to_dtype(schema, Context::Default, expr_arena)
.ok()
});
.to_field(schema, Context::Default, expr_arena)?;
let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
// Will be reset in the function so get that here.
Expand All @@ -544,8 +536,8 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
*options,
state.allow_threading,
schema.cloned(),
output_dtype,
schema.clone(),
output_field,
is_scalar,
)))
},
Expand All @@ -570,11 +562,25 @@ fn create_physical_expr_inner(
let function = SpecialEq::new(Arc::new(
move |c: &mut [polars_core::frame::column::Column]| c[0].explode().map(Some),
) as Arc<dyn ColumnsUdf>);
Ok(Arc::new(ApplyExpr::new_minimal(

let field = expr_arena
.get(expression)
.to_field(schema, ctxt, expr_arena)?;
Ok(Arc::new(ApplyExpr::new(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::GroupWise,
FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
fmt_str: "",
cast_to_supertypes: None,
check_lengths: Default::default(),
flags: Default::default(),
},
state.allow_threading,
schema.clone(),
field,
false,
)))
},
Alias(input, name) => {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-lazy/src/dsl/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized {

// Ensure we get the new schema.
let output_field = eval_field_to_dtype(c.field().as_ref(), &expr, false);
let schema = Arc::new(Schema::from_iter(std::iter::once(output_field.clone())));

let expr = expr.clone();
let mut arena = Arena::with_capacity(10);
Expand All @@ -60,7 +61,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized {
&aexpr,
Context::Default,
&arena,
None,
&schema,
&mut ExpressionConversionState::new(true, 0),
)?;

Expand Down
Loading
Loading