diff --git a/daft/__init__.py b/daft/__init__.py index 4aeb5144ac..d45d2ccc3e 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -72,7 +72,7 @@ def refresh_logger() -> None: from daft.dataframe import DataFrame from daft.logical.schema import Schema from daft.datatype import DataType, TimeUnit -from daft.expressions import Expression, col, lit, interval, coalesce +from daft.expressions import Expression, col, list_, lit, interval, coalesce from daft.io import ( DataCatalogTable, DataCatalogType, @@ -116,6 +116,7 @@ def refresh_logger() -> None: "from_pylist", "from_ray_dataset", "interval", + "list_", "lit", "planning_config_ctx", "read_csv", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index a891e862fc..899eadfdf6 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -996,6 +996,7 @@ class PyExpr: def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ... def col(name: str) -> PyExpr: ... def lit(item: Any) -> PyExpr: ... +def list_(items: list[PyExpr]) -> PyExpr: ... def date_lit(item: int) -> PyExpr: ... def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... diff --git a/daft/expressions/__init__.py b/daft/expressions/__init__.py index 93eb173b90..8181900e6b 100644 --- a/daft/expressions/__init__.py +++ b/daft/expressions/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .expressions import Expression, ExpressionsProjection, col, lit, interval, coalesce +from .expressions import Expression, ExpressionsProjection, col, list_, lit, interval, coalesce -__all__ = ["Expression", "ExpressionsProjection", "coalesce", "col", "interval", "lit"] +__all__ = ["Expression", "ExpressionsProjection", "coalesce", "col", "interval", "list_", "lit"] diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 45d8f16ec4..ebfe3db3bc 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -164,6 +164,38 @@ def col(name: str) -> Expression: return Expression._from_pyexpr(_col(name)) +def list_(*items: Expression | str): + """Constructs a list from the item expressions. + + Example: + >>> import daft + >>> df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> df = df.select(daft.list_("x", "y").alias("fwd"), daft.list_("y", "x").alias("rev")) + >>> df.show() + ╭─────────────┬─────────────╮ + │ fwd ┆ rev │ + │ --- ┆ --- │ + │ List[Int64] ┆ List[Int64] │ + ╞═════════════╪═════════════╡ + │ [1, 4] ┆ [4, 1] │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ [2, 5] ┆ [5, 2] │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ [3, 6] ┆ [6, 3] │ + ╰─────────────┴─────────────╯ + + (Showing first 3 of 3 rows) + + Args: + *items (Union[Expression, str]): item expressions to construct the list + + Returns: + Expression: Expression representing the constructed list + """ + assert len(items) > 0, "List constructor requires at least one item" + return Expression._from_pyexpr(native.list_([col(i)._expr if isinstance(i, str) else i._expr for i in items])) + + def interval( years: int | None = None, months: int | None = None, diff --git a/docs/sphinx/source/expressions.rst b/docs/sphinx/source/expressions.rst index 170268f4f9..5f251e947d 100644 --- a/docs/sphinx/source/expressions.rst +++ b/docs/sphinx/source/expressions.rst @@ -16,6 +16,7 @@ Constructors col lit + list_ Generic ####### diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index 6f6e044e01..5bbd0f18cc 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -68,6 +68,11 @@ pub trait Growable { /// Extends this [`Growable`] with null elements fn add_nulls(&mut self, additional: usize); + /// Extends this [`Growable`] with null elements (same as add_nulls with arrow naming convention). + fn extend_nulls(&mut self, len: usize) { + self.add_nulls(len); + } + /// Builds an array from the [`Growable`] fn build(&mut self) -> DaftResult; } diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index 0fc54cb4fc..d9769ebe25 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -46,6 +46,7 @@ pub mod take; pub mod time; mod trigonometry; pub mod utf8; +pub mod zip; pub fn cast_series_to_supertype(series: &[&Series]) -> DaftResult> { let supertype = series diff --git a/src/daft-core/src/series/ops/zip.rs b/src/daft-core/src/series/ops/zip.rs new file mode 100644 index 0000000000..c7b0e844a2 --- /dev/null +++ b/src/daft-core/src/series/ops/zip.rs @@ -0,0 +1,90 @@ +use std::cmp::{max, min}; + +use arrow2::offset::Offsets; +use common_error::{DaftError, DaftResult}; +use daft_schema::{dtype::DataType, field::Field}; + +use crate::{ + array::{growable::make_growable, ListArray}, + series::{IntoSeries, Series}, +}; + +impl Series { + /// Zips series into a single series of lists. + /// ex: + /// ```text + /// A: Series := ( a_0, a_1, .. , a_n ) + /// B: Series := ( b_0, b_1, .. , b_n ) + /// C: Series := Zip(A, B) <-> ( [a_0, b_0], [a_1, b_1], [a_2, b_2] ) + /// ``` + pub fn zip(field: Field, series: &[&Self]) -> DaftResult { + // err if no series to zip + if series.is_empty() { + return Err(DaftError::ValueError( + "Need at least 1 series to perform zip".to_string(), + )); + } + + // homogeneity checks naturally happen in make_growable's downcast. + let dtype = match &field.dtype { + DataType::List(dtype) => dtype.as_ref(), + DataType::FixedSizeList(..) => { + return Err(DaftError::ValueError( + "Fixed size list constructor is currently not supported".to_string(), + )); + } + _ => { + return Err(DaftError::ValueError( + "Cannot zip field with non-list type".to_string(), + )); + } + }; + + // 0 -> index of child in 'arrays' vector + // 1 -> last index of child + type Child = (usize, usize); + + // build a null series mask so we can skip making full_nulls and avoid downcast "Null to T" errors. + let mut mask: Vec> = vec![]; + let mut rows = 0; + let mut capacity = 0; + let mut arrays = vec![]; + + for s in series { + let len = s.len(); + if is_null(s) { + mask.push(None); + } else { + mask.push(Some((arrays.len(), len - 1))); + arrays.push(*s); + } + rows = max(rows, len); + capacity += len; + } + + // initialize a growable child + let mut offsets = Offsets::::with_capacity(capacity); + let mut child = make_growable("list", dtype, arrays, true, capacity); + let sublist_len = series.len() as i64; + + // merge each series based upon the mask + for row in 0..rows { + for i in &mask { + if let Some((i, end)) = *i { + child.extend(i, min(row, end), 1); + } else { + child.extend_nulls(1); + } + } + offsets.try_push(sublist_len)?; + } + + // create the outer array with offsets + Ok(ListArray::new(field, child.build()?, offsets.into(), None).into_series()) + } +} + +/// Same null check logic as in Series::concat, but may need an audit since there are other is_null impls. +fn is_null(series: &&Series) -> bool { + series.data_type() == &DataType::Null +} diff --git a/src/daft-dsl/src/expr/display.rs b/src/daft-dsl/src/expr/display.rs new file mode 100644 index 0000000000..c3a833f400 --- /dev/null +++ b/src/daft-dsl/src/expr/display.rs @@ -0,0 +1,53 @@ +use std::fmt::Write; + +use itertools::Itertools; + +use super::{Expr, ExprRef, Operator}; + +/// Display for Expr::BinaryOp +pub fn expr_binary_op_display_without_formatter( + op: &Operator, + left: &ExprRef, + right: &ExprRef, +) -> std::result::Result { + let mut f = String::default(); + let write_out_expr = |f: &mut String, input: &Expr| match input { + Expr::Alias(e, _) => write!(f, "{e}"), + Expr::BinaryOp { .. } => write!(f, "[{input}]"), + _ => write!(f, "{input}"), + }; + write_out_expr(&mut f, left)?; + write!(&mut f, " {op} ")?; + write_out_expr(&mut f, right)?; + Ok(f) +} + +/// Display for Expr::IsIn +pub fn expr_is_in_display_without_formatter( + expr: &ExprRef, + inputs: &[ExprRef], +) -> std::result::Result { + let mut f = String::default(); + write!(&mut f, "{expr} IN (")?; + for (i, input) in inputs.iter().enumerate() { + if i != 0 { + write!(&mut f, ", ")?; + } + write!(&mut f, "{input}")?; + } + write!(&mut f, ")")?; + Ok(f) +} + +/// Display for Expr::List +pub fn expr_list_display_without_formatter( + items: &[ExprRef], +) -> std::result::Result { + let mut f = String::default(); + write!( + &mut f, + "list({})", + items.iter().map(|x| x.to_string()).join(", ") + )?; + Ok(f) +} diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 6eb2ea1eee..2864c0a8f0 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1,3 +1,4 @@ +mod display; #[cfg(test)] mod tests; @@ -27,8 +28,7 @@ use serde::{Deserialize, Serialize}; use super::functions::FunctionExpr; use crate::{ functions::{ - binary_op_display_without_formatter, function_display_without_formatter, - function_semantic_id, is_in_display_without_formatter, + function_display_without_formatter, function_semantic_id, python::PythonUDF, scalar_function_semantic_id, sketch::{HashableVecPercentiles, SketchExpr}, @@ -112,7 +112,7 @@ pub enum Expr { #[display("{_0}")] Agg(AggExpr), - #[display("{}", binary_op_display_without_formatter(op, left, right)?)] + #[display("{}", display::expr_binary_op_display_without_formatter(op, left, right)?)] BinaryOp { op: Operator, left: ExprRef, @@ -143,12 +143,15 @@ pub enum Expr { #[display("fill_null({_0}, {_1})")] FillNull(ExprRef, ExprRef), - #[display("{}", is_in_display_without_formatter(_0, _1)?)] + #[display("{}", display::expr_is_in_display_without_formatter(_0, _1)?)] IsIn(ExprRef, Vec), #[display("{_0} in [{_1},{_2}]")] Between(ExprRef, ExprRef, ExprRef), + #[display("{}", display::expr_list_display_without_formatter(_0)?)] + List(Vec), + #[display("lit({_0})")] Literal(lit::LiteralValue), @@ -164,8 +167,10 @@ pub enum Expr { #[display("subquery {_0}")] Subquery(Subquery), + #[display("{_0} in {_1}")] InSubquery(ExprRef, Subquery), + #[display("exists {_0}")] Exists(Subquery), @@ -729,6 +734,12 @@ impl Expr { FieldID::new(format!("{child_id}.is_in({items_id})")) } + Self::List(items) => { + let items_id = items.iter().fold(String::new(), |acc, item| { + format!("{},{}", acc, item.semantic_id(schema)) + }); + FieldID::new(format!("List({items_id})")) + } Self::Between(expr, lower, upper) => { let child_id = expr.semantic_id(schema); let lower_id = lower.semantic_id(schema); @@ -805,6 +816,7 @@ impl Expr { Self::IsIn(expr, items) => std::iter::once(expr.clone()) .chain(items.iter().cloned()) .collect::>(), + Self::List(items) => items.clone(), Self::Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], Self::IfElse { if_true, @@ -867,6 +879,15 @@ impl Expr { Self::IsIn(expr, items) } + Self::List(children_old) => { + let c_len = children.len(); + let c_len_old = children_old.len(); + assert_eq!( + c_len, c_len_old, + "Should have same number of children ({c_len_old}), found ({c_len})" + ); + Self::List(children) + } Self::Between(..) => Self::Between( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), @@ -938,31 +959,22 @@ impl Expr { ))) } } - Self::IsIn(left, right) => { - let left_field = left.to_field(schema)?; - - let first_right_field = right - .first() - .expect("Should have at least 1 child") - .to_field(schema)?; - let all_same_type = right.iter().all(|expr| { - let field = expr.to_field(schema).unwrap(); - // allow nulls to be compared with anything - if field.dtype == DataType::Null || first_right_field.dtype == DataType::Null { - return true; - } - field.dtype == first_right_field.dtype - }); - if !all_same_type { - return Err(DaftError::TypeError(format!( - "Expected all arguments to be of the same type, but received {first_right_field} and others", - ))); - } - - let (result_type, _intermediate, _comp_type) = - InferDataType::from(&left_field.dtype) - .membership_op(&InferDataType::from(&first_right_field.dtype))?; - Ok(Field::new(left_field.name.as_str(), result_type)) + Self::IsIn(expr, items) => { + // Use the expr's field name, and infer membership op type. + let list_dtype = try_compute_is_in_type(items, schema)?.unwrap_or(DataType::Null); + let expr_field = expr.to_field(schema)?; + let expr_type = &expr_field.dtype; + let field_name = &expr_field.name; + let field_type = InferDataType::from(expr_type) + .membership_op(&(&list_dtype).into())? + .0; + Ok(Field::new(field_name, field_type)) + } + Self::List(items) => { + // Use "list" as the field name, and infer list type from items. + let field_name = "list"; + let field_type = try_compute_collection_supertype(items, schema)?; + Ok(Field::new(field_name, DataType::new_list(field_type))) } Self::Between(value, lower, upper) => { let value_field = value.to_field(schema)?; @@ -1107,6 +1119,7 @@ impl Expr { Self::IsIn(expr, ..) => expr.name(), Self::Between(expr, ..) => expr.name(), Self::Literal(..) => "literal", + Self::List(..) => "list", Self::Function { func, inputs } => match func { FunctionExpr::Struct(StructExpr::Get(name)) => name, _ => inputs.first().unwrap().name(), @@ -1195,6 +1208,7 @@ impl Expr { | Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) + | Expr::List(..) | Expr::Between(..) | Expr::Function { .. } | Expr::FillNull(..) @@ -1436,6 +1450,7 @@ pub fn estimated_selectivity(expr: &Expr, schema: &Schema) -> f64 { // Everything else doesn't filter Expr::Subquery(_) => 1.0, Expr::Agg(_) => panic!("Aggregates are not allowed in WHERE clauses"), + Expr::List(_) => 1.0, }; // Lower bound to 1% to prevent overly selective estimate @@ -1477,3 +1492,37 @@ pub fn deduplicate_expr_names(exprs: &[ExprRef]) -> Vec { }) .collect() } + +/// Asserts an expr slice is homogeneous and returns the type, or None if empty or all nulls. +/// None allows for context-dependent handling such as erroring or defaulting to Null. +fn try_compute_is_in_type(exprs: &[ExprRef], schema: &Schema) -> DaftResult> { + let mut dtype: Option = None; + for expr in exprs { + let other_dtype = expr.get_type(schema)?; + // other is null, continue + if other_dtype == DataType::Null { + continue; + } + // other != null and dtype is unset -> set dtype + if dtype.is_none() { + dtype = Some(other_dtype); + continue; + } + // other != null and dtype is set -> compare or err! + if dtype.as_ref() != Some(&other_dtype) { + return Err(DaftError::TypeError(format!("Expected all arguments to be of the same type {}, but found element with type {other_dtype}", dtype.unwrap()))); + } + } + Ok(dtype) +} + +/// Tries to get the supertype of all exprs in the collection. +fn try_compute_collection_supertype(exprs: &[ExprRef], schema: &Schema) -> DaftResult { + let mut dtype = DataType::Null; + for expr in exprs { + let other_dtype = expr.get_type(schema)?; + let super_dtype = try_get_supertype(&dtype, &other_dtype)?; + dtype = super_dtype; + } + Ok(dtype) +} diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 7fa3bd8952..3c7cadeea2 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -18,7 +18,7 @@ pub use scalar::*; use serde::{Deserialize, Serialize}; use self::{map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr}; -use crate::{Expr, ExprRef, Operator}; +use crate::ExprRef; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { @@ -106,39 +106,6 @@ pub fn function_display_without_formatter( Ok(f) } -pub fn is_in_display_without_formatter( - expr: &ExprRef, - inputs: &[ExprRef], -) -> std::result::Result { - let mut f = String::default(); - write!(&mut f, "{expr} IN (")?; - for (i, input) in inputs.iter().enumerate() { - if i != 0 { - write!(&mut f, ", ")?; - } - write!(&mut f, "{input}")?; - } - write!(&mut f, ")")?; - Ok(f) -} - -pub fn binary_op_display_without_formatter( - op: &Operator, - left: &ExprRef, - right: &ExprRef, -) -> std::result::Result { - let mut f = String::default(); - let write_out_expr = |f: &mut String, input: &Expr| match input { - Expr::Alias(e, _) => write!(f, "{e}"), - Expr::BinaryOp { .. } => write!(f, "[{input}]"), - _ => write!(f, "{input}"), - }; - write_out_expr(&mut f, left)?; - write!(&mut f, " {op} ")?; - write_out_expr(&mut f, right)?; - Ok(f) -} - pub fn function_semantic_id(func: &FunctionExpr, inputs: &[ExprRef], schema: &Schema) -> FieldID { let inputs = inputs .iter() diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index a28fa41c46..6908bdf949 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -29,6 +29,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction!(python::col, parent)?)?; parent.add_function(wrap_pyfunction!(python::lit, parent)?)?; + parent.add_function(wrap_pyfunction!(python::list_, parent)?)?; parent.add_function(wrap_pyfunction!(python::date_lit, parent)?)?; parent.add_function(wrap_pyfunction!(python::time_lit, parent)?)?; parent.add_function(wrap_pyfunction!(python::timestamp_lit, parent)?)?; diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index 38a9c8588e..2a92bdafa0 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -31,6 +31,7 @@ pub fn requires_computation(e: &Expr) -> bool { | Expr::NotNull(..) | Expr::FillNull(..) | Expr::IsIn { .. } + | Expr::List { .. } | Expr::Between { .. } | Expr::IfElse { .. } | Expr::Subquery { .. } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index df380bd154..c646d0dc40 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -182,6 +182,11 @@ pub fn lit(item: Bound) -> PyResult { } } +#[pyfunction] +pub fn list_(items: Vec) -> PyExpr { + Expr::List(items.into_iter().map(|item| item.into()).collect()).into() +} + #[allow(clippy::too_many_arguments)] #[pyfunction(signature = ( name, diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 171899203c..3a5410c3d4 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -202,16 +202,15 @@ impl Project { } } +/// Constructs a new copy of this expression +/// with all occurrences of subexprs_to_replace replaced with a column selection. +/// e.g. e := (a+b)+c, subexprs := {FieldID("(a + b)")} +/// -> Col("(a + b)") + c fn replace_column_with_semantic_id( e: ExprRef, subexprs_to_replace: &IndexSet, schema: &Schema, ) -> Transformed { - // Constructs a new copy of this expression - // with all occurrences of subexprs_to_replace replaced with a column selection. - // e.g. e := (a+b)+c, subexprs := {FieldID("(a + b)")} - // -> Col("(a + b)") + c - let sem_id = e.semantic_id(schema); if subexprs_to_replace.contains(&sem_id) { let new_expr = Expr::Column(sem_id.id); @@ -307,6 +306,22 @@ fn replace_column_with_semantic_id( ) } } + Expr::List(items) => { + let mut transformed = false; + let mut new_items = Vec::::new(); + for item in items { + let new_item = + replace_column_with_semantic_id(item.clone(), subexprs_to_replace, schema); + if new_item.transformed { + new_items.push(new_item.data.clone()); + transformed = true; + } + } + if transformed { + return Transformed::yes(Expr::List(new_items).into()); + } + Transformed::no(e) + } Expr::Between(child, lower, upper) => { let child = replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); diff --git a/src/daft-logical-plan/src/partitioning.rs b/src/daft-logical-plan/src/partitioning.rs index 62be09ff05..3dfdb17caf 100644 --- a/src/daft-logical-plan/src/partitioning.rs +++ b/src/daft-logical-plan/src/partitioning.rs @@ -297,6 +297,13 @@ fn translate_clustering_spec_expr( .collect::, _>>()?; Ok(newchild.is_in(newitems)) } + Expr::List(items) => { + let new_items = items + .iter() + .map(|e| translate_clustering_spec_expr(e, old_colname_to_new_colname)) + .collect::, _>>()?; + Ok(Expr::List(new_items).into()) + } Expr::Between(child, lower, upper) => { let newchild = translate_clustering_spec_expr(child, old_colname_to_new_colname)?; let newlower = translate_clustering_spec_expr(lower, old_colname_to_new_colname)?; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 05f34a365e..a1a23977ce 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1628,7 +1628,17 @@ impl<'a> SQLPlanner<'a> { } SQLExpr::Map(_) => unsupported_sql_err!("MAP"), SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), - SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), + SQLExpr::Array(array) => { + if array.elem.is_empty() { + invalid_operation_err!("List constructor requires at least one item") + } + let items = array + .elem + .iter() + .map(|e| self.plan_expr(e)) + .collect::>>()?; + Ok(Expr::List(items).into()) + } SQLExpr::Interval(interval) => { use regex::Regex; diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0bc7c1a4cc..2d3082fe52 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -553,7 +553,22 @@ impl Table { .eval_expression(child)? .is_in(&s) } - + Expr::List(items) => { + // compute list type to determine each child cast + let field = expr.to_field(&self.schema)?; + // extract list child type (could be de-duped with zip and moved to impl DataType) + let dtype = if let DataType::List(dtype) = &field.dtype { + dtype + } else { + return Err(DaftError::ComputeError("List expression must be of type List(T)".to_string())) + }; + // compute child series with explicit casts to the supertype + let items = items.iter().map(|i| i.clone().cast(dtype)).collect::>(); + let items = items.iter().map(|i| self.eval_expression(i)).collect::>>()?; + let items = items.iter().collect::>(); + // zip the series into a single series of lists + Series::zip(field, items.as_slice()) + } Expr::Between(child, lower, upper) => self .eval_expression(child)? .between(&self.eval_expression(lower)?, &self.eval_expression(upper)?), diff --git a/tests/expressions/test_list_.py b/tests/expressions/test_list_.py new file mode 100644 index 0000000000..8ee5d08975 --- /dev/null +++ b/tests/expressions/test_list_.py @@ -0,0 +1,95 @@ +import pytest + +import daft +from daft import DataType as dt +from daft import col, list_, lit + + +def test_list_constructor_empty(): + with pytest.raises(Exception, match="List constructor requires at least one item"): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_()) + + +def test_list_constructor_with_coercions(): + df = daft.from_pydict({"v_i32": [1, 2, 3], "v_bool": [True, True, False]}) + df = df.select(list_(lit(1), col("v_i32"), col("v_bool"))) + assert df.to_pydict() == {"list": [[1, 1, 1], [1, 2, 1], [1, 3, 0]]} + + +def test_list_constructor_with_lit_first(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_(lit(1), col("x"), col("y"))) + assert df.to_pydict() == {"list": [[1, 1, 4], [1, 2, 5], [1, 3, 6]]} + + +def test_list_constructor_with_lit_mid(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_(col("x"), lit(1), col("y"))) + assert df.to_pydict() == {"list": [[1, 1, 4], [2, 1, 5], [3, 1, 6]]} + + +def test_list_constructor_with_lit_last(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_(col("x"), col("y"), lit(1))) + assert df.to_pydict() == {"list": [[1, 4, 1], [2, 5, 1], [3, 6, 1]]} + + +def test_list_constructor_multi_column(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_("x", "y").alias("fwd"), list_("y", "x").alias("rev")) + assert df.to_pydict() == {"fwd": [[1, 4], [2, 5], [3, 6]], "rev": [[4, 1], [5, 2], [6, 3]]} + + +def test_list_constructor_different_lengths(): + with pytest.raises(Exception, match="Expected all columns to be of the same length"): + df = daft.from_pydict({"x": [1, 2], "y": [3]}) + df = df.select(list_("x", "y")) + + +def test_list_constructor_singleton(): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_(col("x")).alias("singleton")) + assert df.to_pydict() == {"singleton": [[1], [2], [3]]} + + +def test_list_constructor_homogeneous(): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_("x", col("x") * 2, col("x") * 3).alias("homogeneous")) + assert df.to_pydict() == {"homogeneous": [[1, 2, 3], [2, 4, 6], [3, 6, 9]]} + + +def test_list_constructor_heterogeneous(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) + df = df.select(list_("x", "y").alias("heterogeneous")) + assert df.to_pydict() == {"heterogeneous": [[1, 1], [2, 1], [3, 0]]} + + +def test_list_constructor_heterogeneous_with_cast(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) + df = df.select(list_(col("x").cast(dt.string()), col("y").cast(dt.string())).alias("strs")) + assert df.to_pydict() == {"strs": [["1", "1"], ["2", "1"], ["3", "0"]]} + + +def test_list_constructor_mixed_null_first(): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_(lit(None), col("x")).alias("res")) + assert df.to_pydict() == {"res": [[None, 1], [None, 2], [None, 3]]} + + +def test_list_constructor_mixed_null_mid(): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_(-1 * col("x"), lit(None), col("x")).alias("res")) + assert df.to_pydict() == {"res": [[-1, None, 1], [-2, None, 2], [-3, None, 3]]} + + +def test_list_constructor_mixed_null_last(): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_(col("x"), lit(None)).alias("res")) + assert df.to_pydict() == {"res": [[1, None], [2, None], [3, None]]} + + +def test_list_constructor_all_nulls(): + df = daft.from_pydict({"x": [1, 2, 3]}) + df = df.select(list_(lit(None), lit(None)).alias("res")) + assert df.to_pydict() == {"res": [[None, None], [None, None], [None, None]]} diff --git a/tests/sql/test_list_exprs.py b/tests/sql/test_list_exprs.py index 2f0799fb71..e508d077df 100644 --- a/tests/sql/test_list_exprs.py +++ b/tests/sql/test_list_exprs.py @@ -1,11 +1,86 @@ import pyarrow as pa +import pytest import daft -from daft import col +from daft import DataType, col, list_ from daft.daft import CountMode from daft.sql.sql import SQLCatalog +def assert_eq(actual, expect): + """Asserts two dataframes are equal for tests.""" + assert actual.collect().to_pydict() == expect.collect().to_pydict() + + +def test_list_constructor_empty(): + with pytest.raises(Exception, match="List constructor requires at least one item"): + df = daft.from_pydict({"x": [1, 2, 3]}) + daft.sql("SELECT [ ] as list FROM df") + df # for ruff ignore unused + + +def test_list_constructor_different_lengths(): + with pytest.raises(Exception, match="Expected all columns to be of the same length"): + df = daft.from_pydict({"x": [1, 2], "y": [3]}) + daft.sql("SELECT [ x, y ] FROM df") + df # for ruff ignore unused + + +def test_list_constructor_singleton(): + df = daft.from_pydict({"x": [1, 2, 3]}) + actual = daft.sql("SELECT [ x ] as list FROM df") + expect = df.select(col("x").apply(lambda x: [x], DataType.list(DataType.int64())).alias("list")) + assert_eq(actual, expect) + + +def test_list_constructor_homogeneous(): + df = daft.from_pydict({"x": [1, 2, 3]}) + actual = daft.sql("SELECT [ x * 1, x * 2, x * 3 ] FROM df") + expect = df.select(col("x").apply(lambda x: [x * 1, x * 2, x * 3], DataType.list(DataType.int64())).alias("list")) + assert_eq(actual, expect) + + +def test_list_constructor_heterogeneous(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) + df = daft.sql("SELECT [ x, y ] AS heterogeneous FROM df").collect() + assert df.to_pydict() == {"heterogeneous": [[1, 1], [2, 1], [3, 0]]} + + +def test_list_constructor_heterogeneous_with_cast(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) + actual = daft.sql("SELECT [ CAST(x AS STRING), CAST(y AS STRING) ] FROM df") + expect = df.select(list_(col("x").cast(DataType.string()), col("y").cast(DataType.string()))) + assert_eq(actual, expect) + + +def test_list_constructor_mixed_null_first(): + df = daft.from_pydict({"x": [1, 2, 3]}) + actual = daft.sql("SELECT [ NULL, x ] FROM df") + expect = df.select(col("x").apply(lambda x: [None, x], DataType.list(DataType.int64())).alias("list")) + assert_eq(actual, expect) + + +def test_list_constructor_mixed_null_mid(): + df = daft.from_pydict({"x": [1, 2, 3]}) + actual = daft.sql("SELECT [ x * -1, NULL, x ] FROM df") + expect = df.select(col("x").apply(lambda x: [x * -1, None, x], DataType.list(DataType.int64())).alias("list")) + assert_eq(actual, expect) + + +def test_list_constructor_mixed_null_last(): + df = daft.from_pydict({"x": [1, 2, 3]}) + actual = daft.sql("SELECT [ x, NULL ] FROM df") + expect = df.select(col("x").apply(lambda x: [x, None], DataType.list(DataType.int64())).alias("list")) + assert_eq(actual, expect) + + +def test_list_constructor_all_nulls(): + df = daft.from_pydict({"x": [1, 2, 3]}) + actual = daft.sql("SELECT [ NULL, NULL ] FROM df") + expect = df.select(col("x").apply(lambda x: [None, None], DataType.list(DataType.null())).alias("list")) + assert_eq(actual, expect) + + def test_list_chunk(): df = daft.from_pydict( {